How to make non-frozen dataclass frozen, and vice versa? - python

I want to know a simple way to make a dataclass bar frozen.
#dataclass
class Bar:
foo: int
bar = Bar(foo=1)
In other words, I want a function like the following some_fn_to_freeze
frozen_bar = some_fn_to_freeze(bar)
frozen_bar.foo = 2 # Error
And, an inverse function some_fn_to_unfreeze
bar = som_fn_to_unfrozen(frozen_bar)
bar.foo = 3 # not Error

The standard way to mutate a frozen dataclass is to use dataclasses.replace:
old_bar = Bar(foo=123)
new_bar = dataclasses.replace(old_bar, foo=456)
assert new_bar.foo == 456
For more complex use-cases, you can use the dataclass utils module from: https://github.com/google/etils
It add a my_dataclass = my_dataclass.unfrozen() member, which allow to mutate frozen dataclasses directly
# pip install etils[edc]
from etils import edc
#edc.dataclass(allow_unfrozen=True) # Add the `unfrozen()`/`frozen` method
#dataclasses.dataclass(frozen=True)
class A:
x: Any = None
y: Any = None
old_a = A(x=A(x=A()))
# After a is unfrozen, the updates on nested attributes will be propagated
# to the top-level parent.
a = old_a.unfrozen()
a.x.x.x = 123
a.x.y = 'abc'
a = a.frozen() # `frozen()` recursively call `dataclasses.replace`
# Only the `unfrozen` object is mutated. Not the original one.
assert a == A(x=A(x=A(x = 123), y='abc'))
assert old_a == A(x=A(x=A()))
As seen in the example, you can return unfrozen/frozen copies of the dataclass, which was explicitly designed to mutate nested dataclasses.
#edc.dataclass also add a a.replace(**kwargs) method to the dataclass (alias of dataclasses.dataclass)
a = A()
a = a.replace(x=123, y=456)
assert a == A(x=123, y=456)

dataclass doesn't have built-in support for that. Frozen-ness is tracked on a class-wide basis, not per-instance, and there's no support for automatically generating frozen or unfrozen equivalents of dataclasses.
While you could try to do something to generate new dataclasses on the fly, it'd interact very poorly with isinstance, ==, and other things you'd want to work. It's probably safer to just write two dataclasses and converter methods:
#dataclass
class Bar:
foo: int
def as_frozen(self):
return FrozenBar(self.foo)
#dataclass(frozen=True)
class FrozenBar:
foo: int
def as_unfrozen(self):
return Bar(self.foo)

Python dataclasses are great, but the attrs package is a more flexible alternative, if you are able to use a third-party library. For example:
import attr
# Your class of interest.
#attr.s()
class Bar(object):
val = attr.ib()
# A frozen variant of it.
#attr.s(frozen = True)
class FrozenBar(Bar):
pass
# Three instances:
# - Bar.
# - FrozenBar based on that Bar.
# - Bar based on that FrozenBar.
b1 = Bar(123)
fb = FrozenBar(**attr.asdict(b1))
b2 = Bar(**attr.asdict(fb))
# We can modify the Bar instances.
b1.val = 777
b2.val = 888
# Check current vals.
for x in (b1, fb, b2):
print(x)
# But we cannot modify the FrozenBar instance.
try:
fb.val = 999
except attr.exceptions.FrozenInstanceError:
print(fb, 'unchanged')
Output:
Bar(val=888)
FrozenBar(val=123)
Bar(val=999)
FrozenBar(val=123) unchanged

I'm using the following code to get a frozen copy of a dataclass class or instance:
import dataclasses
from dataclasses import dataclass, fields, asdict
import typing
from typing import TypeVar
FDC_SELF = TypeVar('FDC_SELF', bound='FreezableDataClass')
#dataclass
class FreezableDataClass:
#classmethod
def get_frozen_dataclass(cls: Type[FDC_SELF]) -> Type[FDC_SELF]:
"""
#return: a generated frozen dataclass definition, compatible with the calling class
"""
cls_fields = fields(cls)
frozen_cls_name = 'Frozen' + cls.__name__
frozen_dc_namespace = {
'__name__': frozen_cls_name,
'__module__': __name__,
}
excluded_from_freezing = cls.attrs_excluded_from_freezing()
for attr in dir(cls):
if attr.startswith('__') or attr in excluded_from_freezing:
continue
attr_def = getattr(cls, attr)
if hasattr(attr_def, '__func__'):
attr_def = classmethod(getattr(attr_def, '__func__'))
frozen_dc_namespace[attr] = attr_def
frozen_dc = dataclasses.make_dataclass(
cls_name=frozen_cls_name,
fields=[(f.name, f.type, f) for f in cls_fields],
bases=(),
namespace=frozen_dc_namespace,
frozen=True,
)
globals()[frozen_dc.__name__] = frozen_dc
return frozen_dc
#classmethod
def attrs_excluded_from_freezing(cls) -> typing.Iterable[str]:
return tuple()
def get_frozen_instance(self: FDC_SELF) -> FDC_SELF:
"""
#return: an instance of a generated frozen dataclass, compatible with the current dataclass, with copied values
"""
cls = type(self)
frozen_dc = cls.get_frozen_dataclass()
# noinspection PyArgumentList
return frozen_dc(**asdict(self))
Derived classes could overwrite attrs_excluded_from_freezing to exclude methods which wouldn't work on a frozen dataclass.
Why didn't I prefer other existing answers?
3rd party libraries - etils.edc, If I would use a solution from one of the previous answers, it would be this one. E.g. to get the ability to recursively freeze/unfreeze.
3rd party libraries - attrs
duplicated code

Related

How to get all values of variables in class?

