I want to populate an attribute of a dataclass using the default_factory method. However, since the factory method is only meaningful in the context of this specific class, I want to keep it inside the class (e.g. as a static or class method). For example:
from dataclasses import dataclass, field
from typing import List
#dataclass
class Deck:
cards: List[str] = field(default_factory=self.create_cards)
#staticmethod
def create_cards():
return ['King', 'Queen']
However, I get this error (as expected) on line 6:
NameError: name 'self' is not defined
How can I overcome this issue? I don't want to move the create_cards() method out of the class.
One possible solution is to move it to __post_init__(self). For example:
#dataclass
class Deck:
cards: List[str] = field(default_factory=list)
def __post_init__(self):
if not self.cards:
self.cards = self.create_cards()
def create_cards(self):
return ['King', 'Queen']
Output:
d1 = Deck()
print(d1) # prints Deck(cards=['King', 'Queen'])
d2 = Deck(["Captain"])
print(d2) # prints Deck(cards=['Captain'])
I adapted momo's answer to be self contained in a class and without thread-safety (since I was using this in asyncio.PriorityQueue context):
from dataclasses import dataclass, field
from typing import Any, ClassVar
#dataclass(order=True)
class FifoPriorityQueueItem:
data: Any=field(default=None, compare=False)
priority: int=10
sequence: int=field(default_factory=lambda: {0})
counter: ClassVar[int] = 0
def get_data(self):
return self.data
def __post_init__(self):
self.sequence = FifoPriorityQueueItem.next_seq()
#staticmethod
def next_seq():
FifoPriorityQueueItem.counter += 1
return FifoPriorityQueueItem.counter
def main():
import asyncio
print('with FifoPriorityQueueItem is FIFO')
q = asyncio.PriorityQueue()
q.put_nowait(FifoPriorityQueueItem('z'))
q.put_nowait(FifoPriorityQueueItem('y'))
q.put_nowait(FifoPriorityQueueItem('b', priority=1))
q.put_nowait(FifoPriorityQueueItem('x'))
q.put_nowait(FifoPriorityQueueItem('a', priority=1))
while not q.empty():
print(q.get_nowait().get_data())
print('without FifoPriorityQueueItem is no longer FIFO')
q.put_nowait((10, 'z'))
q.put_nowait((10, 'y'))
q.put_nowait((1, 'b'))
q.put_nowait((10, 'x'))
q.put_nowait((1, 'a'))
while not q.empty():
print(q.get_nowait()[1])
if __name__ == '__main__':
main()
Results in:
with FifoPriorityQueueItem is FIFO
b
a
z
y
x
without FifoPriorityQueueItem is no longer FIFO
a
b
x
y
z
One option is to wait until after you define the field object to make create_cards a static method. Make it a regular function, use it as such to define the cards field, then replace it with a static method that wraps the function.
from dataclasses import dataclass, field
from typing import List
#dataclass
class Deck:
# Define a regular function first (we'll replace it later,
# so it's not going to be an instance method)
def create_cards():
return ['King', 'Queen']
# Use create_cards as a regular function
cards: List[str] = field(default_factory=create_cards)
# *Now* make it it a static method
create_cards = staticmethod(cards)
This works because the field object is created while the class is being defined, so it doesn't need to be a static method yet.
Related
class NiceClass():
some_value = SomeObject(...)
some_other_value = SomeOtherObject(...)
#classmethod
def get_all_vars(cls):
...
I want get_all_vars() to return [SomeObject(...), SomeOtherObject(...)], or more specifically, the values of the variables in cls.
Solutions tried that didn't work out for me:
return [cls.some_value, cls.some_other_value, ...] (requires listing the variable manually)
subclassing Enum then using list(cls) (requires using some_value.value to access the value elsewhere in the program, also type hinting would be a mess)
namedtuples (nope not touching that subject, heard it was much more complicated than Enum)
[value for key, value in vars(cls).items() if not callable(value) and not key.startswith("__")] (too hacky due to using vars(cls), also for some reason it also includes get_all_vars due to it being a classmethod)
There are two ways. This is a straight answer to your question:
class Foo:
pass
class Bar:
x: int = 1
y: str = 'hello'
z: Foo = Foo()
#classmethod
def get_all(cls):
xs = []
for name, value in vars(cls).items():
if not (name.startswith('__') or isinstance(value, classmethod)):
xs.append(value)
return xs
This is what I suggest:
from dataclasses import dataclass, fields
class Foo:
pass
#dataclass
class Bar:
x: int = 1
y: str = 'hello'
z: Foo = Foo()
#classmethod
def get_defaults(cls):
return [f.default for f in fields(cls)]
#classmethod
def get_all(cls):
return [getattr(cls, f.name) for f in fields(cls)]
results:
Bar.get_defaults() == Bar.get_all()
# True -> [1, 'hello', __main__.Foo]
Bar.x = 10
Bar.get_defaults() == Bar.get_all()
# False -> [1, 'hello', __main__.Foo] != [10, 'hello', __main__.Foo]
You can create a list of values and define individual attributes at the same time with minimal boilerplate.
class NiceClass():
(some_value,
some_other_value,
) = _all_values = [
SomeObject(...),
SomeOtherObject(...)
]
#classmethod
def get_all_vars(cls):
return cls._all_values
The most obvious drawback to this is that it's easy to get the order of names and values out of sync.
Ideally, you might like to do something like
class NiceClass:
_attributes = {
'some_value': SomeObject(...),
'some_other_value': SomeOtherObject(...)
}
#classmethod
def get_all_vars(cls):
return cls._attributes.values()
and have some way to "inject" the contents of _attributes into the class namespace directly. The simplest way to do this is with a mix-in class that defines __init_subclass__:
class AddAttributes:
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.__dict__.update(cls._attributes)
class NiceClass(AddAttributes):
# As above
...
This might sound like a https://xyproblem.info/ but my solution might work in the other case as well. You can get the fields of an object by using __dict__ or vars (which is considered more pythonic given: Python dictionary from an object's fields)
You could do something like:
class ClassTest:
def __init__(self):
self.one = 1
self.two = 2
list(vars(ClassTest()).values())
But you will see that it has some limitations. It doesn't recognize the not in self.variable_name defined variables like you have used above. It might help you nonetheless, because you can simply define them in init.
The dir1/a.py is a class to be tested. I will need to mock/patch the class.
dir1/a.py
from somemodule import get_c
#dataclass
class A:
x: int
y: object
c: ClassVar[C] = get_c() # get_c() need to be mocked/patched
test_1.py
#pytest.fixture
def sut() -> A:
x = 1
y = Mock()
return A(x, y)
def test_...(sut): # get_c() is called
''' '''
test_2.py
#patch('dir.a.A')
def test_...(): # get_c() is called
How to mock/patch A.c in the tests?
Because c: ClassVar[C] = get_c() is in the declaration of the A dataclass at the top-level of the dir1/a.py file, it gets run when the module is imported. So get_c will be called unless you take extreme measures (implementing a custom import loader, or patching the dataclass decorator before the import dir1.a gets called, ...).
If you don't want get_c to ever be called in your tests, the best and simpler solution is to change the code of dir1/a.py to not do it.
If it is OK for get_c to be called, but that the methods on it should not be used, it becomes simpler : just replace the c default value of your A dataclass with one of your own.
import pytest
import unittest.mock as mocking
from dataclasses import dataclass
from typing import ClassVar, TypeVar
C = TypeVar('C')
def get_c():
print("`get_c` called")
return "THING TO MOCK"
#dataclass
class A:
x: int
y: object
c: ClassVar[C] = get_c()
#pytest.fixture
def sut() -> A:
x = 1
y = mocking.Mock()
a = A(x, y)
a.c = mocking.Mock() # change the value of `c`
return a
def test_a(sut):
assert isinstance(sut.c, mocking.Mock)
I want to know a simple way to make a dataclass bar frozen.
#dataclass
class Bar:
foo: int
bar = Bar(foo=1)
In other words, I want a function like the following some_fn_to_freeze
frozen_bar = some_fn_to_freeze(bar)
frozen_bar.foo = 2 # Error
And, an inverse function some_fn_to_unfreeze
bar = som_fn_to_unfrozen(frozen_bar)
bar.foo = 3 # not Error
The standard way to mutate a frozen dataclass is to use dataclasses.replace:
old_bar = Bar(foo=123)
new_bar = dataclasses.replace(old_bar, foo=456)
assert new_bar.foo == 456
For more complex use-cases, you can use the dataclass utils module from: https://github.com/google/etils
It add a my_dataclass = my_dataclass.unfrozen() member, which allow to mutate frozen dataclasses directly
# pip install etils[edc]
from etils import edc
#edc.dataclass(allow_unfrozen=True) # Add the `unfrozen()`/`frozen` method
#dataclasses.dataclass(frozen=True)
class A:
x: Any = None
y: Any = None
old_a = A(x=A(x=A()))
# After a is unfrozen, the updates on nested attributes will be propagated
# to the top-level parent.
a = old_a.unfrozen()
a.x.x.x = 123
a.x.y = 'abc'
a = a.frozen() # `frozen()` recursively call `dataclasses.replace`
# Only the `unfrozen` object is mutated. Not the original one.
assert a == A(x=A(x=A(x = 123), y='abc'))
assert old_a == A(x=A(x=A()))
As seen in the example, you can return unfrozen/frozen copies of the dataclass, which was explicitly designed to mutate nested dataclasses.
#edc.dataclass also add a a.replace(**kwargs) method to the dataclass (alias of dataclasses.dataclass)
a = A()
a = a.replace(x=123, y=456)
assert a == A(x=123, y=456)
dataclass doesn't have built-in support for that. Frozen-ness is tracked on a class-wide basis, not per-instance, and there's no support for automatically generating frozen or unfrozen equivalents of dataclasses.
While you could try to do something to generate new dataclasses on the fly, it'd interact very poorly with isinstance, ==, and other things you'd want to work. It's probably safer to just write two dataclasses and converter methods:
#dataclass
class Bar:
foo: int
def as_frozen(self):
return FrozenBar(self.foo)
#dataclass(frozen=True)
class FrozenBar:
foo: int
def as_unfrozen(self):
return Bar(self.foo)
Python dataclasses are great, but the attrs package is a more flexible alternative, if you are able to use a third-party library. For example:
import attr
# Your class of interest.
#attr.s()
class Bar(object):
val = attr.ib()
# A frozen variant of it.
#attr.s(frozen = True)
class FrozenBar(Bar):
pass
# Three instances:
# - Bar.
# - FrozenBar based on that Bar.
# - Bar based on that FrozenBar.
b1 = Bar(123)
fb = FrozenBar(**attr.asdict(b1))
b2 = Bar(**attr.asdict(fb))
# We can modify the Bar instances.
b1.val = 777
b2.val = 888
# Check current vals.
for x in (b1, fb, b2):
print(x)
# But we cannot modify the FrozenBar instance.
try:
fb.val = 999
except attr.exceptions.FrozenInstanceError:
print(fb, 'unchanged')
Output:
Bar(val=888)
FrozenBar(val=123)
Bar(val=999)
FrozenBar(val=123) unchanged
I'm using the following code to get a frozen copy of a dataclass class or instance:
import dataclasses
from dataclasses import dataclass, fields, asdict
import typing
from typing import TypeVar
FDC_SELF = TypeVar('FDC_SELF', bound='FreezableDataClass')
#dataclass
class FreezableDataClass:
#classmethod
def get_frozen_dataclass(cls: Type[FDC_SELF]) -> Type[FDC_SELF]:
"""
#return: a generated frozen dataclass definition, compatible with the calling class
"""
cls_fields = fields(cls)
frozen_cls_name = 'Frozen' + cls.__name__
frozen_dc_namespace = {
'__name__': frozen_cls_name,
'__module__': __name__,
}
excluded_from_freezing = cls.attrs_excluded_from_freezing()
for attr in dir(cls):
if attr.startswith('__') or attr in excluded_from_freezing:
continue
attr_def = getattr(cls, attr)
if hasattr(attr_def, '__func__'):
attr_def = classmethod(getattr(attr_def, '__func__'))
frozen_dc_namespace[attr] = attr_def
frozen_dc = dataclasses.make_dataclass(
cls_name=frozen_cls_name,
fields=[(f.name, f.type, f) for f in cls_fields],
bases=(),
namespace=frozen_dc_namespace,
frozen=True,
)
globals()[frozen_dc.__name__] = frozen_dc
return frozen_dc
#classmethod
def attrs_excluded_from_freezing(cls) -> typing.Iterable[str]:
return tuple()
def get_frozen_instance(self: FDC_SELF) -> FDC_SELF:
"""
#return: an instance of a generated frozen dataclass, compatible with the current dataclass, with copied values
"""
cls = type(self)
frozen_dc = cls.get_frozen_dataclass()
# noinspection PyArgumentList
return frozen_dc(**asdict(self))
Derived classes could overwrite attrs_excluded_from_freezing to exclude methods which wouldn't work on a frozen dataclass.
Why didn't I prefer other existing answers?
3rd party libraries - etils.edc, If I would use a solution from one of the previous answers, it would be this one. E.g. to get the ability to recursively freeze/unfreeze.
3rd party libraries - attrs
duplicated code
How can I upgrade values from a base dataclass to one that inherits from it?
Example (Python 3.7.2)
from dataclasses import dataclass
#dataclass
class Person:
name: str
smell: str = "good"
#dataclass
class Friend(Person):
# ... more fields
def say_hi(self):
print(f'Hi {self.name}')
friend = Friend(name='Alex')
f1.say_hi()
prints "Hi Alex"
random_stranger = Person(name = 'Bob', smell='OK')
return for random_stranger "Person(name='Bob', smell='OK')"
How do I turn the random_stranger into a friend?
Friend(random_stranger)
returns "Friend(name=Person(name='Bob', smell='OK'), smell='good')"
I'd like to get "Friend(name='Bob', smell='OK')" as a result.
Friend(random_stranger.name, random_stranger.smell)
works, but how do I avoid having to copy all fields?
Or is it possible that I can't use the #dataclass decorator on classes that inherit from dataclasses?
What you are asking for is realized by the factory method pattern, and can be implemented in python classes straight forwardly using the #classmethod keyword.
Just include a dataclass factory method in your base class definition, like this:
import dataclasses
#dataclasses.dataclass
class Person:
name: str
smell: str = "good"
#classmethod
def from_instance(cls, instance):
return cls(**dataclasses.asdict(instance))
Any new dataclass that inherit from this baseclass can now create instances of each other[1] like this:
#dataclasses.dataclass
class Friend(Person):
def say_hi(self):
print(f'Hi {self.name}')
random_stranger = Person(name = 'Bob', smell='OK')
friend = Friend.from_instance(random_stranger)
print(friend.say_hi())
# "Hi Bob"
[1] It won't work if your child classes introduce new fields with no default values, you try to create parent class instances from child class instances, or your parent class has init-only arguments.
You probably do not want to have the class itself be a mutable property, and instead use something such as an enum to indicate a status such as this. Depending on the requirements, you may consider one of a few patterns:
class RelationshipStatus(Enum):
STRANGER = 0
FRIEND = 1
PARTNER = 2
#dataclass
class Person(metaclass=ABCMeta):
full_name: str
smell: str = "good"
status: RelationshipStatus = RelationshipStatus.STRANGER
#dataclass
class GreetablePerson(Person):
nickname: str = ""
#property
def greet_name(self):
if self.status == RelationshipStatus.STRANGER:
return self.full_name
else:
return self.nickname
def say_hi(self):
print(f"Hi {self.greet_name}")
if __name__ == '__main__':
random_stranger = GreetablePerson(full_name="Robert Thirstwilder",
nickname="Bobby")
random_stranger.status = RelationshipStatus.STRANGER
random_stranger.say_hi()
random_stranger.status = RelationshipStatus.FRIEND
random_stranger.say_hi()
You may want, also, to implement this in a trait/mixin style. Instead of creating a GreetablePerson, instead make a class Greetable, also abstract, and make your concrete class inherit both of those.
You may also consider using the excellent, backported, much more flexible attrs package. This would also enable you to create a fresh object with the evolve() function:
friend = attr.evolve(random_stranger, status=RelationshipStatus.FRIEND)
dataclasses.asdict is recursive (see doc),
so if fields themselves are dataclasses, dataclasses.asdict(instance) appearing in other answers breaks. Instead, define:
from dataclasses import fields
def shallow_asdict(instance):
return {field.name: getattr(instance, field.name) for field in fields(instance)}
and use it to initialize a Friend object from the Person object's fields:
friend = Friend(**shallow_asdict(random_stranger))
assert friend == Friend(name="Bob", smell="OK")
vars(stranger) gives you a dict of all attributes of the dataclass instance stranger. As the default __init__() method of dataclasses takes keyword arguments, twin_stranger = Person(**vars(stranger)) creates a new instance with a copy of the values. That also works for derived classes if you supply the additional arguments like stranger_got_friend = Friend(**vars(stranger), city='Rome'):
from dataclasses import dataclass
#dataclass
class Person:
name: str
smell: str
#dataclass
class Friend(Person):
city: str
def say_hi(self):
print(f'Hi {self.name}')
friend = Friend(name='Alex', smell='good', city='Berlin')
friend.say_hi() # Hi Alex
stranger = Person(name='Bob', smell='OK')
stranger_got_friend = Friend(**vars(stranger), city='Rome')
stranger_got_friend.say_hi() # Hi Bob
I'm trying to create a frozen dataclass but I'm having issues with setting a value from __post_init__. Is there a way to set a field value based on values from an init param in a dataclass when using the frozen=True setting?
RANKS = '2,3,4,5,6,7,8,9,10,J,Q,K,A'.split(',')
SUITS = 'H,D,C,S'.split(',')
#dataclass(order=True, frozen=True)
class Card:
rank: str = field(compare=False)
suit: str = field(compare=False)
value: int = field(init=False)
def __post_init__(self):
self.value = RANKS.index(self.rank) + 1
def __add__(self, other):
if isinstance(other, Card):
return self.value + other.value
return self.value + other
def __str__(self):
return f'{self.rank} of {self.suit}'
and this is the trace
File "C:/Users/user/.PyCharm2018.3/config/scratches/scratch_5.py", line 17, in __post_init__
self.value = RANKS.index(self.rank) + 1
File "<string>", line 3, in __setattr__
dataclasses.FrozenInstanceError: cannot assign to field 'value'
Use the same thing the generated __init__ method does: object.__setattr__.
def __post_init__(self):
object.__setattr__(self, 'value', RANKS.index(self.rank) + 1)
A solution I use in almost all of my classes is to define additional constructors as classmethods.
Based on the given example, one could rewrite it as follows:
#dataclass(order=True, frozen=True)
class Card:
rank: str = field(compare=False)
suit: str = field(compare=False)
value: int
#classmethod
def from_rank_and_suite(cls, rank: str, suit: str) -> "Card":
value = RANKS.index(self.rank) + 1
return cls(rank=rank, suit=suit, value=value)
By this one has all the freedom one requires without having to resort to __setattr__ hacks and without having to give up desired strictness like frozen=True.
Using mutation
Frozen objects should not be changed. But once in a while the need may arise. The accepted answer works perfectly for that. Here is another way of approaching this: return a new instance with the changed values. This may be overkill for some cases, but it's an option.
from copy import deepcopy
#dataclass(frozen=True)
class A:
a: str = ''
b: int = 0
def mutate(self, **options):
new_config = deepcopy(self.__dict__)
# some validation here
new_config.update(options)
return self.__class__(**new_config)
Another approach
If you want to set all or many of the values, you can call __init__ again inside __post_init__. Though there are not many use cases.
The following example is not practical, only for demonstrating the possibility.
from dataclasses import dataclass, InitVar
#dataclass(frozen=True)
class A:
a: str = ''
b: int = 0
config: InitVar[dict] = None
def __post_init__(self, config: dict):
if config:
self.__init__(**config)
The following call
A(config={'a':'a', 'b':1})
will yield
A(a='a', b=1)
without throwing error. This is tested on python 3.7 and 3.9.
Of course, you can directly construct using A(a='hi', b=1), but there maybe other uses, e.g. loading configs from a json file.
Bonus: an even crazier usage
A(config={'a':'a', 'b':1, 'config':{'a':'b'}})
will yield
A(a='b', b=1)
This feels a little bit like 'hacking' the intent of a frozen dataclass, but works well and is clean for making modifications to a frozen dataclass within the post_init method. Note that this decorator could be used for any method (which feels scary, given that you expect the dataclass to be frozen), thus I compensated by asserting the function name this decorator attaches to must be 'post_init'.
Separate from the class, write a decorator that you'll use in the class:
def _defrost(cls):
cls.stash_setattr = cls.__setattr__
cls.stash_delattr = cls.__delattr__
cls.__setattr__ = object.__setattr__
cls.__delattr__ = object.__delattr__
def _refreeze(cls):
cls.__setattr__ = cls.stash_setattr
cls.__delattr__ = cls.stash_delattr
del cls.stash_setattr
del cls.stash_delattr
def temp_unfreeze_for_postinit(func):
assert func.__name__ == '__post_init__'
def wrapper(self, *args, **kwargs):
_defrost(self.__class__)
func(self, *args, **kwargs)
_refreeze(self.__class__)
return wrapper
Then, within your frozen dataclass, simply decorate your post_init method!
#dataclasses.dataclass(frozen=True)
class SimpleClass:
a: int
#temp_unfreeze_for_postinit
def __post_init__(self, adder):
self.b = self.a + adder
Solution avoiding object mutation using cached property
This is a simplified Version of #Anna Giasson answer.
Frozen dataclasses work well together with caching from the functools module. Instead of using a dataclass field, you can define a #functools.cached_property annotated method that gets evaluated only upon the first lookup of the attribute. Here is a minimal version of the original example:
from dataclasses import dataclass
import functools
#dataclass(frozen=True)
class Card:
rank: str
#functools.cached_property
def value(self):
# just for demonstration:
# this gets printed only once per Card instance
print("Evaluate value")
return len(self.rank)
card = Card(rank="foo")
assert card.value == 3
assert card.value == 3
In practice, if the evaluation is cheap, you can also use a non-cached #property decorator.
Commenting with my own solution as I stumbled upon this with the same question but found none of the solutions suited my application.
Here the property that, much like OP, I tried to create in a post_init method initially is the bit_mask property.
I got it to work the cached_property decorator in functools; since I wanted the property to be static/immutable much like the other properties in the dataclass.
The function create_bitmask is defined elsewhere in my code, but you can see that it depends on the other properties of the dataclass instantance.
Hopefully, someone else might find this helpful.
from dataclasses import dataclass
from functools import cached_property
#dataclass(frozen=True)
class Register:
subsection: str
name: str
abbreviation: str
address: int
n_bits: int
_get_method: Callable[[int], int]
_set_method: Callable[[int, int], None]
_save_method: Callable[[int, int], None]
#cached_property
def bit_mask(self) -> int:
# The cache is used to avoid recalculating since this is a static value
# (hence max_size = 1)
return create_bitmask(
n_bits=self.n_bits,
start_bit=0,
size=self.n_bits,
set_val=True
)
def get(self) -> int:
raw_value = self._get_method(self.address)
return raw_value & self.bit_mask
def set(self, value: int) -> None:
self._set_method(
self.address,
value & self.bit_mask
)
def save(self, value: int) -> None:
self._save_method(
self.address,
value & self.bit_mask
)
Avoiding mutation as proposed by Peter Barmettler is what I tend to do in such cases. It feels much more consistent with the frozen=True feature. As a side note, order=True and the __add__ method made me think you would like to sort and compute a score based on a list of cards.
This might be a possible approach:
from __future__ import annotations
from dataclasses import dataclass
RANKS = '2,3,4,5,6,7,8,9,10,J,Q,K,A'.split(',')
SUITS = 'H,D,C,S'.split(',')
#dataclass(frozen=True)
class Card:
rank: str
suit: str
#property
def value(self) -> int:
return RANKS.index(self.rank) + 1
def __lt__(self, __o: Card) -> bool:
return self.value < __o.value
def __str__(self) -> str:
return f'{self.rank} of {self.suit}'
#classmethod
def score(cls, cards: list[Card]) -> int:
return sum(card.value for card in cards)
c1 = Card('A', 'H')
c2 = Card('3', 'D')
cards = [c1, c2]
Card.score(cards) # -> 15
sorted(cards) # -> [Card(rank='3', suit='D'), Card(rank='A', suit='H')]
The scoring logic does not need to be a class method, but this feels ok since the logic determining the value of a card is inside the class as well.