Enumerate all possible dataclass instances (with enum and bool fields only) - python

Basically I need following. I have a python3 dataclass or NamedTuple, with only enum and bool fields. E.g.:
from enum import Enum, auto
from typing import NamedTuple
class MyEnum(Enum):
v1 = auto()
v2 = auto()
v3 = auto()
class MyStateDefinition(NamedTuple):
a: MyEnum
b: bool
Is here any good known solution to enumerate all possible non-equal instances of such a dataclass? (Example above has 6 possible non-equal instances).
Perhaps it is not a dataclass I should use, but something else. Or should I play with such things like dataclasses.fields directly?
I imagine it as some table generator which accepts a namedtuple or dataclass as an input parameter and produces all possible values.
table = DataTable(MyStateDefinition)
for item in table:
# Use items somehow
print(item.a)
print(item.b)
Why do I need it? I just have some state definition which consists of enums and bools. I believe it could be implemented as a bitmask. But when it comes to extending your bitmask with new values, it turns out to be a nightmare. Afterall, bitmasks seem to be a non-pythonic way of doing things.
Currently I have to use an implementation of my own. But perhaps I'm reinventing the wheel.
Thanks!

You can do this using enums, with the data-tuples as the enum-members' value (an Enum/NamedTuple hybrid, if you will). The _ignore_ attribute is used to prevent certain names in the class namespace from being converted into enum members.
from itertools import product
from enum import Enum
class Data(Enum):
_ignore_ = "Data", "myenum_member", "truthiness"
#property
def a(self):
return self.value[0]
#property
def b(self):
return self.value[1]
def __repr__(self):
return f'Data(a={self.a!r}, b={self.b!r})'
Data = vars()
for myenum_member, truthiness in product(MyEnum, (True, False)):
Data[f'{myenum_member.name}_{truthiness}'] = (myenum_member, truthiness)
You should be able to iterate through the resulting enum class just as you desire.
This use of enums is similar to the "time period" example in the Enum HOWTO section of the docs.
Generating this kind of table dynamically
If you want to generate this kind of table dynamically, you could do something like this, (ab)using metaclasses. I've shown example usages for how you would use this DataTable class in the docstrings. (For some reason, using typing.get_type_hints in a doctest seems to cause the doctest module to error out, but the examples do work if you try them yourself in an interactive terminal.) Rather than special-casing bool, as you did in your answer, I decided to special-case typing.Literal, as it seemed like a more extensible option (and bool can just be spelled as typing.Literal[True, False]).
from __future__ import annotations
from itertools import product
from enum import Enum, EnumMeta
from typing import (
Iterable,
Mapping,
cast,
Protocol,
get_type_hints,
Any,
get_args,
get_origin,
Literal,
TypeVar,
Union,
Optional
)
D = TypeVar('D')
T = TypeVar('T')
class DataTableFactory(EnumMeta):
"""A helper class for making data tables (an implementation detail of `DataTable`)."""
_CLS_BASES = (Enum,)
#classmethod
def __prepare__( # type: ignore[override]
metacls,
cls_name: str,
fields: Mapping[str, Iterable[Any]]
) -> dict[str, Any]:
cls_dict = cast(
dict[str, Any],
super().__prepare__(cls_name, metacls._CLS_BASES)
)
for i, field in enumerate(fields.keys()):
cls_dict[field] = property(fget=lambda self, i=i: self.value[i]) # type: ignore[misc]
for p in product(*fields.values()):
cls_dict['_'.join(map(str, p))] = p
def __repr__(self: Enum) -> str:
contents = ', '.join(
f'{field}={getattr(self, field)!r}'
for field in fields
)
return f'{cls_name}Member({contents})'
cls_dict['__repr__'] = __repr__
return cls_dict
#classmethod
def make_datatable(
metacls,
cls_name: str,
*,
fields: Mapping[str, Iterable[Any]],
doc: Optional[str] = None
) -> type[Enum]:
"""Create a new data table"""
cls_dict = metacls.__prepare__(cls_name, fields)
new_cls = metacls.__new__(metacls, cls_name, metacls._CLS_BASES, cls_dict)
new_cls.__module__ = __name__
if doc is None:
all_attrs = '\n'.join(
f' {f"{attr_name}: ":<{(max(map(len, fields)) + 3)}}one of {attr_val!r}'
for attr_name, attr_val in fields.items()
)
fields_len = len(fields)
doc = (
f'An enum-like data table.\n\n'
f'All members of this data table have {fields_len} '
f'read-only attribute{"s" if fields_len > 1 else ""}:\n'
f'{all_attrs}\n\n'
f'----------------------------------------------------------------------'
)
new_cls.__doc__ = doc
return cast(type[Enum], new_cls)
def __repr__(cls) -> str:
return f"<Data table '{cls.__name__}'>"
def index_of(cls: Iterable[D], member: D) -> int:
"""Get the index of a member in the list of members."""
return list(cls).index(member)
def get(
cls: Iterable[D],
/,
*,
default_: Optional[T] = None,
**kwargs: Any
) -> Union[D, T, None]:
"""Return instance for given arguments set.
Return `default_` if no member matches those arguments.
"""
it = (
member for member in cls
if all((getattr(member, key) == val) for key, val in kwargs.items())
)
return next(it, default_)
def __dir__(cls) -> list[str]:
# By defining __dir__, we make methods defined in this class
# discoverable by the interactive help() function in the REPL
return cast(list[str], super().__dir__()) + ['index_of', 'get']
class TypedStructProto(Protocol):
"""In order to satisfy this interface, a type must have an __annotations__ dict."""
__annotations__: dict[str, Union[Iterable[Any], type[Literal[True]]]]
class DataTableMeta(type):
"""Metaclass for `DataTable`."""
__call__ = DataTableFactory.make_datatable # type: ignore[assignment]
class DataTable(metaclass=DataTableMeta):
"""A mechanism to create 'data table enumerations' -- not really a class at all!
Example usage
-------------
>>> Cars = DataTable('Cars', fields={'make': ('Toyota', 'Audi'), 'colour': ('Red', 'Blue')})
>>> Cars
<Data table 'Cars'>
>>> list(Cars)
[CarsMember(make=Toyota, colour=Red), CarsMember(make=Toyota, colour=Blue), CarsMember(make=Audi, colour=Red), CarsMember(make=Audi, colour=Blue)]
>>> Cars.get(make='Audi', colour='Red')
CarsMember(make=Audi, colour=Red)
>>> Cars.index_of(_)
2
"""
#classmethod
def from_struct(cls, cls_name: str, *, struct: type[TypedStructProto], doc: Optional[str] = None) -> type[Enum]:
"""Make a DataTable from a "typed struct" -- e.g. a dataclass, NamedTuple or TypedDict.
Example usage (works the same way with dataclasses and TypedDicts)
-------------------------------------------------------------------
>>> from enum import Enum, auto
>>> from typing import NamedTuple, Literal
>>> class E(Enum):
... v1 = auto()
... v2 = auto()
... v3 = auto()
...
>>> class BoolsEndEnums(NamedTuple):
... a: E
... b: Literal[True, False]
...
>>> BoolsEndEnumsTable = DataTable.from_struct('BoolsEndEnumsTable', struct=BoolsEndEnums)
>>> list(BoolsEndEnumsTable)
[BoolsEndEnumsTableMember(a=E.v1, b=True), BoolsEndEnumsTableMember(a=E.v1, b=False), BoolsEndEnumsTableMember(a=E.v2, b=True), BoolsEndEnumsTableMember(a=E.v2, b=False), BoolsEndEnumsTableMember(a=E.v3, b=True), BoolsEndEnumsTableMember(a=E.v3, b=False)]
"""
fields = get_type_hints(struct)
for field_name, field_val in fields.items():
if get_origin(field_val) is Literal:
fields[field_name] = get_args(field_val)
return cast(type[Enum], cls(cls_name, fields=fields, doc=doc)) # type: ignore[call-arg]
I've had to do some "interesting" things with the type hints, but MyPy is sort of happy with all this.