class NiceClass():
some_value = SomeObject(...)
some_other_value = SomeOtherObject(...)
#classmethod
def get_all_vars(cls):
...
I want get_all_vars() to return [SomeObject(...), SomeOtherObject(...)], or more specifically, the values of the variables in cls.
Solutions tried that didn't work out for me:
return [cls.some_value, cls.some_other_value, ...] (requires listing the variable manually)
subclassing Enum then using list(cls) (requires using some_value.value to access the value elsewhere in the program, also type hinting would be a mess)
namedtuples (nope not touching that subject, heard it was much more complicated than Enum)
[value for key, value in vars(cls).items() if not callable(value) and not key.startswith("__")] (too hacky due to using vars(cls), also for some reason it also includes get_all_vars due to it being a classmethod)
There are two ways. This is a straight answer to your question:
class Foo:
pass
class Bar:
x: int = 1
y: str = 'hello'
z: Foo = Foo()
#classmethod
def get_all(cls):
xs = []
for name, value in vars(cls).items():
if not (name.startswith('__') or isinstance(value, classmethod)):
xs.append(value)
return xs
This is what I suggest:
from dataclasses import dataclass, fields
class Foo:
pass
#dataclass
class Bar:
x: int = 1
y: str = 'hello'
z: Foo = Foo()
#classmethod
def get_defaults(cls):
return [f.default for f in fields(cls)]
#classmethod
def get_all(cls):
return [getattr(cls, f.name) for f in fields(cls)]
results:
Bar.get_defaults() == Bar.get_all()
# True -> [1, 'hello', __main__.Foo]
Bar.x = 10
Bar.get_defaults() == Bar.get_all()
# False -> [1, 'hello', __main__.Foo] != [10, 'hello', __main__.Foo]
You can create a list of values and define individual attributes at the same time with minimal boilerplate.
class NiceClass():
(some_value,
some_other_value,
) = _all_values = [
SomeObject(...),
SomeOtherObject(...)
]
#classmethod
def get_all_vars(cls):
return cls._all_values
The most obvious drawback to this is that it's easy to get the order of names and values out of sync.
Ideally, you might like to do something like
class NiceClass:
_attributes = {
'some_value': SomeObject(...),
'some_other_value': SomeOtherObject(...)
}
#classmethod
def get_all_vars(cls):
return cls._attributes.values()
and have some way to "inject" the contents of _attributes into the class namespace directly. The simplest way to do this is with a mix-in class that defines __init_subclass__:
class AddAttributes:
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.__dict__.update(cls._attributes)
class NiceClass(AddAttributes):
# As above
...
This might sound like a https://xyproblem.info/ but my solution might work in the other case as well. You can get the fields of an object by using __dict__ or vars (which is considered more pythonic given: Python dictionary from an object's fields)
You could do something like:
class ClassTest:
def __init__(self):
self.one = 1
self.two = 2
list(vars(ClassTest()).values())
But you will see that it has some limitations. It doesn't recognize the not in self.variable_name defined variables like you have used above. It might help you nonetheless, because you can simply define them in init.

Distinguishing between Pydantic Models with same fields

