Related
Basically I need following. I have a python3 dataclass or NamedTuple, with only enum and bool fields. E.g.:
from enum import Enum, auto
from typing import NamedTuple
class MyEnum(Enum):
v1 = auto()
v2 = auto()
v3 = auto()
class MyStateDefinition(NamedTuple):
a: MyEnum
b: bool
Is here any good known solution to enumerate all possible non-equal instances of such a dataclass? (Example above has 6 possible non-equal instances).
Perhaps it is not a dataclass I should use, but something else. Or should I play with such things like dataclasses.fields directly?
I imagine it as some table generator which accepts a namedtuple or dataclass as an input parameter and produces all possible values.
table = DataTable(MyStateDefinition)
for item in table:
# Use items somehow
print(item.a)
print(item.b)
Why do I need it? I just have some state definition which consists of enums and bools. I believe it could be implemented as a bitmask. But when it comes to extending your bitmask with new values, it turns out to be a nightmare. Afterall, bitmasks seem to be a non-pythonic way of doing things.
Currently I have to use an implementation of my own. But perhaps I'm reinventing the wheel.
Thanks!
You can do this using enums, with the data-tuples as the enum-members' value (an Enum/NamedTuple hybrid, if you will). The _ignore_ attribute is used to prevent certain names in the class namespace from being converted into enum members.
from itertools import product
from enum import Enum
class Data(Enum):
_ignore_ = "Data", "myenum_member", "truthiness"
#property
def a(self):
return self.value[0]
#property
def b(self):
return self.value[1]
def __repr__(self):
return f'Data(a={self.a!r}, b={self.b!r})'
Data = vars()
for myenum_member, truthiness in product(MyEnum, (True, False)):
Data[f'{myenum_member.name}_{truthiness}'] = (myenum_member, truthiness)
You should be able to iterate through the resulting enum class just as you desire.
This use of enums is similar to the "time period" example in the Enum HOWTO section of the docs.
Generating this kind of table dynamically
If you want to generate this kind of table dynamically, you could do something like this, (ab)using metaclasses. I've shown example usages for how you would use this DataTable class in the docstrings. (For some reason, using typing.get_type_hints in a doctest seems to cause the doctest module to error out, but the examples do work if you try them yourself in an interactive terminal.) Rather than special-casing bool, as you did in your answer, I decided to special-case typing.Literal, as it seemed like a more extensible option (and bool can just be spelled as typing.Literal[True, False]).
from __future__ import annotations
from itertools import product
from enum import Enum, EnumMeta
from typing import (
Iterable,
Mapping,
cast,
Protocol,
get_type_hints,
Any,
get_args,
get_origin,
Literal,
TypeVar,
Union,
Optional
)
D = TypeVar('D')
T = TypeVar('T')
class DataTableFactory(EnumMeta):
"""A helper class for making data tables (an implementation detail of `DataTable`)."""
_CLS_BASES = (Enum,)
#classmethod
def __prepare__( # type: ignore[override]
metacls,
cls_name: str,
fields: Mapping[str, Iterable[Any]]
) -> dict[str, Any]:
cls_dict = cast(
dict[str, Any],
super().__prepare__(cls_name, metacls._CLS_BASES)
)
for i, field in enumerate(fields.keys()):
cls_dict[field] = property(fget=lambda self, i=i: self.value[i]) # type: ignore[misc]
for p in product(*fields.values()):
cls_dict['_'.join(map(str, p))] = p
def __repr__(self: Enum) -> str:
contents = ', '.join(
f'{field}={getattr(self, field)!r}'
for field in fields
)
return f'{cls_name}Member({contents})'
cls_dict['__repr__'] = __repr__
return cls_dict
#classmethod
def make_datatable(
metacls,
cls_name: str,
*,
fields: Mapping[str, Iterable[Any]],
doc: Optional[str] = None
) -> type[Enum]:
"""Create a new data table"""
cls_dict = metacls.__prepare__(cls_name, fields)
new_cls = metacls.__new__(metacls, cls_name, metacls._CLS_BASES, cls_dict)
new_cls.__module__ = __name__
if doc is None:
all_attrs = '\n'.join(
f' {f"{attr_name}: ":<{(max(map(len, fields)) + 3)}}one of {attr_val!r}'
for attr_name, attr_val in fields.items()
)
fields_len = len(fields)
doc = (
f'An enum-like data table.\n\n'
f'All members of this data table have {fields_len} '
f'read-only attribute{"s" if fields_len > 1 else ""}:\n'
f'{all_attrs}\n\n'
f'----------------------------------------------------------------------'
)
new_cls.__doc__ = doc
return cast(type[Enum], new_cls)
def __repr__(cls) -> str:
return f"<Data table '{cls.__name__}'>"
def index_of(cls: Iterable[D], member: D) -> int:
"""Get the index of a member in the list of members."""
return list(cls).index(member)
def get(
cls: Iterable[D],
/,
*,
default_: Optional[T] = None,
**kwargs: Any
) -> Union[D, T, None]:
"""Return instance for given arguments set.
Return `default_` if no member matches those arguments.
"""
it = (
member for member in cls
if all((getattr(member, key) == val) for key, val in kwargs.items())
)
return next(it, default_)
def __dir__(cls) -> list[str]:
# By defining __dir__, we make methods defined in this class
# discoverable by the interactive help() function in the REPL
return cast(list[str], super().__dir__()) + ['index_of', 'get']
class TypedStructProto(Protocol):
"""In order to satisfy this interface, a type must have an __annotations__ dict."""
__annotations__: dict[str, Union[Iterable[Any], type[Literal[True]]]]
class DataTableMeta(type):
"""Metaclass for `DataTable`."""
__call__ = DataTableFactory.make_datatable # type: ignore[assignment]
class DataTable(metaclass=DataTableMeta):
"""A mechanism to create 'data table enumerations' -- not really a class at all!
Example usage
-------------
>>> Cars = DataTable('Cars', fields={'make': ('Toyota', 'Audi'), 'colour': ('Red', 'Blue')})
>>> Cars
<Data table 'Cars'>
>>> list(Cars)
[CarsMember(make=Toyota, colour=Red), CarsMember(make=Toyota, colour=Blue), CarsMember(make=Audi, colour=Red), CarsMember(make=Audi, colour=Blue)]
>>> Cars.get(make='Audi', colour='Red')
CarsMember(make=Audi, colour=Red)
>>> Cars.index_of(_)
2
"""
#classmethod
def from_struct(cls, cls_name: str, *, struct: type[TypedStructProto], doc: Optional[str] = None) -> type[Enum]:
"""Make a DataTable from a "typed struct" -- e.g. a dataclass, NamedTuple or TypedDict.
Example usage (works the same way with dataclasses and TypedDicts)
-------------------------------------------------------------------
>>> from enum import Enum, auto
>>> from typing import NamedTuple, Literal
>>> class E(Enum):
... v1 = auto()
... v2 = auto()
... v3 = auto()
...
>>> class BoolsEndEnums(NamedTuple):
... a: E
... b: Literal[True, False]
...
>>> BoolsEndEnumsTable = DataTable.from_struct('BoolsEndEnumsTable', struct=BoolsEndEnums)
>>> list(BoolsEndEnumsTable)
[BoolsEndEnumsTableMember(a=E.v1, b=True), BoolsEndEnumsTableMember(a=E.v1, b=False), BoolsEndEnumsTableMember(a=E.v2, b=True), BoolsEndEnumsTableMember(a=E.v2, b=False), BoolsEndEnumsTableMember(a=E.v3, b=True), BoolsEndEnumsTableMember(a=E.v3, b=False)]
"""
fields = get_type_hints(struct)
for field_name, field_val in fields.items():
if get_origin(field_val) is Literal:
fields[field_name] = get_args(field_val)
return cast(type[Enum], cls(cls_name, fields=fields, doc=doc)) # type: ignore[call-arg]
I've had to do some "interesting" things with the type hints, but MyPy is sort of happy with all this.
Also posting implementation of my own. Not ideal, I had to use some protected members.
Usage:
from typing import NamedTuple
from datatable import DataTable
class BoolsEndEnums(NamedTuple):
a: E
b: bool
tbl = DataTable(BoolsEndEnums)
item = tbl[0]
print(item.a) # a is v1
print(item.b) # b is False
See test_datatable.py, _test_cls for more usage examples.
datatable.py
import collections
import dataclasses
from collections import Iterable
from enum import Enum
from typing import Union, Any, Tuple, Iterator, get_type_hints, NamedTuple
def is_cls_namedtuple(cls):
return issubclass(cls, tuple) and hasattr(cls, "_fields")
class DataTable(Iterable):
def __init__(self, data_cls):
self._table = []
self._index = {}
self._rindex = {}
self._named_tuple_cls = None
fields = None
if dataclasses.is_dataclass(data_cls):
fields = [f.name for f in dataclasses.fields(data_cls)]
self._named_tuple_cls = collections.namedtuple(
f"{data_cls.__name__}_immutable",
fields
)
elif is_cls_namedtuple(data_cls):
self._named_tuple_cls = data_cls
fields = data_cls._fields
else:
raise ValueError(
"Only dataclasses and NamedTuple subclasses are supported."
)
hints = get_type_hints(data_cls)
self._build_table([], [(f, hints[f]) for f in fields])
def index_of(self, instance):
"""
Returns record index of given instance in table.
:param instance:
:return:
"""
index = self._as_index(instance)
return self._rindex.get(index)
def get(self, **kw):
"""
Returns instance for given arguments set
:param kw:
:return:
"""
index = self._as_index(kw)
return self._table[self._rindex[index]]
def __len__(self):
return len(self._table)
def __getitem__(self, i: Union[int, slice]):
return self._table[i]
def __iter__(self) -> Iterator:
return self._table.__iter__()
def _build_table(self, defined_fields, remained_fields):
if not remained_fields:
instance = self._named_tuple_cls(**dict(defined_fields))
item_id = len(self._table)
self._index[item_id] = instance
self._rindex[self._as_index(defined_fields)] = item_id
self._table.append(instance)
return
next_name, next_type = remained_fields[0]
remained_fields = remained_fields[1:]
if issubclass(next_type, Enum):
for v in next_type:
self._build_table(
defined_fields + [(next_name, v)],
remained_fields
)
return
if next_type is bool:
self._build_table(
defined_fields + [(next_name, False)],
remained_fields
)
self._build_table(
defined_fields + [(next_name, True)],
remained_fields
)
return
raise ValueError(f"Got unexpected dataclass field type: {next_type}")
#staticmethod
def _as_index(v: Union[Any, Tuple[str, Any]]):
items = None
if dataclasses.is_dataclass(v):
items = dataclasses.asdict(v).items()
elif is_cls_namedtuple(type(v)):
items = v._asdict().items()
elif isinstance(v, dict):
items = v.items()
else:
assert isinstance(v, collections.Sequence)
items = v
return tuple(sorted(items, key=lambda x: x[0]))
test_datatable.py
import dataclasses
from enum import Enum, auto
from typing import NamedTuple
import pytest
from dataclass_utils import DataTable
class E(Enum):
v1 = auto()
v2 = auto()
v3 = auto()
#dataclasses.dataclass
class BoolsEndEnums:
a: E
b: bool
class BoolsEndEnumsNamedTuple(NamedTuple):
a: E
b: bool
#dataclasses.dataclass
class HugeSetOfValues:
a: int
b: bool
class NotSupportedCls:
pass
def _test_cls(cls):
tbl = DataTable(cls)
first = cls(E.v1, False)
last = cls(E.v3, True)
expected_num_entries = 6
assert tbl.index_of(first) == 0
assert tbl.index_of(last) == (expected_num_entries - 1)
assert len(tbl) == expected_num_entries
actual_third = tbl.get(a=E.v2, b=False)
assert actual_third.a == E.v2
assert actual_third.b is False
actual_forth = tbl[3]
assert actual_forth.a == E.v2
assert actual_forth.b is True
items = [item for item in tbl]
actual_fifth = items[4]
assert actual_fifth.a == E.v3
assert actual_fifth.b is False
# Test that we can't change result
with pytest.raises(AttributeError):
tbl[0].a = E.v2
def test_dataclass():
_test_cls(BoolsEndEnums)
def test_namedtuple():
_test_cls(BoolsEndEnumsNamedTuple)
def test_datatable_neg():
"""
Generic negative tests
"""
with pytest.raises(ValueError):
DataTable(HugeSetOfValues)
with pytest.raises(ValueError):
DataTable(NotSupportedCls)
I have an abstract base class GameNodeState that contains a Type enum:
import abc
import enum
class GameNodeState(metaclass=abc.ABCMeta):
class Type(enum.Enum):
INIT = enum.auto()
INTERMEDIATE = enum.auto()
END = enum.auto()
The names in the enum are generic because they must make sense for any subclass of GameNodeState. But when I subclass GameNodeState, as GameState and RoundState, I would like to be able to add concrete aliases to the members of GameNodeState.Type if the enum is accessed through the subclass. For example, if the GameState subclass aliases INTERMEDIATE as ROUND and RoundState aliases INTERMEDIATE as TURN, I would like the following behaviour:
>>> GameNodeState.Type.INTERMEDIATE
<Type.INTERMEDIATE: 2>
>>> RoundState.Type.TURN
<Type.INTERMEDIATE: 2>
>>> RoundState.Type.INTERMEDIATE
<Type.INTERMEDIATE: 2>
>>> GameNodeState.Type.TURN
AttributeError: TURN
My first thought was this:
class GameState(GameNodeState):
class Type(GameNodeState.Type):
ROUND = GameNodeState.Type.INTERMEDIATE.value
class RoundState(GameNodeState):
class Type(GameNodeState.Type):
TURN = GameNodeState.Type.INTERMEDIATE.value
But enums can't be subclassed.
Note: there are obviously more attributes and methods in the GameNodeState hierarchy, I stripped it down to the bare minimum here to focus on this particular thing.
Refinement
(Original solution below.)
I've extracted an intermediate concept from the code above, namely the concept of enum union. This can be used to obtain the behaviour above, and is also useful in other contexts too. The code can be foud here, and I've asked a Code Review question.
I'll add the code here as well for reference:
import enum
import itertools as itt
from functools import reduce
import operator
from typing import Literal, Union
import more_itertools as mitt
AUTO = object()
class UnionEnumMeta(enum.EnumMeta):
"""
The metaclass for enums which are the union of several sub-enums.
Union enums have the _subenums_ attribute which is a tuple of the enums forming the
union.
"""
#classmethod
def make_union(
mcs, *subenums: enum.EnumMeta, name: Union[str, Literal[AUTO], None] = AUTO
) -> enum.EnumMeta:
"""
Create an enum whose set of members is the union of members of several enums.
Order matters: where two members in the union have the same value, they will
be considered as aliases of each other, and the one appearing in the first
enum in the sequence will be used as the canonical members (the aliases will
be associated to this enum member).
:param subenums: Sequence of sub-enums to make a union of.
:param name: Name to use for the enum class. AUTO will result in a combination
of the names of all subenums, None will result in "UnionEnum".
:return: An enum class which is the union of the given subenums.
"""
subenums = mcs._normalize_subenums(subenums)
class UnionEnum(enum.Enum, metaclass=mcs):
pass
union_enum = UnionEnum
union_enum._subenums_ = subenums
if duplicate_names := reduce(
set.intersection, (set(subenum.__members__) for subenum in subenums)
):
raise ValueError(
f"Found duplicate member names in enum union: {duplicate_names}"
)
# If aliases are defined, the canonical member will be the one that appears
# first in the sequence of subenums.
# dict union keeps last key so we have to do it in reverse:
union_enum._value2member_map_ = value2member_map = reduce(
operator.or_, (subenum._value2member_map_ for subenum in reversed(subenums))
)
# union of the _member_map_'s but using the canonical member always:
union_enum._member_map_ = member_map = {
name: value2member_map[member.value]
for name, member in itt.chain.from_iterable(
subenum._member_map_.items() for subenum in subenums
)
}
# only include canonical aliases in _member_names_
union_enum._member_names_ = list(
mitt.unique_everseen(
itt.chain.from_iterable(subenum._member_names_ for subenum in subenums),
key=member_map.__getitem__,
)
)
if name is AUTO:
name = (
"".join(subenum.__name__.removesuffix("Enum") for subenum in subenums)
+ "UnionEnum"
)
UnionEnum.__name__ = name
elif name is not None:
UnionEnum.__name__ = name
return union_enum
def __repr__(cls):
return f"<union of {', '.join(map(str, cls._subenums_))}>"
def __instancecheck__(cls, instance):
return any(isinstance(instance, subenum) for subenum in cls._subenums_)
#classmethod
def _normalize_subenums(mcs, subenums):
"""Remove duplicate subenums and flatten nested unions"""
# we will need to collapse at most one level of nesting, with the inductive
# hypothesis that any previous unions are already flat
subenums = mitt.collapse(
(e._subenums_ if isinstance(e, mcs) else e for e in subenums),
base_type=enum.EnumMeta,
)
subenums = mitt.unique_everseen(subenums)
return tuple(subenums)
def enum_union(*enums, **kwargs):
return UnionEnumMeta.make_union(*enums, **kwargs)
Once we have that, we can just define the extend_enum decorator to compute the union of the base enum and the enum "extension", which will result in the desired behaviour:
def extend_enum(base_enum):
def decorator(extension_enum):
return enum_union(base_enum, extension_enum)
return decorator
Usage:
class GameNodeState(metaclass=abc.ABCMeta):
class Type(enum.Enum):
INIT = enum.auto()
INTERMEDIATE = enum.auto()
END = enum.auto()
class RoundState(GameNodeState):
#extend_enum(GameNodeState.Type)
class Type(enum.Enum):
TURN = GameNodeState.Type.INTERMEDIATE.value
class GameState(GameNodeState):
#extend_enum(GameNodeState.Type)
class Type(enum.Enum):
ROUND = GameNodeState.Type.INTERMEDIATE.value
Now all of the examples above produce the same output (plus the added instance check, i.e. isinstance(RoundState.Type.TURN, RoundState.Type) returns True).
I think this is a cleaner solution because it doesn't involve mucking around with descriptors; it doesn't need to know anything about the owner class (this works just as well with top-level classes).
Attribute lookup through subclasses and instances of GameNodeState should automatically link to the correct "extension" (i.e., union), as long as the extension enum is added with the same name as for the GameNodeState superclass so that it hides the original definition.
Original
Not sure how bad of an idea this is, but here is a solution using a descriptor wrapped around the enum that gets the set of aliases based on the class from which it is being accessed.
class ExtensibleClassEnum:
class ExtensionWrapperMeta(enum.EnumMeta):
#classmethod
def __prepare__(mcs, name, bases):
# noinspection PyTypeChecker
classdict: enum._EnumDict = super().__prepare__(name, bases)
classdict["_ignore_"] = ["base_descriptor", "extension_enum"]
return classdict
# noinspection PyProtectedMember
def __new__(mcs, cls, bases, classdict):
base_descriptor = classdict.pop("base_descriptor")
extension_enum = classdict.pop("extension_enum")
wrapper_enum = super().__new__(mcs, cls, bases, classdict)
wrapper_enum.base_descriptor = base_descriptor
wrapper_enum.extension_enum = extension_enum
base, extension = base_descriptor.base_enum, extension_enum
if set(base._member_map_.keys()) & set(extension._member_map_.keys()):
raise ValueError("Found duplicate names in extension")
# dict union keeps last key so we have to do it in reverse:
wrapper_enum._value2member_map_ = (
extension._value2member_map_ | base._value2member_map_
)
# union of both _member_map_'s but using the canonical member always:
wrapper_enum._member_map_ = {
name: wrapper_enum._value2member_map_[member.value]
for name, member in itertools.chain(
base._member_map_.items(), extension._member_map_.items()
)
}
# aliases shouldn't appear in _member_names_
wrapper_enum._member_names_ = list(
m.name for m in wrapper_enum._value2member_map_.values()
)
return wrapper_enum
def __repr__(self):
# have to use vars() to avoid triggering the descriptor
base_descriptor = vars(self)["base_descriptor"]
return (
f"<extension wrapper enum for {base_descriptor.base_enum}"
f" in {base_descriptor._extension2owner[self]}>"
)
def __init__(self, base_enum):
if not issubclass(base_enum, enum.Enum):
raise TypeError(base_enum)
self.base_enum = base_enum
# The user won't be able to retrieve the descriptor object itself, just
# the enum, so we have to forward calls to register_extension:
self.base_enum.register_extension = staticmethod(self.register_extension)
# mapping of owner class -> extension for subclasses that define an extension
self._extensions: Dict[Type, ExtensibleClassEnum.ExtensionWrapperMeta] = {}
# reverse mapping
self._extension2owner: Dict[ExtensibleClassEnum.ExtensionWrapperMeta, Type] = {}
# add the base enum as the base extension via __set_name__:
self._pending_extension = base_enum
#property
def base_owner(self):
# will be initialised after __set_name__ is called with base owner
return self._extension2owner[self.base_enum]
def __set_name__(self, owner, name):
# step 2 of register_extension: determine the class that defined it
self._extensions[owner] = self._pending_extension
self._extension2owner[self._pending_extension] = owner
del self._pending_extension
def __get__(self, instance, owner):
# Only compute extensions once:
if owner in self._extensions:
return self._extensions[owner]
# traverse in MRO until we find the closest supertype defining an extension
for supertype in owner.__mro__:
if supertype in self._extensions:
extension = self._extensions[supertype]
break
else:
raise TypeError(f"{owner} is not a subclass of {self.base_owner}")
# Cache the result
self._extensions[owner] = extension
return extension
def make_extension(self, extension: enum.EnumMeta):
class ExtensionWrapperEnum(
enum.Enum, metaclass=ExtensibleClassEnum.ExtensionWrapperMeta
):
base_descriptor = self
extension_enum = extension
return ExtensionWrapperEnum
def register_extension(self, extension_enum):
"""Decorator for enum extensions"""
# need a way to determine owner class
# add a temporary attribute that we will use when __set_name__ is called:
if hasattr(self, "_pending_extension"):
# __set_name__ not called after the previous call to register_extension
raise RuntimeError(
"An extension was created outside of a class definition",
self._pending_extension,
)
self._pending_extension = self.make_extension(extension_enum)
return self
Usage would be as follows:
class GameNodeState(metaclass=abc.ABCMeta):
#ExtensibleClassEnum
class Type(enum.Enum):
INIT = enum.auto()
INTERMEDIATE = enum.auto()
END = enum.auto()
class RoundState(GameNodeState):
#GameNodeState.Type.register_extension
class Type(enum.Enum):
TURN = GameNodeState.Type.INTERMEDIATE.value
class GameState(GameNodeState):
#GameNodeState.Type.register_extension
class Type(enum.Enum):
ROUND = GameNodeState.Type.INTERMEDIATE.value
Then:
>>> (RoundState.Type.TURN
... == RoundState.Type.INTERMEDIATE
... == GameNodeState.Type.INTERMEDIATE
... == GameState.Type.INTERMEDIATE
... == GameState.Type.ROUND)
...
True
>>> RoundState.Type.__members__
mappingproxy({'INIT': <Type.INIT: 1>,
'INTERMEDIATE': <Type.INTERMEDIATE: 2>,
'END': <Type.END: 3>,
'TURN': <Type.INTERMEDIATE: 2>})
>>> list(RoundState.Type)
[<Type.INTERMEDIATE: 2>, <Type.INIT: 1>, <Type.END: 3>]
>>> GameNodeState.Type.TURN
Traceback (most recent call last):
...
File "C:\Program Files\Python39\lib\enum.py", line 352, in __getattr__
raise AttributeError(name) from None
AttributeError: TURN
I want to know a simple way to make a dataclass bar frozen.
#dataclass
class Bar:
foo: int
bar = Bar(foo=1)
In other words, I want a function like the following some_fn_to_freeze
frozen_bar = some_fn_to_freeze(bar)
frozen_bar.foo = 2 # Error
And, an inverse function some_fn_to_unfreeze
bar = som_fn_to_unfrozen(frozen_bar)
bar.foo = 3 # not Error
The standard way to mutate a frozen dataclass is to use dataclasses.replace:
old_bar = Bar(foo=123)
new_bar = dataclasses.replace(old_bar, foo=456)
assert new_bar.foo == 456
For more complex use-cases, you can use the dataclass utils module from: https://github.com/google/etils
It add a my_dataclass = my_dataclass.unfrozen() member, which allow to mutate frozen dataclasses directly
# pip install etils[edc]
from etils import edc
#edc.dataclass(allow_unfrozen=True) # Add the `unfrozen()`/`frozen` method
#dataclasses.dataclass(frozen=True)
class A:
x: Any = None
y: Any = None
old_a = A(x=A(x=A()))
# After a is unfrozen, the updates on nested attributes will be propagated
# to the top-level parent.
a = old_a.unfrozen()
a.x.x.x = 123
a.x.y = 'abc'
a = a.frozen() # `frozen()` recursively call `dataclasses.replace`
# Only the `unfrozen` object is mutated. Not the original one.
assert a == A(x=A(x=A(x = 123), y='abc'))
assert old_a == A(x=A(x=A()))
As seen in the example, you can return unfrozen/frozen copies of the dataclass, which was explicitly designed to mutate nested dataclasses.
#edc.dataclass also add a a.replace(**kwargs) method to the dataclass (alias of dataclasses.dataclass)
a = A()
a = a.replace(x=123, y=456)
assert a == A(x=123, y=456)
dataclass doesn't have built-in support for that. Frozen-ness is tracked on a class-wide basis, not per-instance, and there's no support for automatically generating frozen or unfrozen equivalents of dataclasses.
While you could try to do something to generate new dataclasses on the fly, it'd interact very poorly with isinstance, ==, and other things you'd want to work. It's probably safer to just write two dataclasses and converter methods:
#dataclass
class Bar:
foo: int
def as_frozen(self):
return FrozenBar(self.foo)
#dataclass(frozen=True)
class FrozenBar:
foo: int
def as_unfrozen(self):
return Bar(self.foo)
Python dataclasses are great, but the attrs package is a more flexible alternative, if you are able to use a third-party library. For example:
import attr
# Your class of interest.
#attr.s()
class Bar(object):
val = attr.ib()
# A frozen variant of it.
#attr.s(frozen = True)
class FrozenBar(Bar):
pass
# Three instances:
# - Bar.
# - FrozenBar based on that Bar.
# - Bar based on that FrozenBar.
b1 = Bar(123)
fb = FrozenBar(**attr.asdict(b1))
b2 = Bar(**attr.asdict(fb))
# We can modify the Bar instances.
b1.val = 777
b2.val = 888
# Check current vals.
for x in (b1, fb, b2):
print(x)
# But we cannot modify the FrozenBar instance.
try:
fb.val = 999
except attr.exceptions.FrozenInstanceError:
print(fb, 'unchanged')
Output:
Bar(val=888)
FrozenBar(val=123)
Bar(val=999)
FrozenBar(val=123) unchanged
I'm using the following code to get a frozen copy of a dataclass class or instance:
import dataclasses
from dataclasses import dataclass, fields, asdict
import typing
from typing import TypeVar
FDC_SELF = TypeVar('FDC_SELF', bound='FreezableDataClass')
#dataclass
class FreezableDataClass:
#classmethod
def get_frozen_dataclass(cls: Type[FDC_SELF]) -> Type[FDC_SELF]:
"""
#return: a generated frozen dataclass definition, compatible with the calling class
"""
cls_fields = fields(cls)
frozen_cls_name = 'Frozen' + cls.__name__
frozen_dc_namespace = {
'__name__': frozen_cls_name,
'__module__': __name__,
}
excluded_from_freezing = cls.attrs_excluded_from_freezing()
for attr in dir(cls):
if attr.startswith('__') or attr in excluded_from_freezing:
continue
attr_def = getattr(cls, attr)
if hasattr(attr_def, '__func__'):
attr_def = classmethod(getattr(attr_def, '__func__'))
frozen_dc_namespace[attr] = attr_def
frozen_dc = dataclasses.make_dataclass(
cls_name=frozen_cls_name,
fields=[(f.name, f.type, f) for f in cls_fields],
bases=(),
namespace=frozen_dc_namespace,
frozen=True,
)
globals()[frozen_dc.__name__] = frozen_dc
return frozen_dc
#classmethod
def attrs_excluded_from_freezing(cls) -> typing.Iterable[str]:
return tuple()
def get_frozen_instance(self: FDC_SELF) -> FDC_SELF:
"""
#return: an instance of a generated frozen dataclass, compatible with the current dataclass, with copied values
"""
cls = type(self)
frozen_dc = cls.get_frozen_dataclass()
# noinspection PyArgumentList
return frozen_dc(**asdict(self))
Derived classes could overwrite attrs_excluded_from_freezing to exclude methods which wouldn't work on a frozen dataclass.
Why didn't I prefer other existing answers?
3rd party libraries - etils.edc, If I would use a solution from one of the previous answers, it would be this one. E.g. to get the ability to recursively freeze/unfreeze.
3rd party libraries - attrs
duplicated code
Let's say, I have a pre-existing mapping as a dictionary:
value_map = {'a': 1, 'b': 2}
I can create an enum class from this like so:
from enum import Enum
MyEnum = Enum('MyEnum', value_map)
and use it like so
a = MyEnum.a
print(a.value)
>>> 1
print(a.name)
>>> 'a'
But then I want to define some methods to my new enum class:
def double_value(self):
return self.value * 2
Of course, i can do this:
class MyEnum(Enum):
a = 1
b = 2
#property
def double_value(self):
return self.value * 2
But as I said, I have to use a pre-defined value mapping dictionary, so I cannot do this.
How can this be achieved? I tried to inherit from another class defining this method like a mixin, but I could'nt figure it out.
You can pass in a base type with mixin methods into the functional API, with the type argument:
>>> import enum
>>> value_map = {'a': 1, 'b': 2}
>>> class DoubledEnum:
... #property
... def double_value(self):
... return self.value * 2
...
>>> MyEnum = enum.Enum('MyEnum', value_map, type=DoubledEnum)
>>> MyEnum.a.double_value
2
For a fully functional approach that never uses a class statement, you can create the base mix-in with the type() function:
DoubledEnum = type('DoubledEnum', (), {'double_value': property(double_value)})
MyEnum = enum.Enum('MyEnum', value_map, type=DoubledEnum)
You can also use enum.EnumMeta() metaclass the same way, the way Python would when you create a class MyEnum(enum.Enum): ... subclass:
Create a class dictionary using the metaclass __prepare__ hook
Call the metaclass, passing in the class name, the bases ((enum.Enum,) here), and the class dictionary created in step 1.
The custom dictionary subclass that enum.EnumMeta uses isn't really designed for easy reuse; it implements a __setitem__ hook to record metadata, but doesn't override the dict.update() method, so we need to use a little care when using your value_map dictionary:
import enum
def enum_with_extras(name, value_map, bases=enum.Enum, **extras):
if not isinstance(bases, tuple):
bases = bases,
if not any(issubclass(b, enum.Enum) for b in bases):
bases += enum.Enum,
classdict = enum.EnumMeta.__prepare__(name, bases)
for key, value in {**value_map, **extras}.items():
classdict[key] = value
return enum.EnumMeta(name, bases, classdict)
Then pass in double_value=property(double_value) to that function (together with the enum name and value_map dictionary):
>>> def double_value(self):
... return self.value * 2
...
>>> MyEnum = enum_with_extras('MyEnum', value_map, double_value=property(double_value))
>>> MyEnum.a
<MyEnum.a: 1>
>>> MyEnum.a.double_value
2
You are otherwise allowed to create subclasses of an enum without members (anything that's a descriptor is not a member, so functions, properties, classmethods, etc.), so you can define an enum without members first:
class DoubledEnum(enum.Enum):
#property
def double_value(self):
return self.value * 2
which is an acceptable base class for both in the functional API (e.g. enum.Enum(..., type=DoubledEnum)) and for the metaclass approach I encoded as enum_with_extras().
You can create a new meta class (Either using a meta-metaclass or a factory function, like I do below) that derives from enum.EnumMeta (The metaclass for enums) and just adds the members before creating the class
import enum
import collections.abc
def enum_metaclass_with_default(default_members):
"""Creates an Enum metaclass where `default_members` are added"""
if not isinstance(default_members, collections.abc.Mapping):
default_members = enum.Enum('', default_members).__members__
default_members = dict(default_members)
class EnumMetaWithDefaults(enum.EnumMeta):
def __new__(mcs, name, bases, classdict):
"""Updates classdict adding the default members and
creates a new Enum class with these members
"""
# Update the classdict with default_members
# if they don't already exist
for k, v in default_members.items():
if k not in classdict:
classdict[k] = v
# Add `enum.Enum` as a base class
# Can't use `enum.Enum` in `bases`, because
# that uses `==` instead of `is`
bases = tuple(bases)
for base in bases:
if base is enum.Enum:
break
else:
bases = (enum.Enum,) + bases
return super(EnumMetaWithDefaults, mcs).__new__(mcs, name, bases, classdict)
return EnumMetaWithDefaults
value_map = {'a': 1, 'b': 2}
class MyEnum(metaclass=enum_metaclass_with_default(value_map)):
#property
def double_value(self):
return self.value * 2
assert MyEnum.a.double_value == 2
A different solution was to directly try and update locals(), as it is replaced with a mapping that creates enum values when you try to assign values.
import enum
value_map = {'a': 1, 'b': 2}
def set_enum_values(locals, value_map):
# Note that we can't use `locals.update(value_map)`
# because it's `locals.__setitem__(k, v)` that
# creates the enum value, and `update` doesn't
# call `__setitem__`.
for k, v in value_map:
locals[k] = v
class MyEnum(enum.Enum):
set_enum_values(locals(), value_map)
#property
def double_value(self):
return self.value * 2
assert MyEnum.a.double_value == 2
This seems well defined enough, and a = 1 is most likely going to be the same as locals()['a'] = 1, but it might change in the future. The first solution is more robust and less hacky (And I haven't tested it in other Python implementations, but it probably works the same)
PLUS: Adding more stuff (a dirt hack) to #Artyer's answer. 🤗
Note that you can also provide "additional" capabilities to an Enum if you create it from a dict, see...
from enum import Enum
_colors = {"RED": (1, "It's the color of blood."), "BLUE": (2, "It's the color of the sky.")}
def _set_members_colors(locals: dict):
for k, v in colors.items():
locals[k] = v[0]
class Colors(int, Enum):
_set_members_colors(locals())
#property
def description(self):
return colors[self.name][1]
print(str(Colors.RED))
print(str(Colors.RED.value))
print(str(Colors.RED.description))
Output...
Colors.RED
1
It's the color of blood.
Thanks! 😉
I've been messing around with python's enum library and have come across a conundrum. In the docs, they show an example of an auto-numbering enum, wherein something is defined:
class Color(AutoNumber):
red = ()
green = ()
...
I want to make a similar class, but the value would automatically be set from the name of the member AND keep the functionality that you get from doing the str and enum mixin stuff
So something like:
class Animal(MagicStrEnum):
horse = ()
dog = ()
Animal.dog == 'dog' # True
I've looked at the source code of the enum module and tried a lot of variations messing around with __new__ and the EnumMeta class
Update: 2017-03-01
In Python 3.6 (and Aenum 2.01) Flag and IntFlag classes have been added; part of that was a new auto() helper that makes this trivially easy:
>>> class AutoName(Enum):
... def _generate_next_value_(name, start, count, last_values):
... return name
...
>>> class Ordinal(AutoName):
... NORTH = auto()
... SOUTH = auto()
... EAST = auto()
... WEST = auto()
...
>>> list(Ordinal)
[<Ordinal.NORTH: 'NORTH'>, <Ordinal.SOUTH: 'SOUTH'>, <Ordinal.EAST: 'EAST'>, <Ordinal.WEST: 'WEST'>]
Original answer
The difficulty with an AutoStr class is that the name of the enum member is not passed into the code that creates it, so it is unavailable for use. Another wrinkle is that str is immutable, so we can't change those types of enums after they have been created (by using a class decorator, for example).
The easiest thing to do is use the Functional API:
Animal = Enum('Animal', [(a, a) for a in ('horse', 'dog')], type=str)
which gives us:
>>> list(Animal)
[<Animal.horse: 'horse'>, <Animal.dog: 'dog'>]
>>> Animal.dog == 'dog'
True
The next easiest thing to do, assuming you want to make a base class for your future enumeration use, would be something like my DocEnem:
class DocEnum(Enum):
"""
compares equal to all cased versions of its name
accepts a doctring for each member
"""
def __new__(cls, *args):
"""Ignores arguments (will be handled in __init__)"""
obj = object.__new__(cls)
obj._value_ = None
return obj
def __init__(self, doc=None):
# first, fix _value_
self._value_ = self._name_.lower()
self.__doc__ = doc
def __eq__(self, other):
if isinstance(other, basestring):
return self._value_ == other.lower()
elif not isinstance(other, self.__class__):
return NotImplemented
return self is other
def __hash__(self):
# keep DocEnum hashable
return hash(self._value_)
def __ne__(self, other):
return not self == other
and in use:
class SpecKind(DocEnum):
REQUIRED = "required value"
OPTION = "single value per name"
MULTI = "multiple values per name (list form)"
FLAG = "boolean value per name"
KEYWORD = 'unknown options'
Note that unlike the first option, DocEnum members are not strs.
If you want to do it the hard way: subclass EnumMeta and fiddle with the new Enum's class dictionary before the members are created:
from enum import EnumMeta, Enum, _EnumDict
class StrEnumMeta(EnumMeta):
def __new__(metacls, cls, bases, oldclassdict):
"""
Scan through `oldclassdict` and convert any value that is a plain tuple
into a `str` of the name instead
"""
newclassdict = _EnumDict()
for k, v in oldclassdict.items():
if v == ():
v = k
newclassdict[k] = v
return super().__new__(metacls, cls, bases, newclassdict)
class AutoStrEnum(str, Enum, metaclass=StrEnumMeta):
"base class for name=value str enums"
class Animal(AutoStrEnum):
horse = ()
dog = ()
whale = ()
print(Animal.horse)
print(Animal.horse == 'horse')
print(Animal.horse.name, Animal.horse.value)
Which gives us:
Animal.horse
True
horse horse
1 Disclosure: I am the author of the Python stdlib Enum, the enum34 backport, and the Advanced Enumeration (aenum) library.
Perhaps you are looking for the name attribute which is automatically provided by the Enum class
>>> class Animal(Enum):
... ant = 1
... bee = 2
... cat = 3
... dog = 4
...
>>> Animal.ant.name == "ant"
True
Though if you really want to shoot yourself in the foot. And I'm sure this will introduce a whole world of gotchas (I've eliminated the most obvious one).
from enum import Enum, EnumMeta, _EnumDict
class AutoStrEnumDict(_EnumDict):
def __setitem__(self, key, value):
super().__setitem__(key, key)
class AutoStrEnumMeta(EnumMeta):
#classmethod
def __prepare__(metacls, cls, bases):
return AutoStrEnumDict()
def __init__(self, name, bases, attrs):
super().__init__(name, bases, attrs)
# override Enum.__str__
# can't put these on the class directly otherwise EnumMeta overwrites them
# should also consider resetting __repr__, __format__ and __reduce_ex__
if self.__str__ is not str.__str__:
self.__str__ = str.__str__
class AutoStrNameEnum(str, Enum, metaclass=AutoStrEnumMeta):
pass
class Animal(AutoStrNameEnum):
horse = ()
dog = ()
print(Animal.horse)
assert Animal.horse == "horse"
assert str(Animal.horse) == "horse"
# and not equal to "Animal.horse" (the gotcha mentioned earlier)