Also posting implementation of my own. Not ideal, I had to use some protected members.
Usage:
from typing import NamedTuple
from datatable import DataTable
class BoolsEndEnums(NamedTuple):
a: E
b: bool
tbl = DataTable(BoolsEndEnums)
item = tbl[0]
print(item.a) # a is v1
print(item.b) # b is False
See test_datatable.py, _test_cls for more usage examples.
datatable.py
import collections
import dataclasses
from collections import Iterable
from enum import Enum
from typing import Union, Any, Tuple, Iterator, get_type_hints, NamedTuple
def is_cls_namedtuple(cls):
return issubclass(cls, tuple) and hasattr(cls, "_fields")
class DataTable(Iterable):
def __init__(self, data_cls):
self._table = []
self._index = {}
self._rindex = {}
self._named_tuple_cls = None
fields = None
if dataclasses.is_dataclass(data_cls):
fields = [f.name for f in dataclasses.fields(data_cls)]
self._named_tuple_cls = collections.namedtuple(
f"{data_cls.__name__}_immutable",
fields
)
elif is_cls_namedtuple(data_cls):
self._named_tuple_cls = data_cls
fields = data_cls._fields
else:
raise ValueError(
"Only dataclasses and NamedTuple subclasses are supported."
)
hints = get_type_hints(data_cls)
self._build_table([], [(f, hints[f]) for f in fields])
def index_of(self, instance):
"""
Returns record index of given instance in table.
:param instance:
:return:
"""
index = self._as_index(instance)
return self._rindex.get(index)
def get(self, **kw):
"""
Returns instance for given arguments set
:param kw:
:return:
"""
index = self._as_index(kw)
return self._table[self._rindex[index]]
def __len__(self):
return len(self._table)
def __getitem__(self, i: Union[int, slice]):
return self._table[i]
def __iter__(self) -> Iterator:
return self._table.__iter__()
def _build_table(self, defined_fields, remained_fields):
if not remained_fields:
instance = self._named_tuple_cls(**dict(defined_fields))
item_id = len(self._table)
self._index[item_id] = instance
self._rindex[self._as_index(defined_fields)] = item_id
self._table.append(instance)
return
next_name, next_type = remained_fields[0]
remained_fields = remained_fields[1:]
if issubclass(next_type, Enum):
for v in next_type:
self._build_table(
defined_fields + [(next_name, v)],
remained_fields
)
return
if next_type is bool:
self._build_table(
defined_fields + [(next_name, False)],
remained_fields
)
self._build_table(
defined_fields + [(next_name, True)],
remained_fields
)
return
raise ValueError(f"Got unexpected dataclass field type: {next_type}")
#staticmethod
def _as_index(v: Union[Any, Tuple[str, Any]]):
items = None
if dataclasses.is_dataclass(v):
items = dataclasses.asdict(v).items()
elif is_cls_namedtuple(type(v)):
items = v._asdict().items()
elif isinstance(v, dict):
items = v.items()
else:
assert isinstance(v, collections.Sequence)
items = v
return tuple(sorted(items, key=lambda x: x[0]))
test_datatable.py
import dataclasses
from enum import Enum, auto
from typing import NamedTuple
import pytest
from dataclass_utils import DataTable
class E(Enum):
v1 = auto()
v2 = auto()
v3 = auto()
#dataclasses.dataclass
class BoolsEndEnums:
a: E
b: bool
class BoolsEndEnumsNamedTuple(NamedTuple):
a: E
b: bool
#dataclasses.dataclass
class HugeSetOfValues:
a: int
b: bool
class NotSupportedCls:
pass
def _test_cls(cls):
tbl = DataTable(cls)
first = cls(E.v1, False)
last = cls(E.v3, True)
expected_num_entries = 6
assert tbl.index_of(first) == 0
assert tbl.index_of(last) == (expected_num_entries - 1)
assert len(tbl) == expected_num_entries
actual_third = tbl.get(a=E.v2, b=False)
assert actual_third.a == E.v2
assert actual_third.b is False
actual_forth = tbl[3]
assert actual_forth.a == E.v2
assert actual_forth.b is True
items = [item for item in tbl]
actual_fifth = items[4]
assert actual_fifth.a == E.v3
assert actual_fifth.b is False
# Test that we can't change result
with pytest.raises(AttributeError):
tbl[0].a = E.v2
def test_dataclass():
_test_cls(BoolsEndEnums)
def test_namedtuple():
_test_cls(BoolsEndEnumsNamedTuple)
def test_datatable_neg():
"""
Generic negative tests
"""
with pytest.raises(ValueError):
DataTable(HugeSetOfValues)
with pytest.raises(ValueError):
DataTable(NotSupportedCls)