I'm using Pydantic to define hierarchical data in which there are models with identical attributes.
However, when I save and load these models, Pydantic can no longer distinguish which model was used and picks the first one in the field type annotation.
I understand that this is expected behavior based on the documentation.
However, the class type information is important to my application.
What is the recommended way to distinguish between different classes in Pydantic? One hack is to simply add an extraneous field to one of the models, but I'd like to find a more elegant solution.
See the simplified example below: container is initialized with data of type DataB, but after exporting and loading, the new container has data of type DataA as it's the first element in the type declaration of container.data.
Thanks for your help!
from abc import ABC
from pydantic import BaseModel #pydantic 1.8.2
from typing import Union
class Data(BaseModel, ABC):
""" base class for a Member """
number: float
class DataA(Data):
""" A type of Data"""
pass
class DataB(Data):
""" Another type of Data """
pass
class Container(BaseModel):
""" container holds a subclass of Data """
data: Union[DataA, DataB]
# initialize container with DataB
data = DataB(number=1.0)
container = Container(data=data)
# export container to string and load new container from string
string = container.json()
new_container = Container.parse_raw(string)
# look at type of container.data
print(type(new_container.data).__name__)
# >>> DataA
As correctly noted in the comments, without storing additional information models cannot be distinguished when parsing.
As of today (pydantic v1.8.2), the most canonical way to distinguish models when parsing in a Union (in case of ambiguity) is to explicitly add a type specifier Literal. It will look like this:
from abc import ABC
from pydantic import BaseModel
from typing import Union, Literal
class Data(BaseModel, ABC):
""" base class for a Member """
number: float
class DataA(Data):
""" A type of Data"""
tag: Literal['A'] = 'A'
class DataB(Data):
""" Another type of Data """
tag: Literal['B'] = 'B'
class Container(BaseModel):
""" container holds a subclass of Data """
data: Union[DataA, DataB]
# initialize container with DataB
data = DataB(number=1.0)
container = Container(data=data)
# export container to string and load new container from string
string = container.json()
new_container = Container.parse_raw(string)
# look at type of container.data
print(type(new_container.data).__name__)
# >>> DataB
This method can be automated, but you can use it at your own responsibility, since it breaks static typing and uses objects that may change in future versions:
from pydantic.fields import ModelField
class Data(BaseModel, ABC):
""" base class for a Member """
number: float
def __init_subclass__(cls, **kwargs):
name = 'tag'
value = cls.__name__
annotation = Literal[value]
tag_field = ModelField.infer(name=name, value=value, annotation=annotation, class_validators=None, config=cls.__config__)
cls.__fields__[name] = tag_field
cls.__annotations__[name] = annotation
class DataA(Data):
""" A type of Data"""
pass
class DataB(Data):
""" Another type of Data """
pass
Just wanted to take the opportunity to list another possible alternative here to pydantic - which already supports this use case very well, as per below answer.
I am the creator and maintainer of a relatively newer and lesser-known JSON serialization library, the Dataclass Wizard - which relies on the Python dataclasses module to perform its magic. As of the latest version, 0.14.0, the dataclass-wizard now supports dataclasses within Union types. Previously, it did not support dataclasses within Union types at all, which was kind of a glaring omission, and something on my "to-do" list of things to (eventually) add support for.
As of the latest, it should now support defining dataclasses within Union types. The reason it did not generally work before, is because the data being de-serialized is often a JSON object, which only knows simple types such as arrays and dictionaries, for example. A dict type would not otherwise match any of the Union[Data1, Data2] types, even if the object had all the correct dataclass fields as keys. This is simply because it doesn't compare the dict object against each of the dataclass fields in the Union types, though that might change in a future release.
So in any case, here is a simple example to demonstrate the usage of dataclasses in Union types, using a class inheritance model with the JSONWizard mixin class:
With Class Inheritance
from abc import ABC
from dataclasses import dataclass
from typing import Union
from dataclass_wizard import JSONWizard
#dataclass
class Data(ABC):
""" base class for a Member """
number: float
class DataA(Data, JSONWizard):
""" A type of Data"""
class _(JSONWizard.Meta):
"""
This defines a custom tag that uniquely identifies the dataclass.
"""
tag = 'A'
class DataB(Data, JSONWizard):
""" Another type of Data """
class _(JSONWizard.Meta):
"""
This defines a custom tag that uniquely identifies the dataclass.
"""
tag = 'B'
#dataclass
class Container(JSONWizard):
""" container holds a subclass of Data """
data: Union[DataA, DataB]
The usage is shown below, and is again pretty straightforward. It relies on a special __tag__ key set in a dictionary or JSON object to marshal it into the correct dataclass, based on the Meta.tag value for that class, that we have set up above.
print('== Load with DataA ==')
input_dict = {
'data': {
'number': '1.0',
'__tag__': 'A'
}
}
# De-serialize the `dict` object to a `Container` instance.
container = Container.from_dict(input_dict)
print(repr(container))
# prints:
# Container(data=DataA(number=1.0))
# Show the prettified JSON representation of the instance.
print(container)
# Assert we load the correct dataclass from the annotated `Union` types
assert type(container.data) == DataA
print()
print('== Load with DataB ==')
# initialize container with DataB
data_b = DataB(number=2.0)
container = Container(data=data_b)
print(repr(container))
# prints:
# Container(data=DataB(number=2.0))
# Show the prettified JSON representation of the instance.
print(container)
# Assert we load the correct dataclass from the annotated `Union` types
assert type(container.data) == DataB
# Assert we end up with the same instance when serializing and de-serializing
# our data.
string = container.to_json()
assert container == Container.from_json(string)
Without Class Inheritance
Here is the same example as above, but with relying solely on dataclasses, without using any special class inheritance model:
from abc import ABC
from dataclasses import dataclass
from typing import Union
from dataclass_wizard import asdict, fromdict, LoadMeta
#dataclass
class Data(ABC):
""" base class for a Member """
number: float
class DataA(Data):
""" A type of Data"""
class DataB(Data):
""" Another type of Data """
#dataclass
class Container:
""" container holds a subclass of Data """
data: Union[DataA, DataB]
# Setup tags for the dataclasses. This can be passed into either
# `LoadMeta` or `DumpMeta`.
#
# Note that I'm not a fan of this syntax either, so it might change. I was
# thinking of something more explicit, like `LoadMeta(...).bind_to(class)`
LoadMeta(DataA, tag='A')
LoadMeta(DataB, tag='B')
# The rest is the same as before.
# initialize container with DataB
data = DataB(number=2.0)
container = Container(data=data)
print(repr(container))
# prints:
# Container(data=DataB(number=2.0))
# Assert we load the correct dataclass from the annotated `Union` types
assert type(container.data) == DataB
# Assert we end up with the same data when serializing and de-serializing.
out_dict = asdict(container)
assert container == fromdict(Container, out_dict)
I'm trying to hack something together in the meantime using custom validators.
Basically the class decorator adds a class_name: str field, which is added to the json string. The validator then looks up the correct subclass based on its value.
def register_distinct_subclasses(fields: tuple):
""" fields is tuple of subclasses that we want to be registered as distinct """
field_map = {field.__name__: field for field in fields}
def _register_distinct_subclasses(cls):
""" cls is the superclass of fields, which we add a new validator to """
orig_init = cls.__init__
class _class:
class_name: str
def __init__(self, **kwargs):
class_name = type(self).__name__
kwargs["class_name"] = class_name
orig_init(**kwargs)
#classmethod
def __get_validators__(cls):
yield cls.validate
#classmethod
def validate(cls, v):
if isinstance(v, dict):
class_name = v.get("class_name")
json_string = json.dumps(v)
else:
class_name = v.class_name
json_string = v.json()
cls_type = field_map[class_name]
return cls_type.parse_raw(json_string)
return _class
return _register_distinct_subclasses
which is called as follows
Data = register_distinct_subclasses((DataA, DataB))(Data)

How to extend an enum with aliases

