I have Input and Output pandera SchemaModels and the Output inherits the Input which accurately represents that all attributes of the Input schema are in the scope of the Output schema.
What I want to avoid is inheriting all attributes as required (non-Optional) as they are rightly coming from the Input schema. Instead I want to preserve them as required for the Input schema but define which of them remain required for the Output schema while the other inherited attributes become optional.
This pydantic question is similar and has solution for defining __init_subclass__ method in the parent class. However, this doesn't work out of the box for pandera classes and I'm not sure if it is even implementable or the right approach.
import pandera as pa
from typing import Optional
from pandera.typing import Index, DataFrame, Series, Category
class InputSchema(pa.SchemaModel):
reporting_date: Series[pa.DateTime] = pa.Field(coerce=True)
def __init_subclass__(cls, optional_fields=None, **kwargs):
super().__init_subclass__(**kwargs)
if optional_fields:
for field in optional_fields:
cls.__fields__[field].outer_type_ = Optional
cls.__fields__[field].required = False
class OutputSchema(InputSchema, optional_fields=['reporting_date']):
test: Series[str] = pa.Field()
#pa.check_types
def func(inputs: DataFrame[InputSchema]) -> DataFrame[OutputSchema]:
inputs = inputs.drop(columns=['reporting_date'])
inputs['test'] = 'a'
return inputs
data = pd.DataFrame({'reporting_date': ['2023-01-11', '2023-01-12']})
func(data)
Error:
---> 18 class OutputSchema(InputSchema, optional_fields=['reporting_date']):
KeyError: 'reporting_date'
Edit:
Desired outcome to be able to set which fields from the inherited schema remain required while the remaining become optional:
class InputSchema(pa.SchemaModel):
reporting_date: Series[pa.DateTime] = pa.Field(coerce=True)
other_field: Series[str] = pa.Field()
class OutputSchema(InputSchema, required=['reporting_date'])
test: Series[str] = pa.Field()
The resulting OutputSchema has reporting_date and test as required while other_field as optional.
A similar question was asked on pandera's issue tracker, with a docs update on track for the next pandera release. There is no clean solution, but the most simple one is to exclude columns by overloading to_schema:
import pandera as pa
from pandera.typing import Series
class InputSchema(pa.SchemaModel):
reporting_date: Series[pa.DateTime] = pa.Field(coerce=True)
class OutputSchema(InputSchema):
test: Series[str]
#classmethod
def to_schema(cls) -> pa.DataFrameSchema:
return super().to_schema().remove_columns(["reporting_date"])
This runs without SchemaError against your check function.
Here is a solution by reusing existing type annotation from the input schema:
import pandera as pa
import pandas as pd
from typing import Optional
from pandera.typing import Index, DataFrame, Series, Category
from pydantic import Field, BaseModel
from typing import Annotated, Type
def copy_field(from_model: Type[BaseModel], fname: str, annotations: dict[str, ...]):
annotations[fname] = from_model.__annotations__[fname]
class InputSchema(pa.SchemaModel):
reporting_date: Series[pa.DateTime] = pa.Field(coerce=True)
not_inherit: Series[str]
class OutputSchema(pa.SchemaModel):
test: Series[str] = pa.Field()
copy_field(InputSchema, "reporting_date", __annotations__)
# reporting_date: Series[pa.DateTime] = pa.Field(coerce=True)
# not_inherit: Optional[Series[str]]
data = pd.DataFrame({
'reporting_date': ['2023-01-11', '2023-01-12'],
'not_inherit': ['a','a']
})
#pa.check_types
def func(
inputs: DataFrame[InputSchema]
) -> DataFrame[OutputSchema]:
inputs = inputs.drop(columns=['not_inherit'])
inputs['test'] = 'a'
return inputs
func(data)
Related
I am using pydantic to manage settings for an app that supports different datasets. Each has a set of overridable defaults, but they are different per datasets. Currently, I have all of the logic correctly implemented via validators:
from pydantic import BaseModel
class DatasetSettings(BaseModel):
dataset_name: str
table_name: str
#validator("table_name", always=True)
def validate_table_name(cls, v, values):
if isinstance(v, str):
return v
if values["dataset_name"] == "DATASET_1":
return "special_dataset_1_default_table"
if values["dataset_name"] == "DATASET_2":
return "special_dataset_2_default_table"
return "default_table"
class AppSettings(BaseModel):
dataset_settings: DatasetSettings
app_url: str
This way, I get different defaults based on dataset_name, but the user can override them if necessary. This is the desired behavior. The trouble is that once there are more than a handful of such fields and names, it gets to be a mess to read and to maintain. It seems like inheritance/polymorphism would solve this problem but the pydantic factory logic seems too hardcoded to make it feasible, especially with nested models.
class Dataset1Settings(DatasetSettings):
dataset_name: str = "DATASET_1"
table_name: str = "special_dataset_1_default_table"
class Dataset2Settings(DatasetSettings):
dataset_name: str = "DATASET_2"
table_name: str = "special_dataset_2_default_table"
def dataset_settings_factory(dataset_name, table_name=None):
if dataset_name == "DATASET_1":
return Dataset1Settings(dataset_name, table_name)
if dataset_name == "DATASET_2":
return Dataset2Settings(dataset_name, table_name)
return DatasetSettings(dataset_name, table_name)
class AppSettings(BaseModel):
dataset_settings: DatasetSettings
app_url: str
Options I've considered:
Create a new set of default dataset settings models, override __init__ of DatasetSettings, instantiate the subclass and copy its attributes into the parent class. Kind of clunky.
Override __init__ of AppSettings using the dataset_settings_factory to set the dataset_settings attribute of AppSettings. Not so good because the default behavior doesn't work in the DatasetSettings at all, only when instantiated as a nested model in AppSettings.
I was hoping Field(default_factory=dataset_settings_factory) would work, but the default_factory is only for actual defaults so it has zero args. Is there some other way to intercept the args of a particular pydantic field and use a custom factory?
Another option would be to use a Discriminated/Tagged Unions.
But your solution (without looking in detail) looks fine too.
I ended up solving the problem following the first option, as follows. Code is runnable with pydantic 1.8.2 and pydantic 1.9.1.
from typing import Optional
from pydantic import BaseModel, Field
class DatasetSettings(BaseModel):
dataset_name: Optional[str] = Field(default="DATASET_1")
table_name: Optional[str] = None
def __init__(self, **data):
factory_dict = {"DATASET_1": Dataset1Settings, "DATASET_2": Dataset2Settings}
dataset_name = (
data["dataset_name"]
if "dataset_name" in data
else self.__fields__["dataset_name"].default
)
if dataset_name in factory_dict:
data = factory_dict[dataset_name](**data).dict()
super().__init__(**data)
class Dataset1Settings(BaseModel):
dataset_name: str = "DATASET_1"
table_name: str = "special_dataset_1_default_table"
class Dataset2Settings(BaseModel):
dataset_name: str = "DATASET_2"
table_name: str = "special_dataset_2_default_table"
class AppSettings(BaseModel):
dataset_settings: DatasetSettings = Field(default_factory=DatasetSettings)
app_url: Optional[str]
app_settings = AppSettings(dataset_settings={"dataset_name": "DATASET_1"})
assert app_settings.dataset_settings.table_name == "special_dataset_1_default_table"
app_settings = AppSettings(dataset_settings={"dataset_name": "DATASET_2"})
assert app_settings.dataset_settings.table_name == "special_dataset_2_default_table"
# bonus: no args mode
app_settings = AppSettings()
assert app_settings.dataset_settings.table_name == "special_dataset_1_default_table"
A couple of gotchas I discovered along the way:
If Dataset1Settings inherits from DatasetSettings, it enters a recursive loop calling init on init ad infinitum. This could be broken with some introspection, but I opted for the duck approach.
The current solution destroys any validators on DatasetSettings. I'm sure there's a way to call the validation logic anyway but the current solution effectively sidesteps whatever class-level validation you have by only initing with super().__init__
The same thing works for BaseSettings objects, but you have to drag their cumbersome init args:
def __init__(
self,
_env_file: Union[Path, str, None] = None,
_env_file_encoding: Optional[str] = None,
_secrets_dir: Union[Path, str, None] = None,
**values: Any
):
...
I wonder if there is a way to implement subclasses of a base class for different types. Each subclass should have individual input and output types while providing same behaviour as the base class.
Background: I want to process voltage and temperature samples. 100 voltage samples form a VoltageDataset. 100 temperature samples form a TemperatureDataset. Multiple VoltageDatasets form a VoltageDataCluster. Same for temperature. The processing of Datasets depends on their physical quantity. To ensure that voltage related processing can't be applied to temperature samples I'd like to add type hints.
So I'd would be nice if there is a way to define that VoltageDataClustes method append_dataset allows VoltageDataset as input type only. Same for temperature.
Is there a way to implement this behaviour without copy&pasting?
# base class
class DataCluster:
def __init__(self, name):
self.name = name
self.datasets = list()
def append_dataset(self, dataset: Dataset) -> None:
self.datasets.append(dataset)
# subclass that should allow VoltageDataset input only.
class VoltageDataCluster(DataCluster):
pass
# subclass that should allow TemperatureDataset input only.
class TemperatureDataCluster(DataCluster):
pass
Thanks!
Niklas
You could use pydantic generic models.
from typing import Generic, TypeVar, List
from pydantic.generics import GenericModel
DataT = TypeVar('DataT')
class DataCluster(GenericModel, Generic[DataT]):
name: str
datasets: List[DataT] = []
def append_dataset(self, dataset: DataT) -> None:
self.datasets.append(dataset)
voltage_cluster = DataCluster[VoltageDataset](name="name")
voltage_cluster.append_dataset(some_voltage_dataset)
When you inherit a class it automatically inherits the functionality of the class so there is no need to copy and paste. I'll illustrate this with an example.
# DataCluster.py
class DataCluster:
def __init__(self, name):
self.name = name
def printHello(self):
print("Hello")
# This will work in sub classes that have a "data" attribute
def printData(self):
print(self.data)
# VoltageDataCluster.py
from superclasses.DataCluster import DataCluster
class VoltageDataCluster(DataCluster):
def __init__(self, differentInput):
self.differentInput = differentInput
self.data = "someotherdata"
# mainclass.py
from superclasses.DataCluster import DataCluster
from superclasses.VoltageDataCluster import VoltageDataCluster
try:
dc = DataCluster("mark")
dc.printHello();
# The input for this class is not name
vdc = VoltageDataCluster("Some Other Input")
# These methods are only defined in DataCluster
vdc.printHello()
vdc.printData()
As you can see, even though we only defined the "printHello" method in the super class, the other class inherited this method while using different inputs. So no copy and pasting required. Here is a runnable example (I added comments to tell you where to find each file used).
EDIT: Added a data attribute so its more relevant to your example.
I'm using Pydantic to define hierarchical data in which there are models with identical attributes.
However, when I save and load these models, Pydantic can no longer distinguish which model was used and picks the first one in the field type annotation.
I understand that this is expected behavior based on the documentation.
However, the class type information is important to my application.
What is the recommended way to distinguish between different classes in Pydantic? One hack is to simply add an extraneous field to one of the models, but I'd like to find a more elegant solution.
See the simplified example below: container is initialized with data of type DataB, but after exporting and loading, the new container has data of type DataA as it's the first element in the type declaration of container.data.
Thanks for your help!
from abc import ABC
from pydantic import BaseModel #pydantic 1.8.2
from typing import Union
class Data(BaseModel, ABC):
""" base class for a Member """
number: float
class DataA(Data):
""" A type of Data"""
pass
class DataB(Data):
""" Another type of Data """
pass
class Container(BaseModel):
""" container holds a subclass of Data """
data: Union[DataA, DataB]
# initialize container with DataB
data = DataB(number=1.0)
container = Container(data=data)
# export container to string and load new container from string
string = container.json()
new_container = Container.parse_raw(string)
# look at type of container.data
print(type(new_container.data).__name__)
# >>> DataA
As correctly noted in the comments, without storing additional information models cannot be distinguished when parsing.
As of today (pydantic v1.8.2), the most canonical way to distinguish models when parsing in a Union (in case of ambiguity) is to explicitly add a type specifier Literal. It will look like this:
from abc import ABC
from pydantic import BaseModel
from typing import Union, Literal
class Data(BaseModel, ABC):
""" base class for a Member """
number: float
class DataA(Data):
""" A type of Data"""
tag: Literal['A'] = 'A'
class DataB(Data):
""" Another type of Data """
tag: Literal['B'] = 'B'
class Container(BaseModel):
""" container holds a subclass of Data """
data: Union[DataA, DataB]
# initialize container with DataB
data = DataB(number=1.0)
container = Container(data=data)
# export container to string and load new container from string
string = container.json()
new_container = Container.parse_raw(string)
# look at type of container.data
print(type(new_container.data).__name__)
# >>> DataB
This method can be automated, but you can use it at your own responsibility, since it breaks static typing and uses objects that may change in future versions:
from pydantic.fields import ModelField
class Data(BaseModel, ABC):
""" base class for a Member """
number: float
def __init_subclass__(cls, **kwargs):
name = 'tag'
value = cls.__name__
annotation = Literal[value]
tag_field = ModelField.infer(name=name, value=value, annotation=annotation, class_validators=None, config=cls.__config__)
cls.__fields__[name] = tag_field
cls.__annotations__[name] = annotation
class DataA(Data):
""" A type of Data"""
pass
class DataB(Data):
""" Another type of Data """
pass
Just wanted to take the opportunity to list another possible alternative here to pydantic - which already supports this use case very well, as per below answer.
I am the creator and maintainer of a relatively newer and lesser-known JSON serialization library, the Dataclass Wizard - which relies on the Python dataclasses module to perform its magic. As of the latest version, 0.14.0, the dataclass-wizard now supports dataclasses within Union types. Previously, it did not support dataclasses within Union types at all, which was kind of a glaring omission, and something on my "to-do" list of things to (eventually) add support for.
As of the latest, it should now support defining dataclasses within Union types. The reason it did not generally work before, is because the data being de-serialized is often a JSON object, which only knows simple types such as arrays and dictionaries, for example. A dict type would not otherwise match any of the Union[Data1, Data2] types, even if the object had all the correct dataclass fields as keys. This is simply because it doesn't compare the dict object against each of the dataclass fields in the Union types, though that might change in a future release.
So in any case, here is a simple example to demonstrate the usage of dataclasses in Union types, using a class inheritance model with the JSONWizard mixin class:
With Class Inheritance
from abc import ABC
from dataclasses import dataclass
from typing import Union
from dataclass_wizard import JSONWizard
#dataclass
class Data(ABC):
""" base class for a Member """
number: float
class DataA(Data, JSONWizard):
""" A type of Data"""
class _(JSONWizard.Meta):
"""
This defines a custom tag that uniquely identifies the dataclass.
"""
tag = 'A'
class DataB(Data, JSONWizard):
""" Another type of Data """
class _(JSONWizard.Meta):
"""
This defines a custom tag that uniquely identifies the dataclass.
"""
tag = 'B'
#dataclass
class Container(JSONWizard):
""" container holds a subclass of Data """
data: Union[DataA, DataB]
The usage is shown below, and is again pretty straightforward. It relies on a special __tag__ key set in a dictionary or JSON object to marshal it into the correct dataclass, based on the Meta.tag value for that class, that we have set up above.
print('== Load with DataA ==')
input_dict = {
'data': {
'number': '1.0',
'__tag__': 'A'
}
}
# De-serialize the `dict` object to a `Container` instance.
container = Container.from_dict(input_dict)
print(repr(container))
# prints:
# Container(data=DataA(number=1.0))
# Show the prettified JSON representation of the instance.
print(container)
# Assert we load the correct dataclass from the annotated `Union` types
assert type(container.data) == DataA
print()
print('== Load with DataB ==')
# initialize container with DataB
data_b = DataB(number=2.0)
container = Container(data=data_b)
print(repr(container))
# prints:
# Container(data=DataB(number=2.0))
# Show the prettified JSON representation of the instance.
print(container)
# Assert we load the correct dataclass from the annotated `Union` types
assert type(container.data) == DataB
# Assert we end up with the same instance when serializing and de-serializing
# our data.
string = container.to_json()
assert container == Container.from_json(string)
Without Class Inheritance
Here is the same example as above, but with relying solely on dataclasses, without using any special class inheritance model:
from abc import ABC
from dataclasses import dataclass
from typing import Union
from dataclass_wizard import asdict, fromdict, LoadMeta
#dataclass
class Data(ABC):
""" base class for a Member """
number: float
class DataA(Data):
""" A type of Data"""
class DataB(Data):
""" Another type of Data """
#dataclass
class Container:
""" container holds a subclass of Data """
data: Union[DataA, DataB]
# Setup tags for the dataclasses. This can be passed into either
# `LoadMeta` or `DumpMeta`.
#
# Note that I'm not a fan of this syntax either, so it might change. I was
# thinking of something more explicit, like `LoadMeta(...).bind_to(class)`
LoadMeta(DataA, tag='A')
LoadMeta(DataB, tag='B')
# The rest is the same as before.
# initialize container with DataB
data = DataB(number=2.0)
container = Container(data=data)
print(repr(container))
# prints:
# Container(data=DataB(number=2.0))
# Assert we load the correct dataclass from the annotated `Union` types
assert type(container.data) == DataB
# Assert we end up with the same data when serializing and de-serializing.
out_dict = asdict(container)
assert container == fromdict(Container, out_dict)
I'm trying to hack something together in the meantime using custom validators.
Basically the class decorator adds a class_name: str field, which is added to the json string. The validator then looks up the correct subclass based on its value.
def register_distinct_subclasses(fields: tuple):
""" fields is tuple of subclasses that we want to be registered as distinct """
field_map = {field.__name__: field for field in fields}
def _register_distinct_subclasses(cls):
""" cls is the superclass of fields, which we add a new validator to """
orig_init = cls.__init__
class _class:
class_name: str
def __init__(self, **kwargs):
class_name = type(self).__name__
kwargs["class_name"] = class_name
orig_init(**kwargs)
#classmethod
def __get_validators__(cls):
yield cls.validate
#classmethod
def validate(cls, v):
if isinstance(v, dict):
class_name = v.get("class_name")
json_string = json.dumps(v)
else:
class_name = v.class_name
json_string = v.json()
cls_type = field_map[class_name]
return cls_type.parse_raw(json_string)
return _class
return _register_distinct_subclasses
which is called as follows
Data = register_distinct_subclasses((DataA, DataB))(Data)
I have an abstract base class GameNodeState that contains a Type enum:
import abc
import enum
class GameNodeState(metaclass=abc.ABCMeta):
class Type(enum.Enum):
INIT = enum.auto()
INTERMEDIATE = enum.auto()
END = enum.auto()
The names in the enum are generic because they must make sense for any subclass of GameNodeState. But when I subclass GameNodeState, as GameState and RoundState, I would like to be able to add concrete aliases to the members of GameNodeState.Type if the enum is accessed through the subclass. For example, if the GameState subclass aliases INTERMEDIATE as ROUND and RoundState aliases INTERMEDIATE as TURN, I would like the following behaviour:
>>> GameNodeState.Type.INTERMEDIATE
<Type.INTERMEDIATE: 2>
>>> RoundState.Type.TURN
<Type.INTERMEDIATE: 2>
>>> RoundState.Type.INTERMEDIATE
<Type.INTERMEDIATE: 2>
>>> GameNodeState.Type.TURN
AttributeError: TURN
My first thought was this:
class GameState(GameNodeState):
class Type(GameNodeState.Type):
ROUND = GameNodeState.Type.INTERMEDIATE.value
class RoundState(GameNodeState):
class Type(GameNodeState.Type):
TURN = GameNodeState.Type.INTERMEDIATE.value
But enums can't be subclassed.
Note: there are obviously more attributes and methods in the GameNodeState hierarchy, I stripped it down to the bare minimum here to focus on this particular thing.
Refinement
(Original solution below.)
I've extracted an intermediate concept from the code above, namely the concept of enum union. This can be used to obtain the behaviour above, and is also useful in other contexts too. The code can be foud here, and I've asked a Code Review question.
I'll add the code here as well for reference:
import enum
import itertools as itt
from functools import reduce
import operator
from typing import Literal, Union
import more_itertools as mitt
AUTO = object()
class UnionEnumMeta(enum.EnumMeta):
"""
The metaclass for enums which are the union of several sub-enums.
Union enums have the _subenums_ attribute which is a tuple of the enums forming the
union.
"""
#classmethod
def make_union(
mcs, *subenums: enum.EnumMeta, name: Union[str, Literal[AUTO], None] = AUTO
) -> enum.EnumMeta:
"""
Create an enum whose set of members is the union of members of several enums.
Order matters: where two members in the union have the same value, they will
be considered as aliases of each other, and the one appearing in the first
enum in the sequence will be used as the canonical members (the aliases will
be associated to this enum member).
:param subenums: Sequence of sub-enums to make a union of.
:param name: Name to use for the enum class. AUTO will result in a combination
of the names of all subenums, None will result in "UnionEnum".
:return: An enum class which is the union of the given subenums.
"""
subenums = mcs._normalize_subenums(subenums)
class UnionEnum(enum.Enum, metaclass=mcs):
pass
union_enum = UnionEnum
union_enum._subenums_ = subenums
if duplicate_names := reduce(
set.intersection, (set(subenum.__members__) for subenum in subenums)
):
raise ValueError(
f"Found duplicate member names in enum union: {duplicate_names}"
)
# If aliases are defined, the canonical member will be the one that appears
# first in the sequence of subenums.
# dict union keeps last key so we have to do it in reverse:
union_enum._value2member_map_ = value2member_map = reduce(
operator.or_, (subenum._value2member_map_ for subenum in reversed(subenums))
)
# union of the _member_map_'s but using the canonical member always:
union_enum._member_map_ = member_map = {
name: value2member_map[member.value]
for name, member in itt.chain.from_iterable(
subenum._member_map_.items() for subenum in subenums
)
}
# only include canonical aliases in _member_names_
union_enum._member_names_ = list(
mitt.unique_everseen(
itt.chain.from_iterable(subenum._member_names_ for subenum in subenums),
key=member_map.__getitem__,
)
)
if name is AUTO:
name = (
"".join(subenum.__name__.removesuffix("Enum") for subenum in subenums)
+ "UnionEnum"
)
UnionEnum.__name__ = name
elif name is not None:
UnionEnum.__name__ = name
return union_enum
def __repr__(cls):
return f"<union of {', '.join(map(str, cls._subenums_))}>"
def __instancecheck__(cls, instance):
return any(isinstance(instance, subenum) for subenum in cls._subenums_)
#classmethod
def _normalize_subenums(mcs, subenums):
"""Remove duplicate subenums and flatten nested unions"""
# we will need to collapse at most one level of nesting, with the inductive
# hypothesis that any previous unions are already flat
subenums = mitt.collapse(
(e._subenums_ if isinstance(e, mcs) else e for e in subenums),
base_type=enum.EnumMeta,
)
subenums = mitt.unique_everseen(subenums)
return tuple(subenums)
def enum_union(*enums, **kwargs):
return UnionEnumMeta.make_union(*enums, **kwargs)
Once we have that, we can just define the extend_enum decorator to compute the union of the base enum and the enum "extension", which will result in the desired behaviour:
def extend_enum(base_enum):
def decorator(extension_enum):
return enum_union(base_enum, extension_enum)
return decorator
Usage:
class GameNodeState(metaclass=abc.ABCMeta):
class Type(enum.Enum):
INIT = enum.auto()
INTERMEDIATE = enum.auto()
END = enum.auto()
class RoundState(GameNodeState):
#extend_enum(GameNodeState.Type)
class Type(enum.Enum):
TURN = GameNodeState.Type.INTERMEDIATE.value
class GameState(GameNodeState):
#extend_enum(GameNodeState.Type)
class Type(enum.Enum):
ROUND = GameNodeState.Type.INTERMEDIATE.value
Now all of the examples above produce the same output (plus the added instance check, i.e. isinstance(RoundState.Type.TURN, RoundState.Type) returns True).
I think this is a cleaner solution because it doesn't involve mucking around with descriptors; it doesn't need to know anything about the owner class (this works just as well with top-level classes).
Attribute lookup through subclasses and instances of GameNodeState should automatically link to the correct "extension" (i.e., union), as long as the extension enum is added with the same name as for the GameNodeState superclass so that it hides the original definition.
Original
Not sure how bad of an idea this is, but here is a solution using a descriptor wrapped around the enum that gets the set of aliases based on the class from which it is being accessed.
class ExtensibleClassEnum:
class ExtensionWrapperMeta(enum.EnumMeta):
#classmethod
def __prepare__(mcs, name, bases):
# noinspection PyTypeChecker
classdict: enum._EnumDict = super().__prepare__(name, bases)
classdict["_ignore_"] = ["base_descriptor", "extension_enum"]
return classdict
# noinspection PyProtectedMember
def __new__(mcs, cls, bases, classdict):
base_descriptor = classdict.pop("base_descriptor")
extension_enum = classdict.pop("extension_enum")
wrapper_enum = super().__new__(mcs, cls, bases, classdict)
wrapper_enum.base_descriptor = base_descriptor
wrapper_enum.extension_enum = extension_enum
base, extension = base_descriptor.base_enum, extension_enum
if set(base._member_map_.keys()) & set(extension._member_map_.keys()):
raise ValueError("Found duplicate names in extension")
# dict union keeps last key so we have to do it in reverse:
wrapper_enum._value2member_map_ = (
extension._value2member_map_ | base._value2member_map_
)
# union of both _member_map_'s but using the canonical member always:
wrapper_enum._member_map_ = {
name: wrapper_enum._value2member_map_[member.value]
for name, member in itertools.chain(
base._member_map_.items(), extension._member_map_.items()
)
}
# aliases shouldn't appear in _member_names_
wrapper_enum._member_names_ = list(
m.name for m in wrapper_enum._value2member_map_.values()
)
return wrapper_enum
def __repr__(self):
# have to use vars() to avoid triggering the descriptor
base_descriptor = vars(self)["base_descriptor"]
return (
f"<extension wrapper enum for {base_descriptor.base_enum}"
f" in {base_descriptor._extension2owner[self]}>"
)
def __init__(self, base_enum):
if not issubclass(base_enum, enum.Enum):
raise TypeError(base_enum)
self.base_enum = base_enum
# The user won't be able to retrieve the descriptor object itself, just
# the enum, so we have to forward calls to register_extension:
self.base_enum.register_extension = staticmethod(self.register_extension)
# mapping of owner class -> extension for subclasses that define an extension
self._extensions: Dict[Type, ExtensibleClassEnum.ExtensionWrapperMeta] = {}
# reverse mapping
self._extension2owner: Dict[ExtensibleClassEnum.ExtensionWrapperMeta, Type] = {}
# add the base enum as the base extension via __set_name__:
self._pending_extension = base_enum
#property
def base_owner(self):
# will be initialised after __set_name__ is called with base owner
return self._extension2owner[self.base_enum]
def __set_name__(self, owner, name):
# step 2 of register_extension: determine the class that defined it
self._extensions[owner] = self._pending_extension
self._extension2owner[self._pending_extension] = owner
del self._pending_extension
def __get__(self, instance, owner):
# Only compute extensions once:
if owner in self._extensions:
return self._extensions[owner]
# traverse in MRO until we find the closest supertype defining an extension
for supertype in owner.__mro__:
if supertype in self._extensions:
extension = self._extensions[supertype]
break
else:
raise TypeError(f"{owner} is not a subclass of {self.base_owner}")
# Cache the result
self._extensions[owner] = extension
return extension
def make_extension(self, extension: enum.EnumMeta):
class ExtensionWrapperEnum(
enum.Enum, metaclass=ExtensibleClassEnum.ExtensionWrapperMeta
):
base_descriptor = self
extension_enum = extension
return ExtensionWrapperEnum
def register_extension(self, extension_enum):
"""Decorator for enum extensions"""
# need a way to determine owner class
# add a temporary attribute that we will use when __set_name__ is called:
if hasattr(self, "_pending_extension"):
# __set_name__ not called after the previous call to register_extension
raise RuntimeError(
"An extension was created outside of a class definition",
self._pending_extension,
)
self._pending_extension = self.make_extension(extension_enum)
return self
Usage would be as follows:
class GameNodeState(metaclass=abc.ABCMeta):
#ExtensibleClassEnum
class Type(enum.Enum):
INIT = enum.auto()
INTERMEDIATE = enum.auto()
END = enum.auto()
class RoundState(GameNodeState):
#GameNodeState.Type.register_extension
class Type(enum.Enum):
TURN = GameNodeState.Type.INTERMEDIATE.value
class GameState(GameNodeState):
#GameNodeState.Type.register_extension
class Type(enum.Enum):
ROUND = GameNodeState.Type.INTERMEDIATE.value
Then:
>>> (RoundState.Type.TURN
... == RoundState.Type.INTERMEDIATE
... == GameNodeState.Type.INTERMEDIATE
... == GameState.Type.INTERMEDIATE
... == GameState.Type.ROUND)
...
True
>>> RoundState.Type.__members__
mappingproxy({'INIT': <Type.INIT: 1>,
'INTERMEDIATE': <Type.INTERMEDIATE: 2>,
'END': <Type.END: 3>,
'TURN': <Type.INTERMEDIATE: 2>})
>>> list(RoundState.Type)
[<Type.INTERMEDIATE: 2>, <Type.INIT: 1>, <Type.END: 3>]
>>> GameNodeState.Type.TURN
Traceback (most recent call last):
...
File "C:\Program Files\Python39\lib\enum.py", line 352, in __getattr__
raise AttributeError(name) from None
AttributeError: TURN
I have a web application with many models and many class based views. The most part of code looks like this
from typing import TypeVar, Type
M = TypeVar('M', bound='Model')
TypeModel = Type[M]
# ---------- models
class Model:
#classmethod
def factory(cls: TypeModel) -> M:
return cls()
class ModelOne(Model):
def one(self):
return
class ModelTwo(Model):
def two(self):
return
# ---------- views
class BaseView:
model: TypeModel
#property
def obj(self) -> M:
return self.model.factory()
def logic(self):
raise NotImplementedError
class One(BaseView):
model = ModelOne
def logic(self):
self.obj. # how can i get suggest of methods of ModelOne here?
...
class Two(BaseView):
model = ModelTwo
def logic(self):
self.obj. # how can i get suggest of methods of ModelTwo here?
...
I want to have a property obj which is instance of specified model in view. How can I achieve this?
Thank you
You need to make your BaseView class generic with respect to M. So, you should do something like this:
from typing import TypeVar, Type, Generic
M = TypeVar('M', bound='Model')
# Models
class Model:
#classmethod
def factory(cls: Type[M]) -> M:
return cls()
class ModelOne(Model):
def one(self):
return
class ModelTwo(Model):
def two(self):
return
# Views
# A BaseView is now a generic type and will use M as a placeholder.
class BaseView(Generic[M]):
model: Type[M]
#property
def obj(self) -> M:
return self.model.factory()
def logic(self):
raise NotImplementedError
# The subclasses now specify what kind of model the BaseView should be
# working against when they subclass it.
class One(BaseView[ModelOne]):
model = ModelOne
def logic(self):
self.obj.one()
class Two(BaseView[ModelTwo]):
model = ModelTwo
def logic(self):
self.obj.two()
One note: I got rid of your TypeModel type alias. This is partly stylistic and partly pragmatic.
Stylistically, when I look at a type signature, I want to be able to immediately determine whether or not it's using generics/typevars or not. Using type aliases tends to obscure that/I don't really like using context-sensitive types.
Pragmatically, both PyCharm's type checker and mypy tend to struggle a little when you make excessive use of type aliases containing typevars.