Related

Dynamically generate mypy-compliant property setters

I am trying to declare a base class with certain attributes for which the (very expensive) calculation differs depending on the subclass, but that accepts injecting the value if previously calculated
class Test:
_value1: int | None = None
_value2: str | None = None
_value3: list | None = None
_value4: dict | None = None
#property
def value1(self) -> int:
if self._value1 is None:
self._value1 = self._get_value1()
return self._value1
#value1.setter
def value1(self, value1: int) -> None:
self._value1 = value1
def _get_value1(self) -> int:
raise NotImplementedError
class SubClass(Test):
def _get_value1(self) -> int:
time.sleep(1000000)
return 1
instance = SubClass()
instance.value1 = 1
print(instance.value1) # doesn't wait
As you can see it becomes very verbose, with every property having three different functions associated to it.
Is there a way to dynamically declare at the very least the setter, so that mypy knows it's always the same function but with proper typing? Or in general, is there a more concise way to declare this kind of writable property for which the underlying implementation must be implemented by the base class, in bulk?
Declaring __setattr__ doesn't seem to be viable, because just having __setattr__ declared tricks mpy into thinking I can just assign any value to anything else that's not overloaded, while I still want errors to show up in case I'm trying to assign the wrong attributes. It also doesn't fix that I still need to declare setters, otherwise it thinks the value is immutable.
Instead of inheriting a bunch of pre-defined properties from a base class, I would move all the logic surrounding each property into a custom descriptor class. (The following assumes Python 3.11 and mypy version 1.0.0.)
from typing import TypeVar, Generic, Callable, Type, Optional, Self, Union, overload
T = TypeVar('T')
C = TypeVar('C')
class Descriptor(Generic[C, T]):
def __init__(self, f: Callable[[C], T]):
self.getter = f
def __set_name__(self, owner: C, name: str):
self.private_name = "_" + name
self.public_name = name
#overload
def __get__(self: Self, obj: C, objtype: Optional[Type[C]]) -> T:
...
#overload
def __get__(self: Self, obj: None, objtype: Type[C]) -> Self:
...
def __get__(self: Self, obj: Optional[C], owner: Optional[Type[C]] = None) -> Union[Self, T]:
if obj is None:
return self
if getattr(obj, self.private_name, None) is None:
init_value = self.getter(obj)
self.__set__(obj, init_value)
return getattr(obj, self.private_name)
def __set__(self, obj: C, value: T):
setattr(obj, self.private_name, value)
Then you can define each descriptor similar to how you would define a property, by decorating the function that will return the value an initial value if none has yet been defined.
class Test:
#Descriptor
def value1(self) -> int:
time.sleep(10000000)
return 1
#Descriptor
def value2(self) -> str:
return "foo"
#Descriptor
def value3(self) -> list:
return [1, 2, 3]
#Descriptor
def value4(self) -> dict:
return dict(foo=9)
The descriptor class is generic in both the class it will be used in and the type of the wrapped value.
x = Test()
reveal_type(x.value1) # int
reveal_type(Test.value1) # Descriptor[Test, int]
x.value1 = 3 # OK
x.value1 = "foo" # error, x.__set__ expects an int, not a str
If you wanted to simply omit writing #property.setter (this part)
#value1.setter
def value1(self, value1: int) -> None:
self._value1 = value1
one possible implementation would be to subclass property to automatically implement a __set__ method which matches the behaviour specified in your example:
from __future__ import annotations
import typing as t
if t.TYPE_CHECKING:
import collections.abc as cx
_ValueT = t.TypeVar("_ValueT")
class settable(property, t.Generic[_ValueT]):
fget: cx.Callable[[t.Any], _ValueT]
def __init__(self, fget: cx.Callable[[t.Any], _ValueT], /) -> None:
super().__init__(fget)
if t.TYPE_CHECKING:
# Type-safe descriptor protocol for property retrieval methods (`__get__`)
# see https://docs.python.org/3/howto/descriptor.html
# These are under `typing.TYPE_CHECKING` because we don't need
# to modify their implementation from `builtins.property`, but
# just need to add type-safety.
#t.overload # type: ignore[override, no-overload-impl]
def __get__(self, instance: None, Class: type, /) -> settable[_ValueT]:
"""
Retrieving a property from on a class (`instance: None`) retrieves the
property object (`settable[_ValueT]`)
"""
#t.overload
def __get__(self, instance: object, Class: type, /) -> _ValueT:
"""
Retrieving a property from the instance (all other `typing.overload` cases)
retrieves the value
"""
def __set__(self, instance: t.Any, value: _ValueT) -> None:
"""
Type-safe setter method. Grabs the name of the function first decorated with
`#settable`, then calls `setattr` on the given value with an attribute name of
'_<function name>'.
"""
setattr(instance, f"_{self.fget.__name__}", value)
Here's a demonstration of type-safety:
import time
class Test:
_value1: int | None = None
_value2: str | None = None
_value3: list | None = None
_value4: dict | None = None
#settable
def value1(self) -> int:
if self._value1 is None:
self._value1 = self._get_value1()
return self._value1
def _get_value1(self) -> int:
raise NotImplementedError
class SubClass(Test):
def _get_value1(self) -> int:
time.sleep(1000000)
return 1
>>> instance: SubClass = SubClass()
>>> instance.value1 = 1 # OK
>>>
>>> if t.TYPE_CHECKING:
... reveal_type(instance.value1) # mypy: Revealed type is "builtins.int"
...
>>> print(instance.value1)
1
>>> instance.value1 = "1" # mypy: Incompatible types in assignment (expression has type "str", variable has type "int") [assignment]
>>> SubClass.value1 = 1 # mypy: Cannot assign to a method [assignment]
... # mypy: Incompatible types in assignment (expression has type "int", variable has type "settable[int]") [assignment]