I have an abstract base class GameNodeState that contains a Type enum:
import abc
import enum
class GameNodeState(metaclass=abc.ABCMeta):
class Type(enum.Enum):
INIT = enum.auto()
INTERMEDIATE = enum.auto()
END = enum.auto()
The names in the enum are generic because they must make sense for any subclass of GameNodeState. But when I subclass GameNodeState, as GameState and RoundState, I would like to be able to add concrete aliases to the members of GameNodeState.Type if the enum is accessed through the subclass. For example, if the GameState subclass aliases INTERMEDIATE as ROUND and RoundState aliases INTERMEDIATE as TURN, I would like the following behaviour:
>>> GameNodeState.Type.INTERMEDIATE
<Type.INTERMEDIATE: 2>
>>> RoundState.Type.TURN
<Type.INTERMEDIATE: 2>
>>> RoundState.Type.INTERMEDIATE
<Type.INTERMEDIATE: 2>
>>> GameNodeState.Type.TURN
AttributeError: TURN
My first thought was this:
class GameState(GameNodeState):
class Type(GameNodeState.Type):
ROUND = GameNodeState.Type.INTERMEDIATE.value
class RoundState(GameNodeState):
class Type(GameNodeState.Type):
TURN = GameNodeState.Type.INTERMEDIATE.value
But enums can't be subclassed.
Note: there are obviously more attributes and methods in the GameNodeState hierarchy, I stripped it down to the bare minimum here to focus on this particular thing.
Refinement
(Original solution below.)
I've extracted an intermediate concept from the code above, namely the concept of enum union. This can be used to obtain the behaviour above, and is also useful in other contexts too. The code can be foud here, and I've asked a Code Review question.
I'll add the code here as well for reference:
import enum
import itertools as itt
from functools import reduce
import operator
from typing import Literal, Union
import more_itertools as mitt
AUTO = object()
class UnionEnumMeta(enum.EnumMeta):
"""
The metaclass for enums which are the union of several sub-enums.
Union enums have the _subenums_ attribute which is a tuple of the enums forming the
union.
"""
#classmethod
def make_union(
mcs, *subenums: enum.EnumMeta, name: Union[str, Literal[AUTO], None] = AUTO
) -> enum.EnumMeta:
"""
Create an enum whose set of members is the union of members of several enums.
Order matters: where two members in the union have the same value, they will
be considered as aliases of each other, and the one appearing in the first
enum in the sequence will be used as the canonical members (the aliases will
be associated to this enum member).
:param subenums: Sequence of sub-enums to make a union of.
:param name: Name to use for the enum class. AUTO will result in a combination
of the names of all subenums, None will result in "UnionEnum".
:return: An enum class which is the union of the given subenums.
"""
subenums = mcs._normalize_subenums(subenums)
class UnionEnum(enum.Enum, metaclass=mcs):
pass
union_enum = UnionEnum
union_enum._subenums_ = subenums
if duplicate_names := reduce(
set.intersection, (set(subenum.__members__) for subenum in subenums)
):
raise ValueError(
f"Found duplicate member names in enum union: {duplicate_names}"
)
# If aliases are defined, the canonical member will be the one that appears
# first in the sequence of subenums.
# dict union keeps last key so we have to do it in reverse:
union_enum._value2member_map_ = value2member_map = reduce(
operator.or_, (subenum._value2member_map_ for subenum in reversed(subenums))
)
# union of the _member_map_'s but using the canonical member always:
union_enum._member_map_ = member_map = {
name: value2member_map[member.value]
for name, member in itt.chain.from_iterable(
subenum._member_map_.items() for subenum in subenums
)
}
# only include canonical aliases in _member_names_
union_enum._member_names_ = list(
mitt.unique_everseen(
itt.chain.from_iterable(subenum._member_names_ for subenum in subenums),
key=member_map.__getitem__,
)
)
if name is AUTO:
name = (
"".join(subenum.__name__.removesuffix("Enum") for subenum in subenums)
+ "UnionEnum"
)
UnionEnum.__name__ = name
elif name is not None:
UnionEnum.__name__ = name
return union_enum
def __repr__(cls):
return f"<union of {', '.join(map(str, cls._subenums_))}>"
def __instancecheck__(cls, instance):
return any(isinstance(instance, subenum) for subenum in cls._subenums_)
#classmethod
def _normalize_subenums(mcs, subenums):
"""Remove duplicate subenums and flatten nested unions"""
# we will need to collapse at most one level of nesting, with the inductive
# hypothesis that any previous unions are already flat
subenums = mitt.collapse(
(e._subenums_ if isinstance(e, mcs) else e for e in subenums),
base_type=enum.EnumMeta,
)
subenums = mitt.unique_everseen(subenums)
return tuple(subenums)
def enum_union(*enums, **kwargs):
return UnionEnumMeta.make_union(*enums, **kwargs)
Once we have that, we can just define the extend_enum decorator to compute the union of the base enum and the enum "extension", which will result in the desired behaviour:
def extend_enum(base_enum):
def decorator(extension_enum):
return enum_union(base_enum, extension_enum)
return decorator
Usage:
class GameNodeState(metaclass=abc.ABCMeta):
class Type(enum.Enum):
INIT = enum.auto()
INTERMEDIATE = enum.auto()
END = enum.auto()
class RoundState(GameNodeState):
#extend_enum(GameNodeState.Type)
class Type(enum.Enum):
TURN = GameNodeState.Type.INTERMEDIATE.value
class GameState(GameNodeState):
#extend_enum(GameNodeState.Type)
class Type(enum.Enum):
ROUND = GameNodeState.Type.INTERMEDIATE.value
Now all of the examples above produce the same output (plus the added instance check, i.e. isinstance(RoundState.Type.TURN, RoundState.Type) returns True).
I think this is a cleaner solution because it doesn't involve mucking around with descriptors; it doesn't need to know anything about the owner class (this works just as well with top-level classes).
Attribute lookup through subclasses and instances of GameNodeState should automatically link to the correct "extension" (i.e., union), as long as the extension enum is added with the same name as for the GameNodeState superclass so that it hides the original definition.
Original
Not sure how bad of an idea this is, but here is a solution using a descriptor wrapped around the enum that gets the set of aliases based on the class from which it is being accessed.
class ExtensibleClassEnum:
class ExtensionWrapperMeta(enum.EnumMeta):
#classmethod
def __prepare__(mcs, name, bases):
# noinspection PyTypeChecker
classdict: enum._EnumDict = super().__prepare__(name, bases)
classdict["_ignore_"] = ["base_descriptor", "extension_enum"]
return classdict
# noinspection PyProtectedMember
def __new__(mcs, cls, bases, classdict):
base_descriptor = classdict.pop("base_descriptor")
extension_enum = classdict.pop("extension_enum")
wrapper_enum = super().__new__(mcs, cls, bases, classdict)
wrapper_enum.base_descriptor = base_descriptor
wrapper_enum.extension_enum = extension_enum
base, extension = base_descriptor.base_enum, extension_enum
if set(base._member_map_.keys()) & set(extension._member_map_.keys()):
raise ValueError("Found duplicate names in extension")
# dict union keeps last key so we have to do it in reverse:
wrapper_enum._value2member_map_ = (
extension._value2member_map_ | base._value2member_map_
)
# union of both _member_map_'s but using the canonical member always:
wrapper_enum._member_map_ = {
name: wrapper_enum._value2member_map_[member.value]
for name, member in itertools.chain(
base._member_map_.items(), extension._member_map_.items()
)
}
# aliases shouldn't appear in _member_names_
wrapper_enum._member_names_ = list(
m.name for m in wrapper_enum._value2member_map_.values()
)
return wrapper_enum
def __repr__(self):
# have to use vars() to avoid triggering the descriptor
base_descriptor = vars(self)["base_descriptor"]
return (
f"<extension wrapper enum for {base_descriptor.base_enum}"
f" in {base_descriptor._extension2owner[self]}>"
)
def __init__(self, base_enum):
if not issubclass(base_enum, enum.Enum):
raise TypeError(base_enum)
self.base_enum = base_enum
# The user won't be able to retrieve the descriptor object itself, just
# the enum, so we have to forward calls to register_extension:
self.base_enum.register_extension = staticmethod(self.register_extension)
# mapping of owner class -> extension for subclasses that define an extension
self._extensions: Dict[Type, ExtensibleClassEnum.ExtensionWrapperMeta] = {}
# reverse mapping
self._extension2owner: Dict[ExtensibleClassEnum.ExtensionWrapperMeta, Type] = {}
# add the base enum as the base extension via __set_name__:
self._pending_extension = base_enum
#property
def base_owner(self):
# will be initialised after __set_name__ is called with base owner
return self._extension2owner[self.base_enum]
def __set_name__(self, owner, name):
# step 2 of register_extension: determine the class that defined it
self._extensions[owner] = self._pending_extension
self._extension2owner[self._pending_extension] = owner
del self._pending_extension
def __get__(self, instance, owner):
# Only compute extensions once:
if owner in self._extensions:
return self._extensions[owner]
# traverse in MRO until we find the closest supertype defining an extension
for supertype in owner.__mro__:
if supertype in self._extensions:
extension = self._extensions[supertype]
break
else:
raise TypeError(f"{owner} is not a subclass of {self.base_owner}")
# Cache the result
self._extensions[owner] = extension
return extension
def make_extension(self, extension: enum.EnumMeta):
class ExtensionWrapperEnum(
enum.Enum, metaclass=ExtensibleClassEnum.ExtensionWrapperMeta
):
base_descriptor = self
extension_enum = extension
return ExtensionWrapperEnum
def register_extension(self, extension_enum):
"""Decorator for enum extensions"""
# need a way to determine owner class
# add a temporary attribute that we will use when __set_name__ is called:
if hasattr(self, "_pending_extension"):
# __set_name__ not called after the previous call to register_extension
raise RuntimeError(
"An extension was created outside of a class definition",
self._pending_extension,
)
self._pending_extension = self.make_extension(extension_enum)
return self
Usage would be as follows:
class GameNodeState(metaclass=abc.ABCMeta):
#ExtensibleClassEnum
class Type(enum.Enum):
INIT = enum.auto()
INTERMEDIATE = enum.auto()
END = enum.auto()
class RoundState(GameNodeState):
#GameNodeState.Type.register_extension
class Type(enum.Enum):
TURN = GameNodeState.Type.INTERMEDIATE.value
class GameState(GameNodeState):
#GameNodeState.Type.register_extension
class Type(enum.Enum):
ROUND = GameNodeState.Type.INTERMEDIATE.value
Then:
>>> (RoundState.Type.TURN
... == RoundState.Type.INTERMEDIATE
... == GameNodeState.Type.INTERMEDIATE
... == GameState.Type.INTERMEDIATE
... == GameState.Type.ROUND)
...
True
>>> RoundState.Type.__members__
mappingproxy({'INIT': <Type.INIT: 1>,
'INTERMEDIATE': <Type.INTERMEDIATE: 2>,
'END': <Type.END: 3>,
'TURN': <Type.INTERMEDIATE: 2>})
>>> list(RoundState.Type)
[<Type.INTERMEDIATE: 2>, <Type.INIT: 1>, <Type.END: 3>]
>>> GameNodeState.Type.TURN
Traceback (most recent call last):
...
File "C:\Program Files\Python39\lib\enum.py", line 352, in __getattr__
raise AttributeError(name) from None
AttributeError: TURN