auto() in Dataclass like auto() in Enum (Python)

Is there a way to do the same code below using Dataclass instead of Enum?
from enum import Enum, auto
class State(Enum):
val_A= auto()
val_B = auto()
val_C = auto()
The only solution I found is the following code:
from dataclasses import dataclass
#dataclass(frozen=True)
class State():
val_A:str = 'val_A'
val_B:str = 'val_B'
val_C:str = 'val_C'
thank you for the suggestions.
Descriptors
One approach could be to use a descriptor class, defined as below:
class Auto:
_GLOBAL_STATE = {}
__slots__ = ('_private_name', )
# `owner` is the class or type, whereas instance is an object of `owner`
def __get__(self, instance, owner):
try:
return getattr(instance, self._private_name)
except AttributeError:
_state = self.__class__._GLOBAL_STATE
_dflt = _state[owner] = _state.get(owner, 0) + 1
return _dflt
def __set_name__(self, owner, name):
self._private_name = '_' + name
def __set__(self, instance, value):
# use object.__setattr__() instead of setattr() as dataclasses
# also does, in case of a "frozen" dataclass
object.__setattr__(instance, self._private_name, value)
Usage would be as follows:
from dataclasses import dataclass
#dataclass(frozen=True)
class State:
val_A: int = Auto()
val_B: int = Auto()
val_C: int = Auto()
s = State()
print(s) # State(val_A=1, val_B=2, val_C=3)
assert s.val_B == 2
s = State(val_A=5)
assert s.val_A == 5
assert s.val_C == 3
Optimal Approach
The most performant approach I can think of, would be to replace the default values before dataclasses is able to process the class.
Initially this is O(N) time, as it would require iterating over all the class members (including dataclass fields) at least once. However, the real benefit is that it replaces the default values for auto values, such as val_A: int = 1, before the #dataclass decorator is able to process the class.
For example, define a metaclass such as one below:
# sentinel value to detect when to replace a field's default
auto = object()
def check_auto(name, bases, cls_dict):
default = 1
for name, val in cls_dict.items():
if val == auto:
cls_dict[name] = default
default += 1
cls = type(name, bases, cls_dict)
return cls
Usage is as below:
from dataclasses import dataclass
#dataclass(frozen=True)
class State(metaclass=check_auto):
val_A: int = auto
val_B: int = auto
val_C: int = auto
s = State()
print(s) # State(val_A=1, val_B=2, val_C=3)
assert s.val_B == 2
s = State(val_A=5)
assert s.val_A == 5
assert s.val_C == 3