Using class or static method as default_factory in dataclasses

I want to populate an attribute of a dataclass using the default_factory method. However, since the factory method is only meaningful in the context of this specific class, I want to keep it inside the class (e.g. as a static or class method). For example:
from dataclasses import dataclass, field
from typing import List
#dataclass
class Deck:
cards: List[str] = field(default_factory=self.create_cards)
#staticmethod
def create_cards():
return ['King', 'Queen']
However, I get this error (as expected) on line 6:
NameError: name 'self' is not defined
How can I overcome this issue? I don't want to move the create_cards() method out of the class.
One possible solution is to move it to __post_init__(self). For example:
#dataclass
class Deck:
cards: List[str] = field(default_factory=list)
def __post_init__(self):
if not self.cards:
self.cards = self.create_cards()
def create_cards(self):
return ['King', 'Queen']
Output:
d1 = Deck()
print(d1) # prints Deck(cards=['King', 'Queen'])
d2 = Deck(["Captain"])
print(d2) # prints Deck(cards=['Captain'])
I adapted momo's answer to be self contained in a class and without thread-safety (since I was using this in asyncio.PriorityQueue context):
from dataclasses import dataclass, field
from typing import Any, ClassVar
#dataclass(order=True)
class FifoPriorityQueueItem:
data: Any=field(default=None, compare=False)
priority: int=10
sequence: int=field(default_factory=lambda: {0})
counter: ClassVar[int] = 0
def get_data(self):
return self.data
def __post_init__(self):
self.sequence = FifoPriorityQueueItem.next_seq()
#staticmethod
def next_seq():
FifoPriorityQueueItem.counter += 1
return FifoPriorityQueueItem.counter
def main():
import asyncio
print('with FifoPriorityQueueItem is FIFO')
q = asyncio.PriorityQueue()
q.put_nowait(FifoPriorityQueueItem('z'))
q.put_nowait(FifoPriorityQueueItem('y'))
q.put_nowait(FifoPriorityQueueItem('b', priority=1))
q.put_nowait(FifoPriorityQueueItem('x'))
q.put_nowait(FifoPriorityQueueItem('a', priority=1))
while not q.empty():
print(q.get_nowait().get_data())
print('without FifoPriorityQueueItem is no longer FIFO')
q.put_nowait((10, 'z'))
q.put_nowait((10, 'y'))
q.put_nowait((1, 'b'))
q.put_nowait((10, 'x'))
q.put_nowait((1, 'a'))
while not q.empty():
print(q.get_nowait()[1])
if __name__ == '__main__':
main()
Results in:
with FifoPriorityQueueItem is FIFO
b
a
z
y
x
without FifoPriorityQueueItem is no longer FIFO
a
b
x
y
z
One option is to wait until after you define the field object to make create_cards a static method. Make it a regular function, use it as such to define the cards field, then replace it with a static method that wraps the function.
from dataclasses import dataclass, field
from typing import List
#dataclass
class Deck:
# Define a regular function first (we'll replace it later,
# so it's not going to be an instance method)
def create_cards():
return ['King', 'Queen']
# Use create_cards as a regular function
cards: List[str] = field(default_factory=create_cards)
# *Now* make it it a static method
create_cards = staticmethod(cards)
This works because the field object is created while the class is being defined, so it doesn't need to be a static method yet.

How to extend Python Enum?

Is it possible to extend classes created using the new Enum functionality in Python 3.4? How?
Simple subclassing doesn't appear to work. An example like
from enum import Enum
class EventStatus(Enum):
success = 0
failure = 1
class BookingStatus(EventStatus):
duplicate = 2
unknown = 3
will give an exception like TypeError: Cannot extend enumerations or (in more recent versions) TypeError: BookingStatus: cannot extend enumeration 'EventStatus'.
How can I make it so that BookingStatus reuses the enumeration values from EventStatus and adds more?
Subclassing an enumeration is allowed only if the enumeration does not define any members.
Allowing subclassing of enums that define members would lead to a violation of some important invariants of types and instances.
https://docs.python.org/3/howto/enum.html#restricted-enum-subclassing
So no, it's not directly possible.
While uncommon, it is sometimes useful to create an enum from many modules. The aenum1 library supports this with an extend_enum function:
from aenum import Enum, extend_enum
class Index(Enum):
DeviceType = 0x1000
ErrorRegister = 0x1001
for name, value in (
('ControlWord', 0x6040),
('StatusWord', 0x6041),
('OperationMode', 0x6060),
):
extend_enum(Index, name, value)
assert len(Index) == 5
assert list(Index) == [Index.DeviceType, Index.ErrorRegister, Index.ControlWord, Index.StatusWord, Index.OperationMode]
assert Index.DeviceType.value == 0x1000
assert Index.StatusWord.value == 0x6041
1 Disclosure: I am the author of the Python stdlib Enum, the enum34 backport, and the Advanced Enumeration (aenum) library.
Calling the Enum class directly and making use of chain allows the extension (joining) of an existing enum.
I came upon the problem of extending enums while working on a CANopen
implementation. Parameter indices in the range from 0x1000 to 0x2000
are generic to all CANopen nodes while e.g. the range from 0x6000
onwards depends open whether the node is a drive, io-module, etc.
nodes.py:
from enum import IntEnum
class IndexGeneric(IntEnum):
""" This enum holds the index value of genric object entrys
"""
DeviceType = 0x1000
ErrorRegister = 0x1001
Idx = IndexGeneric
drives.py:
from itertools import chain
from enum import IntEnum
from nodes import IndexGeneric
class IndexDrives(IntEnum):
""" This enum holds the index value of drive object entrys
"""
ControlWord = 0x6040
StatusWord = 0x6041
OperationMode = 0x6060
Idx= IntEnum('Idx', [(i.name, i.value) for i in chain(IndexGeneric,IndexDrives)])
I tested that way on 3.8. We may inherit existing enum but we need to do it also from base class (at last position).
Docs:
A new Enum class must have one base Enum class, up to one concrete
data type, and as many object-based mixin classes as needed. The order
of these base classes is:
class EnumName([mix-in, ...,] [data-type,] base-enum):
pass
Example:
class Cats(Enum):
SIBERIAN = "siberian"
SPHINX = "sphinx"
class Animals(Cats, Enum):
LABRADOR = "labrador"
CORGI = "corgi"
After that you may access Cats from Animals:
>>> Animals.SIBERIAN
<Cats.SIBERIAN: 'siberian'>
But if you want to iterate over this enum, only new members were accessible:
>>> list(Animals)
[<Animals.LABRADOR: 'labrador'>, <Animals.CORGI: 'corgi'>]
Actually this way is for inheriting methods from base class, but you may use it for members with these restrictions.
Another way (a bit hacky)
As described above, to write some function to join two enums in one. I've wrote that example:
def extend_enum(inherited_enum):
def wrapper(added_enum):
joined = {}
for item in inherited_enum:
joined[item.name] = item.value
for item in added_enum:
joined[item.name] = item.value
return Enum(added_enum.__name__, joined)
return wrapper
class Cats(Enum):
SIBERIAN = "siberian"
SPHINX = "sphinx"
#extend_enum(Cats)
class Animals(Enum):
LABRADOR = "labrador"
CORGI = "corgi"
But here we meet another problems. If we want to compare members it fails:
>>> Animals.SIBERIAN == Cats.SIBERIAN
False
Here we may compare only names and values of newly created members:
>>> Animals.SIBERIAN.value == Cats.SIBERIAN.value
True
But if we need iteration over new Enum, it works ok:
>>> list(Animals)
[<Animals.SIBERIAN: 'siberian'>, <Animals.SPHINX: 'sphinx'>, <Animals.LABRADOR: 'labrador'>, <Animals.CORGI: 'corgi'>]
So choose your way: simple inheritance, inheritance emulation with decorator (recreation in fact), or adding a new dependency like aenum (I haven't tested it, but I expect it support all features I described).
For correct type specification, you could use the Union operator:
from enum import Enum
from typing import Union
class EventStatus(Enum):
success = 0
failure = 1
class BookingSpecificStatus(Enum):
duplicate = 2
unknown = 3
BookingStatus = Union[EventStatus, BookingSpecificStatus]
example_status: BookingStatus
example_status = BookingSpecificStatus.duplicate
example_status = EventStatus.success
Plenty of good answers here already but here's another one purely using Enum's Functional API.
Probably not the most beautiful solution but it avoids code duplication, works out of the box, no additional packages/libraries are need, and it should be sufficient to cover most use cases:
from enum import Enum
class EventStatus(Enum):
success = 0
failure = 1
BookingStatus = Enum(
"BookingStatus",
[es.name for es in EventStatus] + ["duplicate", "unknown"],
start=0,
)
for bs in BookingStatus:
print(bs.name, bs.value)
# success 0
# failure 1
# duplicate 2
# unknown 3
If you'd like to be explicit about the values assigned, you can use:
BookingStatus = Enum(
"BookingStatus",
[(es.name, es.value) for es in EventStatus] + [("duplicate", 6), ("unknown", 7)],
)
for bs in BookingStatus:
print(bs.name, bs.value)
# success 0
# failure 1
# duplicate 6
# unknown 7
I've opted to use a metaclass approach to this problem.
from enum import EnumMeta
class MetaClsEnumJoin(EnumMeta):
"""
Metaclass that creates a new `enum.Enum` from multiple existing Enums.
#code
from enum import Enum
ENUMA = Enum('ENUMA', {'a': 1, 'b': 2})
ENUMB = Enum('ENUMB', {'c': 3, 'd': 4})
class ENUMJOINED(Enum, metaclass=MetaClsEnumJoin, enums=(ENUMA, ENUMB)):
pass
print(ENUMJOINED.a)
print(ENUMJOINED.b)
print(ENUMJOINED.c)
print(ENUMJOINED.d)
#endcode
"""
#classmethod
def __prepare__(metacls, name, bases, enums=None, **kargs):
"""
Generates the class's namespace.
#param enums Iterable of `enum.Enum` classes to include in the new class. Conflicts will
be resolved by overriding existing values defined by Enums earlier in the iterable with
values defined by Enums later in the iterable.
"""
#kargs = {"myArg1": 1, "myArg2": 2}
if enums is None:
raise ValueError('Class keyword argument `enums` must be defined to use this metaclass.')
ret = super().__prepare__(name, bases, **kargs)
for enm in enums:
for item in enm:
ret[item.name] = item.value #Throws `TypeError` if conflict.
return ret
def __new__(metacls, name, bases, namespace, **kargs):
return super().__new__(metacls, name, bases, namespace)
#DO NOT send "**kargs" to "type.__new__". It won't catch them and
#you'll get a "TypeError: type() takes 1 or 3 arguments" exception.
def __init__(cls, name, bases, namespace, **kargs):
super().__init__(name, bases, namespace)
#DO NOT send "**kargs" to "type.__init__" in Python 3.5 and older. You'll get a
#"TypeError: type.__init__() takes no keyword arguments" exception.
This metaclass can be used like so:
>>> from enum import Enum
>>>
>>> ENUMA = Enum('ENUMA', {'a': 1, 'b': 2})
>>> ENUMB = Enum('ENUMB', {'c': 3, 'd': 4})
>>> class ENUMJOINED(Enum, metaclass=MetaClsEnumJoin, enums=(ENUMA, ENUMB)):
... e = 5
... f = 6
...
>>> print(repr(ENUMJOINED.a))
<ENUMJOINED.a: 1>
>>> print(repr(ENUMJOINED.b))
<ENUMJOINED.b: 2>
>>> print(repr(ENUMJOINED.c))
<ENUMJOINED.c: 3>
>>> print(repr(ENUMJOINED.d))
<ENUMJOINED.d: 4>
>>> print(repr(ENUMJOINED.e))
<ENUMJOINED.e: 5>
>>> print(repr(ENUMJOINED.f))
<ENUMJOINED.f: 6>
This approach creates a new Enum using the same name-value pairs as the source Enums, but the resulting Enum members are still unique. The names and values will be the same, but they will fail direct comparisons to their origins following the spirit of Python's Enum class design:
>>> ENUMA.b.name == ENUMJOINED.b.name
True
>>> ENUMA.b.value == ENUMJOINED.b.value
True
>>> ENUMA.b == ENUMJOINED.b
False
>>> ENUMA.b is ENUMJOINED.b
False
>>>
Note what happens in the event of a namespace conflict:
>>> ENUMC = Enum('ENUMA', {'a': 1, 'b': 2})
>>> ENUMD = Enum('ENUMB', {'a': 3})
>>> class ENUMJOINEDCONFLICT(Enum, metaclass=MetaClsEnumJoin, enums=(ENUMC, ENUMD)):
... pass
...
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<stdin>", line 19, in __prepare__
File "C:\Users\jcrwfrd\AppData\Local\Programs\Python\Python37\lib\enum.py", line 100, in __setitem__
raise TypeError('Attempted to reuse key: %r' % key)
TypeError: Attempted to reuse key: 'a'
>>>
This is due to the base enum.EnumMeta.__prepare__ returning a special enum._EnumDict instead of the typical dict object that behaves different upon key assignment. You may wish to suppress this error message by surrounding it with a try-except TypeError, or there may be a way to modify the namespace before calling super().__prepare__(...).
Another way :
Letter = Enum(value="Letter", names={"A": 0, "B": 1})
LetterExtended = Enum(value="Letter", names=dict({"C": 2, "D": 3}, **{i.name: i.value for i in Letter}))
Or :
LetterDict = {"A": 0, "B": 1}
Letter = Enum(value="Letter", names=LetterDict)
LetterExtendedDict = dict({"C": 2, "D": 3}, **LetterDict)
LetterExtended = Enum(value="Letter", names=LetterExtendedDict)
Output :
>>> Letter.A
<Letter.A: 0>
>>> Letter.C
Traceback (most recent call last):
File "<input>", line 1, in <module>
File "D:\jhpx\AppData\Local\Programs\Python\Python36\lib\enum.py", line 324, in __getattr__
raise AttributeError(name) from None
AttributeError: C
>>> LetterExtended.A
<Letter.A: 0>
>>> LetterExtended.C
<Letter.C: 2>
I think you could do it in this way:
from typing import List
from enum import Enum
def extend_enum(current_enum, names: List[str], values: List = None):
if not values:
values = names
for item in current_enum:
names.append(item.name)
values.append(item.value)
return Enum(current_enum.__name__, dict(zip(names, values)))
class EventStatus(Enum):
success = 0
failure = 1
class BookingStatus(object):
duplicate = 2
unknown = 3
BookingStatus = extend_enum(EventStatus, ['duplicate','unknown'],[2,3])
the key points is:
python could change anything at runtime
class is object too
You can't extend enums but you can create a new one by merging them.
Tested for Python 3.6
from enum import Enum
class DummyEnum(Enum):
a = 1
class AnotherDummyEnum(Enum):
b = 2
def merge_enums(class_name: str, enum1, enum2, result_type=Enum):
if not (issubclass(enum1, Enum) and issubclass(enum2, Enum)):
raise TypeError(
f'{enum1} and {enum2} must be derived from Enum class'
)
attrs = {attr.name: attr.value for attr in set(chain(enum1, enum2))}
return result_type(class_name, attrs, module=__name__)
result_enum = merge_enums(
class_name='DummyResultEnum',
enum1=DummyEnum,
enum2=AnotherDummyEnum,
)
Decorator to extend Enum
To expand on Mikhail Bulygin's answer, a decorator can be used to extend an Enum (and support equality by using a custom Enum base class).
1. Enum base class with value-based equality
from enum import Enum
from typing import Any
class EnumBase(Enum):
def __eq__(self, other: Any) -> bool:
if isinstance(other, Enum):
return self.value == other.value
return False
2. Decorator to extend Enum class
from typing import Callable
def extend_enum(parent_enum: EnumBase) -> Callable[[EnumBase], EnumBase]:
"""Decorator function that extends an enum class with values from another enum class."""
def wrapper(extended_enum: EnumBase) -> EnumBase:
joined = {}
for item in parent_enum:
joined[item.name] = item.value
for item in extended_enum:
joined[item.name] = item.value
return EnumBase(extended_enum.__name__, joined)
return wrapper
Example
>>> from enum import Enum
>>> from typing import Any, Callable
>>> class EnumBase(Enum):
def __eq__(self, other: Any) -> bool:
if isinstance(other, Enum):
return self.value == other.value
return False
>>> def extend_enum(parent_enum: EnumBase) -> Callable[[EnumBase], EnumBase]:
def wrapper(extended_enum: EnumBase) -> EnumBase:
joined = {}
for item in parent_enum:
joined[item.name] = item.value
for item in extended_enum:
joined[item.name] = item.value
return EnumBase(extended_enum.__name__, joined)
return wrapper
>>> class Parent(EnumBase):
A = 1
B = 2
>>> #extend_enum(Parent)
class ExtendedEnum(EnumBase):
C = 3
>>> Parent.A == ExtendedEnum.A
True
>>> list(ExtendedEnum)
[<ExtendedEnum.A: 1>, <ExtendedEnum.B: 2>, <ExtendedEnum.C: 3>]
Yes, you can modify an Enum. The example code, below, is somewhat hacky and it obviously depends on internals of Enum which it has no business whatsoever to depend on. On the other hand, it works.
class ExtIntEnum(IntEnum):
#classmethod
def _add(cls, value, name):
obj = int.__new__(cls, value)
obj._value_ = value
obj._name_ = name
obj.__objclass__ = cls
cls._member_map_[name] = obj
cls._value2member_map_[value] = obj
cls._member_names_.append(name)
class Fubar(ExtIntEnum):
foo = 1
bar = 2
Fubar._add(3,"baz")
Fubar._add(4,"quux")
Specifically, observe the obj = int.__new__() line. The enum module jumps through a few hoops to find the correct __new__ method for the class that should be enumerated. We ignore these hoops here because we already know how integers (or rather, instances of subclasses of int) are created.
It's a good idea not to use this in production code. If you have to, you really should add guards against duplicate values or names.
I wanted to inherit from Django's IntegerChoices which is not possible due to the "Cannot extend enumerations" limitation. I figured it could be done by a relative simple metaclass.
CustomMetaEnum.py:
class CustomMetaEnum(type):
def __new__(self, name, bases, namespace):
# Create empty dict to hold constants (ex. A = 1)
fields = {}
# Copy constants from the namespace to the fields dict.
fields = {key:value for key, value in namespace.items() if isinstance(value, int)}
# In case we're about to create a subclass, copy all constants from the base classes' _fields.
for base in bases:
fields.update(base._fields)
# Save constants as _fields in the new class' namespace.
namespace['_fields'] = fields
return super().__new__(self, name, bases, namespace)
# The choices property is often used in Django.
# If other methods such as values(), labels() etc. are needed
# they can be implemented below (for inspiration [Django IntegerChoice source][1])
#property
def choices(self):
return [(value,key) for key,value in self._fields.items()]
main.py:
from CustomMetaEnum import CustomMetaEnum
class States(metaclass=CustomMetaEnum):
A = 1
B = 2
C = 3
print("States: ")
print(States.A)
print(States.B)
print(States.C)
print(States.choices)
print("MoreStates: ")
class MoreStates(States):
D = 22
pass
print(MoreStates.A)
print(MoreStates.B)
print(MoreStates.C)
print(MoreStates.D)
print(MoreStates.choices)
python3.8 main.py:
States:
1
2
3
[(1, 'A'), (2, 'B'), (3, 'C')]
MoreStates:
1
2
3
22
[(22, 'D'), (1, 'A'), (2, 'B'), (3, 'C')]
Conceptually, it does not make sense to extend an enumeration in this sense. The problem is that this violates the Liskov Substitution Principle: instances of a subclass are supposed to be usable anywhere an instance of the base class could be used, but an instance of BookingStatus could not reliably be used anywhere that an EventStatus is expected. After all, if that instance had a value of BookingStatus.duplicate or BookingStatus.unknown, that would not be a valid enumeration value for an EventStatus.
We can create a new class that reuses the EventStatus setup by using the functional API. A basic example:
event_status_codes = [s.name for s in EventStatus]
BookingStatus = Enum(
'BookingStatus', event_status_codes + ['duplicate', 'unknown']
)
This approach re-numbers the enumeration values, ignoring what they were in EventStatus. We can also pass name-value pairs in order to specify the enum values; this lets us do a bit more analysis, in order to reuse the old values and auto-number new ones:
def extend_enum(result_name, base, *new_names):
base_values = [(v.name, v.value) for v in base]
next_number = max(v.value for v in base) + 1
new_values = [(name, i) for i, name in enumerate(new_names, next_number)]
return Enum(result_name, base_values + new_values)
# Now we can do:
BookingStatus = extend_enum('BookingStatus', EventStatus, 'duplicate', 'unknown')

Categories