Can type hinting infer type of class attribute from another?

There are subclasses that have the class attribute matcher_function set to a function. During instantiation that function is called and sets another attribute matcher. In all cases the return object of the matcher_function is what matcher gets set to.
Is it possible to create a type hint in the base class BaseResolution that would allow both mypy and pycharm to properly infer matcher is the return value of matcher_function?
# contains_the_text.py
from hamcrest import contains_string
from hamcrest.library.text.stringcontains import StringContains
from .base_resolution import BaseResolution
class ContainsTheText(BaseResolution):
matcher: StringContains # <-- this is what I'm curious can be inferred
matcher_function = contains_string # <-- this function returns an instance
# of `StringContains`
# it would be wonderful if
# a. mypy could detect that the matcher type hint is correct based on 'matcher_function'
# b. infer what the type is when the hint is not present in the subclasses.
# base_resolution.py
from typing import Any, Callable, TypeVar
from hamcrest.core.base_matcher import BaseMatcher, Matcher
from hamcrest.core.description import Description
T = TypeVar("T")
class BaseResolution(BaseMatcher[T]):
matcher: Matcher
matcher_function: Callable
expected: Any
def __init__(self, *args: object, **kwargs: object) -> None:
cls = self.__class__
if args and kwargs:
self.expected = (args, kwargs)
self.matcher = cls.matcher_function(*args, **kwargs)
elif args:
self.expected = args if len(args) > 1 else args[0]
self.matcher = cls.matcher_function(*args)
elif kwargs:
self.expected = kwargs
self.matcher = cls.matcher_function(**kwargs)
else:
self.expected = True
self.matcher = cls.matcher_function()
def _matches(self, item: T) -> bool:
"""passthrough to the matcher's method."""
return self.matcher.matches(item)
# truncated a whole bunch of other methods...
While these are likely better typehints, they didn't seem to do the trick.
class BaseResolution(BaseMatcher[T]):
matcher: Matcher[T]
matcher_function: Callable[..., Matcher[T]]
I know you can do something sorta similar using TypeVar(bound=) which will infer function return types based on arguments passed in. But I can't seem to figure out how (if even possible) to apply that at a class attribute level.
from typing import Type, TypeVar, Generic
T = TypeVar("T")
class Foo(Generic[T]):
...
class FooBar(Foo[T]):
...
F = TypeVar("F", bound=Foo) # any instance subclass of Foo
class MyClass(FooBar):
...
def bar(f: Type[F]) -> F:
...
def baz(f: Type[Foo]) -> Foo:
...
objx = bar(MyClass)
objy = baz(MyClass)
reveal_type(objx) # -> MyClass*
reveal_type(objy) # -> Foo[Any]
Given the above example I tried the following but that clearly isn't right.
F = TypeVar("F", bound=Matcher)
class BaseResolution(BaseMatcher[T]):
matcher: F
matcher_function: Callable[..., F]
# mypy returns
# Type variable "base_resolution.F" is unbound

Convert json field in body to enum with flask-restx

I have the following api definition for flask-restx (though should also work with flask-restplus).
Is there someway to convert the enum-field in the request body the the Enum MyEnum without too much overhead or using DAOs?
class MyEnum(Enum):
FOO = auto()
BAR = auto()
#dataclass(frozen=True)
class MyClass:
enum: MyEnum
api = Namespace('ns')
model = api.model('Model', {
'enum': fields.String(enum=[x.name for x in MyEnum]),
})
#api.route('/')
class MyClass(Resource):
#api.expect(Model)
def post(self) -> None:
c = MyClass(**api.payload)
print(type(c.enum)) # <class 'str'> (but I want <enum 'MyEnum'>)
assert(type(c.enum) == MyEnum) # Fails
Ok I have written a decorator which will replace the enum value with the enum
def decode_enum(api: Namespace, enum_cls: Type[Enum], keys: List[str]):
def replace_item(obj: dict, keys_: List[str], new_value: Type[Enum]):
if not keys_:
return new_value
obj[keys_[0]] = replace_item(obj[keys_[0]], keys_[1:], new_value)
return obj
def decoder(f):
#wraps(f)
def wrapper(*args, **kwds):
value = api.payload
for k in keys:
value = value[k]
enum = enum_cls[value]
api.payload[keys[0]] = replace_item(api.payload[keys[0]], keys[1:], enum)
return f(*args, **kwds)
return wrapper
return decoder
The usage would be like this
#decode_enum(api, MyEnum, ['enum'])
#api.expect(Model)
def post(self) -> None:
c = MyClass(**api.payload)
print(type(c.enum)) # <enum 'MyEnum'>
The replace_item function was inspired by this SO Answer: https://stackoverflow.com/a/45335542/6900162

How to extend an enum with aliases

I have an abstract base class GameNodeState that contains a Type enum:
import abc
import enum
class GameNodeState(metaclass=abc.ABCMeta):
class Type(enum.Enum):
INIT = enum.auto()
INTERMEDIATE = enum.auto()
END = enum.auto()
The names in the enum are generic because they must make sense for any subclass of GameNodeState. But when I subclass GameNodeState, as GameState and RoundState, I would like to be able to add concrete aliases to the members of GameNodeState.Type if the enum is accessed through the subclass. For example, if the GameState subclass aliases INTERMEDIATE as ROUND and RoundState aliases INTERMEDIATE as TURN, I would like the following behaviour:
>>> GameNodeState.Type.INTERMEDIATE
<Type.INTERMEDIATE: 2>
>>> RoundState.Type.TURN
<Type.INTERMEDIATE: 2>
>>> RoundState.Type.INTERMEDIATE
<Type.INTERMEDIATE: 2>
>>> GameNodeState.Type.TURN
AttributeError: TURN
My first thought was this:
class GameState(GameNodeState):
class Type(GameNodeState.Type):
ROUND = GameNodeState.Type.INTERMEDIATE.value
class RoundState(GameNodeState):
class Type(GameNodeState.Type):
TURN = GameNodeState.Type.INTERMEDIATE.value
But enums can't be subclassed.
Note: there are obviously more attributes and methods in the GameNodeState hierarchy, I stripped it down to the bare minimum here to focus on this particular thing.
Refinement
(Original solution below.)
I've extracted an intermediate concept from the code above, namely the concept of enum union. This can be used to obtain the behaviour above, and is also useful in other contexts too. The code can be foud here, and I've asked a Code Review question.
I'll add the code here as well for reference:
import enum
import itertools as itt
from functools import reduce
import operator
from typing import Literal, Union
import more_itertools as mitt
AUTO = object()
class UnionEnumMeta(enum.EnumMeta):
"""
The metaclass for enums which are the union of several sub-enums.
Union enums have the _subenums_ attribute which is a tuple of the enums forming the
union.
"""
#classmethod
def make_union(
mcs, *subenums: enum.EnumMeta, name: Union[str, Literal[AUTO], None] = AUTO
) -> enum.EnumMeta:
"""
Create an enum whose set of members is the union of members of several enums.
Order matters: where two members in the union have the same value, they will
be considered as aliases of each other, and the one appearing in the first
enum in the sequence will be used as the canonical members (the aliases will
be associated to this enum member).
:param subenums: Sequence of sub-enums to make a union of.
:param name: Name to use for the enum class. AUTO will result in a combination
of the names of all subenums, None will result in "UnionEnum".
:return: An enum class which is the union of the given subenums.
"""
subenums = mcs._normalize_subenums(subenums)
class UnionEnum(enum.Enum, metaclass=mcs):
pass
union_enum = UnionEnum
union_enum._subenums_ = subenums
if duplicate_names := reduce(
set.intersection, (set(subenum.__members__) for subenum in subenums)
):
raise ValueError(
f"Found duplicate member names in enum union: {duplicate_names}"
)
# If aliases are defined, the canonical member will be the one that appears
# first in the sequence of subenums.
# dict union keeps last key so we have to do it in reverse:
union_enum._value2member_map_ = value2member_map = reduce(
operator.or_, (subenum._value2member_map_ for subenum in reversed(subenums))
)
# union of the _member_map_'s but using the canonical member always:
union_enum._member_map_ = member_map = {
name: value2member_map[member.value]
for name, member in itt.chain.from_iterable(
subenum._member_map_.items() for subenum in subenums
)
}
# only include canonical aliases in _member_names_
union_enum._member_names_ = list(
mitt.unique_everseen(
itt.chain.from_iterable(subenum._member_names_ for subenum in subenums),
key=member_map.__getitem__,
)
)
if name is AUTO:
name = (
"".join(subenum.__name__.removesuffix("Enum") for subenum in subenums)
+ "UnionEnum"
)
UnionEnum.__name__ = name
elif name is not None:
UnionEnum.__name__ = name
return union_enum
def __repr__(cls):
return f"<union of {', '.join(map(str, cls._subenums_))}>"
def __instancecheck__(cls, instance):
return any(isinstance(instance, subenum) for subenum in cls._subenums_)
#classmethod
def _normalize_subenums(mcs, subenums):
"""Remove duplicate subenums and flatten nested unions"""
# we will need to collapse at most one level of nesting, with the inductive
# hypothesis that any previous unions are already flat
subenums = mitt.collapse(
(e._subenums_ if isinstance(e, mcs) else e for e in subenums),
base_type=enum.EnumMeta,
)
subenums = mitt.unique_everseen(subenums)
return tuple(subenums)
def enum_union(*enums, **kwargs):
return UnionEnumMeta.make_union(*enums, **kwargs)
Once we have that, we can just define the extend_enum decorator to compute the union of the base enum and the enum "extension", which will result in the desired behaviour:
def extend_enum(base_enum):
def decorator(extension_enum):
return enum_union(base_enum, extension_enum)
return decorator
Usage:
class GameNodeState(metaclass=abc.ABCMeta):
class Type(enum.Enum):
INIT = enum.auto()
INTERMEDIATE = enum.auto()
END = enum.auto()
class RoundState(GameNodeState):
#extend_enum(GameNodeState.Type)
class Type(enum.Enum):
TURN = GameNodeState.Type.INTERMEDIATE.value
class GameState(GameNodeState):
#extend_enum(GameNodeState.Type)
class Type(enum.Enum):
ROUND = GameNodeState.Type.INTERMEDIATE.value
Now all of the examples above produce the same output (plus the added instance check, i.e. isinstance(RoundState.Type.TURN, RoundState.Type) returns True).
I think this is a cleaner solution because it doesn't involve mucking around with descriptors; it doesn't need to know anything about the owner class (this works just as well with top-level classes).
Attribute lookup through subclasses and instances of GameNodeState should automatically link to the correct "extension" (i.e., union), as long as the extension enum is added with the same name as for the GameNodeState superclass so that it hides the original definition.
Original
Not sure how bad of an idea this is, but here is a solution using a descriptor wrapped around the enum that gets the set of aliases based on the class from which it is being accessed.
class ExtensibleClassEnum:
class ExtensionWrapperMeta(enum.EnumMeta):
#classmethod
def __prepare__(mcs, name, bases):
# noinspection PyTypeChecker
classdict: enum._EnumDict = super().__prepare__(name, bases)
classdict["_ignore_"] = ["base_descriptor", "extension_enum"]
return classdict
# noinspection PyProtectedMember
def __new__(mcs, cls, bases, classdict):
base_descriptor = classdict.pop("base_descriptor")
extension_enum = classdict.pop("extension_enum")
wrapper_enum = super().__new__(mcs, cls, bases, classdict)
wrapper_enum.base_descriptor = base_descriptor
wrapper_enum.extension_enum = extension_enum
base, extension = base_descriptor.base_enum, extension_enum
if set(base._member_map_.keys()) & set(extension._member_map_.keys()):
raise ValueError("Found duplicate names in extension")
# dict union keeps last key so we have to do it in reverse:
wrapper_enum._value2member_map_ = (
extension._value2member_map_ | base._value2member_map_
)
# union of both _member_map_'s but using the canonical member always:
wrapper_enum._member_map_ = {
name: wrapper_enum._value2member_map_[member.value]
for name, member in itertools.chain(
base._member_map_.items(), extension._member_map_.items()
)
}
# aliases shouldn't appear in _member_names_
wrapper_enum._member_names_ = list(
m.name for m in wrapper_enum._value2member_map_.values()
)
return wrapper_enum
def __repr__(self):
# have to use vars() to avoid triggering the descriptor
base_descriptor = vars(self)["base_descriptor"]
return (
f"<extension wrapper enum for {base_descriptor.base_enum}"
f" in {base_descriptor._extension2owner[self]}>"
)
def __init__(self, base_enum):
if not issubclass(base_enum, enum.Enum):
raise TypeError(base_enum)
self.base_enum = base_enum
# The user won't be able to retrieve the descriptor object itself, just
# the enum, so we have to forward calls to register_extension:
self.base_enum.register_extension = staticmethod(self.register_extension)
# mapping of owner class -> extension for subclasses that define an extension
self._extensions: Dict[Type, ExtensibleClassEnum.ExtensionWrapperMeta] = {}
# reverse mapping
self._extension2owner: Dict[ExtensibleClassEnum.ExtensionWrapperMeta, Type] = {}
# add the base enum as the base extension via __set_name__:
self._pending_extension = base_enum
#property
def base_owner(self):
# will be initialised after __set_name__ is called with base owner
return self._extension2owner[self.base_enum]
def __set_name__(self, owner, name):
# step 2 of register_extension: determine the class that defined it
self._extensions[owner] = self._pending_extension
self._extension2owner[self._pending_extension] = owner
del self._pending_extension
def __get__(self, instance, owner):
# Only compute extensions once:
if owner in self._extensions:
return self._extensions[owner]
# traverse in MRO until we find the closest supertype defining an extension
for supertype in owner.__mro__:
if supertype in self._extensions:
extension = self._extensions[supertype]
break
else:
raise TypeError(f"{owner} is not a subclass of {self.base_owner}")
# Cache the result
self._extensions[owner] = extension
return extension
def make_extension(self, extension: enum.EnumMeta):
class ExtensionWrapperEnum(
enum.Enum, metaclass=ExtensibleClassEnum.ExtensionWrapperMeta
):
base_descriptor = self
extension_enum = extension
return ExtensionWrapperEnum
def register_extension(self, extension_enum):
"""Decorator for enum extensions"""
# need a way to determine owner class
# add a temporary attribute that we will use when __set_name__ is called:
if hasattr(self, "_pending_extension"):
# __set_name__ not called after the previous call to register_extension
raise RuntimeError(
"An extension was created outside of a class definition",
self._pending_extension,
)
self._pending_extension = self.make_extension(extension_enum)
return self
Usage would be as follows:
class GameNodeState(metaclass=abc.ABCMeta):
#ExtensibleClassEnum
class Type(enum.Enum):
INIT = enum.auto()
INTERMEDIATE = enum.auto()
END = enum.auto()
class RoundState(GameNodeState):
#GameNodeState.Type.register_extension
class Type(enum.Enum):
TURN = GameNodeState.Type.INTERMEDIATE.value
class GameState(GameNodeState):
#GameNodeState.Type.register_extension
class Type(enum.Enum):
ROUND = GameNodeState.Type.INTERMEDIATE.value
Then:
>>> (RoundState.Type.TURN
... == RoundState.Type.INTERMEDIATE
... == GameNodeState.Type.INTERMEDIATE
... == GameState.Type.INTERMEDIATE
... == GameState.Type.ROUND)
...
True
>>> RoundState.Type.__members__
mappingproxy({'INIT': <Type.INIT: 1>,
'INTERMEDIATE': <Type.INTERMEDIATE: 2>,
'END': <Type.END: 3>,
'TURN': <Type.INTERMEDIATE: 2>})
>>> list(RoundState.Type)
[<Type.INTERMEDIATE: 2>, <Type.INIT: 1>, <Type.END: 3>]
>>> GameNodeState.Type.TURN
Traceback (most recent call last):
...
File "C:\Program Files\Python39\lib\enum.py", line 352, in __getattr__
raise AttributeError(name) from None
AttributeError: TURN

Categories