How to compare Enums in Python? - python

Since Python 3.4, the Enum class exists.
I am writing a program, where some constants have a specific order and I wonder which way is the most pythonic to compare them:
class Information(Enum):
ValueOnly = 0
FirstDerivative = 1
SecondDerivative = 2
Now there is a method, which needs to compare a given information of Information with the different enums:
information = Information.FirstDerivative
print(value)
if information >= Information.FirstDerivative:
print(jacobian)
if information >= Information.SecondDerivative:
print(hessian)
The direct comparison does not work with Enums, so there are three approaches and I wonder which one is preferred:
Approach 1: Use values:
if information.value >= Information.FirstDerivative.value:
...
Approach 2: Use IntEnum:
class Information(IntEnum):
...
Approach 3: Not using Enums at all:
class Information:
ValueOnly = 0
FirstDerivative = 1
SecondDerivative = 2
Each approach works, Approach 1 is a bit more verbose, while Approach 2 uses the not recommended IntEnum-class, while and Approach 3 seems to be the way one did this before Enum was added.
I tend to use Approach 1, but I am not sure.
Thanks for any advise!

You should always implement the rich comparison operaters if you want to use them with an Enum. Using the functools.total_ordering class decorator, you only need to implement an __eq__ method along with a single ordering, e.g. __lt__. Since enum.Enum already implements __eq__ this becomes even easier:
>>> import enum
>>> from functools import total_ordering
>>> #total_ordering
... class Grade(enum.Enum):
... A = 5
... B = 4
... C = 3
... D = 2
... F = 1
... def __lt__(self, other):
... if self.__class__ is other.__class__:
... return self.value < other.value
... return NotImplemented
...
>>> Grade.A >= Grade.B
True
>>> Grade.A >= 3
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: unorderable types: Grade() >= int()
Terrible, horrible, ghastly things can happen with IntEnum. It was mostly included for backwards-compatibility sake, enums used to be implemented by subclassing int. From the docs:
For the vast majority of code, Enum is strongly recommended, since
IntEnum breaks some semantic promises of an enumeration (by being
comparable to integers, and thus by transitivity to other unrelated
enumerations). It should be used only in special cases where there’s
no other choice; for example, when integer constants are replaced with
enumerations and backwards compatibility is required with code that
still expects integers.
Here's an example of why you don't want to do this:
>>> class GradeNum(enum.IntEnum):
... A = 5
... B = 4
... C = 3
... D = 2
... F = 1
...
>>> class Suit(enum.IntEnum):
... spade = 4
... heart = 3
... diamond = 2
... club = 1
...
>>> GradeNum.A >= GradeNum.B
True
>>> GradeNum.A >= 3
True
>>> GradeNum.B == Suit.spade
True
>>>

I hadn'r encountered Enum before so I scanned the doc (https://docs.python.org/3/library/enum.html) ... and found OrderedEnum (section 8.13.13.2) Isn't this what you want? From the doc:
>>> class Grade(OrderedEnum):
... A = 5
... B = 4
... C = 3
... D = 2
... F = 1
...
>>> Grade.C < Grade.A
True

Combining some of the above ideas, you can subclass enum.Enum to make it comparable to string/numbers and then build your enums on this class instead:
import numbers
import enum
class EnumComparable(enum.Enum):
def __gt__(self, other):
try:
return self.value > other.value
except:
pass
try:
if isinstance(other, numbers.Real):
return self.value > other
except:
pass
return NotImplemented
def __lt__(self, other):
try:
return self.value < other.value
except:
pass
try:
if isinstance(other, numbers.Real):
return self.value < other
except:
pass
return NotImplemented
def __ge__(self, other):
try:
return self.value >= other.value
except:
pass
try:
if isinstance(other, numbers.Real):
return self.value >= other
if isinstance(other, str):
return self.name == other
except:
pass
return NotImplemented
def __le__(self, other):
try:
return self.value <= other.value
except:
pass
try:
if isinstance(other, numbers.Real):
return self.value <= other
if isinstance(other, str):
return self.name == other
except:
pass
return NotImplemented
def __eq__(self, other):
if self.__class__ is other.__class__:
return self == other
try:
return self.value == other.value
except:
pass
try:
if isinstance(other, numbers.Real):
return self.value == other
if isinstance(other, str):
return self.name == other
except:
pass
return NotImplemented

You can create a simple decorator to resolve this too:
from enum import Enum
from functools import total_ordering
def enum_ordering(cls):
def __lt__(self, other):
if type(other) == type(self):
return self.value < other.value
raise ValueError("Cannot compare different Enums")
setattr(cls, '__lt__', __lt__)
return total_ordering(cls)
#enum_ordering
class Foos(Enum):
a = 1
b = 3
c = 2
assert Names.a < Names.c
assert Names.c < Names.b
assert Names.a != Foos.a
assert Names.a < Foos.c # Will raise a ValueError
For bonus points you could implement the other methods in #VoteCoffee's answer above

for those who want to use the == with two enum instances like that: enum_instance_1 == enum_instance_2
just add the __eq__ method in your Enum class as follows:
def __eq__(self, other):
return self.__class__ is other.__class__ and other.value == self.value

Related

__eq__ order enforcement in Python

A slightly long question to sufficiently explain the background...
Assuming there's a builtin class A:
class A:
def __init__(self, a=None):
self.a = a
def __eq__(self, other):
return self.a == other.a
It's expected to compare in this way:
a1, a2 = A(1), A(2)
a1 == a2 # False
For some reason, the team introduced a wrapper on top of it (The code example doesn't actually wrap A to simplify the code complexity.)
class WrapperA:
def __init__(self, a=None):
self.pa = a
def __eq__(self, other):
return self.pa == other.pa
Again, it's expected to compare in this way:
wa1, wa2 = WrapperA(1), WrapperA(2)
wa1 == wa2 # False
Although it's expected to use either A or WrapperA, the problem is some code bases contain both usages, thus following comparison failed:
a, wa = A(), WrapperA()
wa == a # AttributeError
a == wa # AttributeError
A known solution is to modify __eq__:
For wa == a:
class WrapperA:
def __init__(self, a=None):
self.pa = a
def __eq__(self, other):
if isinstance(other, A):
return self.pa == other.a
return self.pa == other.pa
For a == wa:
class A:
def __init__(self, a=None):
self.a = a
def __eq__(self, other):
if isinstance(other, WrapperA):
return self.a == other.pa
return self.a == other.a
Modifying WrapperA is expected. For A, since it is a builtin thing, two solutions are:
Use setattr to extend A to support WrapperA.
setattr(A, '__eq__', eq_that_supports_WrapperA)
Enforce developer to only compare wa == a (And then don't care about a == wa).
1st option is obviously ugly with duplicated implementation, and 2nd gives developer unnecessary "surprise". So my question is, is there an elegant way to replace any usage of a == wa to wa == a by the Python implementation internally?
Quoting the comment from MisterMiyagi under the question:
Note that == is generally expected to work across all types. A.__eq__ requiring other to be an A is actually a bug that should be fixed. It should at the very least return NotImplemented when it cannot make a decision
This is important, not just a question of style. In fact, according to the documentation:
When a binary (or in-place) method returns NotImplemented the interpreter will try the reflected operation on the other type.
Thus if you just apply MisterMiyagi's comment and fix the logic of __eq__, you'll see your code works fine already:
class A:
def __init__(self, a=None):
self.a = a
def __eq__(self, other):
if isinstance(other, A):
return self.a == other.a
return NotImplemented
class WrapperA:
def __init__(self, a=None):
self.pa = a
def __eq__(self, other):
if isinstance(other, A):
return self.pa == other.a
elif isinstance(other, WrapperA):
return self.pa == other.pa
return NotImplemented
# Trying it
a = A(5)
wrap_a = WrapperA(5)
print(a == wrap_a)
print(wrap_a == a)
wrap_a.pa = 7
print(a == wrap_a)
print(wrap_a == a)
print(f'{wrap_a.pa=}')
Yields:
True
True
False
False
wrap_a.pa=7
Under the hood, a == wrap_a calls A.__eq__ first, which returns NotImplemented. Python then automatically tries WrapperA.__eq__ instead.
I dont really like this whole thing, since I think that wrapping a builtin and using different attribute names will lead to unexpected stuff, but anyway, this will work for you
import inspect
class A:
def __init__(self, a=None):
self.a = a
def __eq__(self, other):
return self.a == other.a
class WrapperA:
def __init__(self, a=None):
self.pa = a
def __eq__(self, other):
if isinstance(other, A):
return self.pa == other.a
return self.pa == other.pa
def __getattribute__(self, item):
# Figure out who tried to get the attribute
# If the item requested was 'a', check if A's __eq__ method called us,
# in that case return pa instead
caller = inspect.stack()[1]
if item == 'a' and getattr(caller, 'function') == '__eq__' and isinstance(caller.frame.f_locals.get('self'), A):
return super(WrapperA, self).__getattribute__('pa')
return super(WrapperA, self).__getattribute__(item)
a = A(5)
wrap_a = WrapperA(5)
print(a == wrap_a)
print(wrap_a == a)
wrap_a.pa = 7
print(a == wrap_a)
print(wrap_a == a)
print(f'{wrap_a.pa=}')
Output:
True
True
False
False
wrap_a.pa=7
Similar to Ron Serruyas answer:
This uses __getattr__ instead of __getattribute__, where the first one is only called if the second one raises an AttributeError or explicitly calls it (ref). This means if the wrapper does not implement __eq__ and the equality should only be performed on the underlying data structure (stored in objects of class A), a working example is given by:
class A(object):
def __init__(self, internal_data=None):
self._internal_data = internal_data
def __eq__(self, other):
return self._internal_data == other._internal_data
class WrapperA(object):
def __init__(self, a_object: A):
self._a = a_object
def __getattr__(self, attribute):
if attribute != '_a': # This is neccessary to prevent recursive calls
return getattr(self._a, attribute)
a1 = A(internal_data=1)
a2 = A(internal_data=2)
wa1 = WrapperA(a1)
wa2 = WrapperA(a2)
print(
a1 == a1,
a1 == a2,
wa1 == wa1,
a1 == wa1,
a2 == wa2,
wa1 == a1)
>>> True False True True True True

Generalized __eq__() method in Python

I'd like to create a generalized __eq__() method for the following Class. Basically I'd like to be able to add another property (nick) without having to change __eq__()
I imagine I can do this somehow by iterating over dir() but I wonder if there is a way to create a comprehension that just delivers the properties.
class Person:
def __init__(self, first, last):
self.first=first
self.last=last
#property
def first(self):
assert(self._first != None)
return self._first
#first.setter
def first(self,fn):
assert(isinstance(fn,str))
self._first=fn
#property
def last(self):
assert(self._last != None)
return self._last
#last.setter
def last(self,ln):
assert(isinstance(ln,str))
self._last=ln
#property
def full(self):
return f'{self.first} {self.last}'
def __eq__(self, other):
return self.first==other.first and self.last==other.last
p = Person('Raymond', 'Salemi')
p2= Person('Ray', 'Salemi')
You could use __dict__ to check if everything is the same, which scales for all attributes:
If the objects are not matching types, I simply return False.
class Person:
def __init__(self, first, last, nick):
self.first = first
self.last = last
self.nick = nick
def __eq__(self, other):
return self.__dict__ == other.__dict__ if type(self) == type(other) else False
>>> p = Person('Ray', 'Salemi', 'Ray')
>>> p2= Person('Ray', 'Salemi', 'Ray')
>>> p3 = Person('Jared', 'Salemi', 'Jarbear')
>>> p == p2
True
>>> p3 == p2
False
>>> p == 1
False
You can get all the properties of a Class with a construct like this:
from itertools import chain
#classmethod
def _properties(cls):
type_dict = dict(chain.from_iterable(typ.__dict__.items() for typ in reversed(cls.mro())))
return {k for k, v in type_dict.items() if 'property' in str(v)}
The __eq__ would become something like this:
def __eq__(self, other):
properties = self._properties() & other._properties()
if other._properties() > properties and self._properties() > properties:
# types are not comparable
return False
try:
return all(getattr(self, prop) == getattr(other, prop) for prop in properties)
except AttributeError:
return False
The reason to work with the reversed(cls.mro()) is so something like this also works:
class Worker(Person):
#property
def wage(self):
return 0
p4 = Worker('Raymond', 'Salemi')
print(p4 == p3)
True
you can try to do this, it will also work if you want eq inside dict and set
def __eq__(self, other):
"""Overrides the default implementation"""
if isinstance(self, other.__class__):
return self.__hash__() == other.__hash__()
return NotImplemented
def __hash__(self):
"""Overrides the default implementation,
and set which fieds to use for hash generation
"""
__make_hash = [
self.first
]
return hash(tuple(sorted(list(filter(None, __make_hash)))))

Implement a list wrapper with overridden __cmp__ function

I have created a new Python object as follows
class Mylist(list):
def __cmp__(self,other):
if len(self)>len(other):
return 1
elif len(self)<len(other):
return -1
elif len(self)==len(other):
return 0
my intend is, when two Mylist objects are compared the object with large number of items should be higher.
c=Mylist([4,5,6])
d=Mylist([1,2,3])
after running the above code, c and d are supposed to be equal(c==d <==True). But I am getting
>>> c==d
False
>>> c>d
True
>>>
they are being compared like the list object itself. What did I do wrong?
You need to implement function __eq__.
class Mylist(list):
def __cmp__(self,other):
if len(self)>len(other):
return 1
elif len(self)<len(other):
return -1
elif len(self)==len(other):
return 0
def __eq__(self, other):
return len(self)==len(other)
UPDATE: (previous code does not work perfectly as explained in comments)
Although #tobias_k answer explains it better, you can do it via __cmp__ function in Python 2 if you insist. You can enable it by removing other compare functions (le,lt,ge, ...):
class Mylist(list):
def __cmp__(self,other):
if len(self)>len(other):
return 1
elif len(self)<len(other):
return -1
elif len(self)==len(other):
return 0
def __eq__(self, other):
return len(self)==len(other)
#property
def __lt__(self, other): raise AttributeError()
#property
def __le__(self, other): raise AttributeError()
#property
def __ne__(self, other): raise AttributeError()
#property
def __gt__(self, other): raise AttributeError()
#property
def __ge__(self, other): raise AttributeError()
The problem seems to be that list implements all of the rich comparison operators, and __cmp__ will only be called if those are not defined. Thus, it seems like you have to overwrite all of those:
class Mylist(list):
def __lt__(self, other): return cmp(self, other) < 0
def __le__(self, other): return cmp(self, other) <= 0
def __eq__(self, other): return cmp(self, other) == 0
def __ne__(self, other): return cmp(self, other) != 0
def __gt__(self, other): return cmp(self, other) > 0
def __ge__(self, other): return cmp(self, other) >= 0
def __cmp__(self, other): return cmp(len(self), len(other))
BTW, it seems like __cmp__ was removed entirely in Python 3. The above works in Python 2.x, but for compatibility you should probably rather do it like
def __lt__(self, other): return len(self) < len(other)
Also see these two related questions. Note that while in Python 3 it would be enough to implement __eq__ and __lt__ and have Python infer the rest, this will not work in this case, since list already implements all of them, so you have to overwrite them all.

Python set intersection and __eq__

According to this page, set.intersection test for element equality using the __eq__ method. Can anyone then explain to me why this fails?
>>> Class Foo(object):
>>> def __eq__(self, other):
>>> return True
>>>
>>> set([Foo()]).intersection([Foo()])
set([])
Using 2.7.3. Is there another (not overly complex) way to do this?
If you overwrite __eq__ you should always overwrite __hash__, too.
"If a == b, then it must be the case that hash(a) == hash(b), else sets
and dictionaries will fail." Eric
__hash__ is used to generate an integer out of an object.
This is used to put the keys of a dict or the elements of sets into buckets so that one can faster find them.
If you do not overwrite __hash__, the default algorithm creates different hash-integers although the objects are equal.
In your case I would do this:
class Foo(object):
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return 1
Because all objects of your class are equal to every other object of that class they must all be in the same bucket(1) in the set. This way in returns also True.
What should __eq__ be like:
if you only compare Foo objects
def __eq__(self, other):
return self.number == other.number
if you also compare Foo objects to other objects:
def __eq__(self, other):
return type(self) == type(other) and self.number == other.number
if you have different classes with different algorithms for equal, I recommend double-dispatch.
class Foo:
def __eq__(self, other):
return hasattr(other, '_equals_foo') and other._equals_foo(self)
def _equals_foo(self, other):
return self.number == other.number
def _equals_bar(self, other):
return False # Foo never equals Bar
class Bar:
def __eq__(self, other):
return hasattr(other, '_equals_bar') and other._equals_bar(self)
def _equals_foo(self, other):
return False # Foo never equals Bar
def _equals_bar(self, other):
return True # Bar always equals Bar
This way both a and b in a == b decide what equal means.

How to create a class instance without calling initializer?

Is there any way to avoid calling __init__ on a class while initializing it, such as from a class method?
I am trying to create a case and punctuation insensitive string class in Python used for efficient comparison purposes but am having trouble creating a new instance without calling __init__.
>>> class String:
def __init__(self, string):
self.__string = tuple(string.split())
self.__simple = tuple(self.__simple())
def __simple(self):
letter = lambda s: ''.join(filter(lambda s: 'a' <= s <= 'z', s))
return filter(bool, map(letter, map(str.lower, self.__string)))
def __eq__(self, other):
assert isinstance(other, String)
return self.__simple == other.__simple
def __getitem__(self, key):
assert isinstance(key, slice)
string = String()
string.__string = self.__string[key]
string.__simple = self.__simple[key]
return string
def __iter__(self):
return iter(self.__string)
>>> String('Hello, world!')[1:]
Traceback (most recent call last):
File "<pyshell#2>", line 1, in <module>
String('Hello, world!')[1:]
File "<pyshell#1>", line 17, in __getitem__
string = String()
TypeError: __init__() takes exactly 2 positional arguments (1 given)
>>>
What should I replace string = String(); string.__string = self.__string[key]; string.__simple = self.__simple[key] with to initialize the new object with the slices?
EDIT:
As inspired by the answer written below, the initializer has been edited to quickly check for no arguments.
def __init__(self, string=None):
if string is None:
self.__string = self.__simple = ()
else:
self.__string = tuple(string.split())
self.__simple = tuple(self.__simple())
When feasible, letting __init__ get called (and make the call innocuous by suitable arguments) is preferable. However, should that require too much of a contortion, you do have an alternative, as long as you avoid the disastrous choice of using old-style classes (there is no good reason to use old-style classes in new code, and several good reasons not to)...:
class String(object):
...
bare_s = String.__new__(String)
This idiom is generally used in classmethods which are meant to work as "alternative constructors", so you'll usually see it used in ways such as...:
#classmethod
def makeit(cls):
self = cls.__new__(cls)
# etc etc, then
return self
(this way the classmethod will properly be inherited and generate subclass instances when called on a subclass rather than on the base class).
A trick the standard pickle and copy modules use is to create an empty class, instantiate the object using that, and then assign that instance's __class__ to the "real" class. e.g.
>>> class MyClass(object):
... init = False
... def __init__(self):
... print 'init called!'
... self.init = True
... def hello(self):
... print 'hello world!'
...
>>> class Empty(object):
... pass
...
>>> a = MyClass()
init called!
>>> a.hello()
hello world!
>>> print a.init
True
>>> b = Empty()
>>> b.__class__ = MyClass
>>> b.hello()
hello world!
>>> print b.init
False
But note, this approach is very rarely necessary. Bypassing the __init__ can have some unexpected side effects, especially if you're not familiar with the original class, so make sure you know what you're doing.
Using a metaclass provides a nice solution in this example. The metaclass has limited use but works fine.
>>> class MetaInit(type):
def __call__(cls, *args, **kwargs):
if args or kwargs:
return super().__call__(*args, **kwargs)
return cls.__new__(cls)
>>> class String(metaclass=MetaInit):
def __init__(self, string):
self.__string = tuple(string.split())
self.__simple = tuple(self.__simple())
def __simple(self):
letter = lambda s: ''.join(filter(lambda s: 'a' <= s <= 'z', s))
return filter(bool, map(letter, map(str.lower, self.__string)))
def __eq__(self, other):
assert isinstance(other, String)
return self.__simple == other.__simple
def __getitem__(self, key):
assert isinstance(key, slice)
string = String()
string.__string = self.__string[key]
string.__simple = self.__simple[key]
return string
def __iter__(self):
return iter(self.__string)
>>> String('Hello, world!')[1:]
<__main__.String object at 0x02E78830>
>>> _._String__string, _._String__simple
(('world!',), ('world',))
>>>
Addendum:
After six years, my opinion favors Alex Martelli's answer more than my own approach. With meta-classes still on the mind, the following answer shows how the problem can be solved both with and without them:
#! /usr/bin/env python3
METHOD = 'metaclass'
class NoInitMeta(type):
def new(cls):
return cls.__new__(cls)
class String(metaclass=NoInitMeta if METHOD == 'metaclass' else type):
def __init__(self, value):
self.__value = tuple(value.split())
self.__alpha = tuple(filter(None, (
''.join(c for c in word.casefold() if 'a' <= c <= 'z') for word in
self.__value)))
def __str__(self):
return ' '.join(self.__value)
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
return self.__alpha == other.__alpha
if METHOD == 'metaclass':
def __getitem__(self, key):
if not isinstance(key, slice):
raise NotImplementedError
instance = type(self).new()
instance.__value = self.__value[key]
instance.__alpha = self.__alpha[key]
return instance
elif METHOD == 'classmethod':
def __getitem__(self, key):
if not isinstance(key, slice):
raise NotImplementedError
instance = self.new()
instance.__value = self.__value[key]
instance.__alpha = self.__alpha[key]
return instance
#classmethod
def new(cls):
return cls.__new__(cls)
elif METHOD == 'inline':
def __getitem__(self, key):
if not isinstance(key, slice):
raise NotImplementedError
cls = type(self)
instance = cls.__new__(cls)
instance.__value = self.__value[key]
instance.__alpha = self.__alpha[key]
return instance
else:
raise ValueError('METHOD did not have an appropriate value')
def __iter__(self):
return iter(self.__value)
def main():
x = String('Hello, world!')
y = x[1:]
print(y)
if __name__ == '__main__':
main()
Pass another argument to the constructor, like so:
def __init__(self, string, simple = None):
if simple is None:
self.__string = tuple(string.split())
self.__simple = tuple(self.__simple())
else:
self.__string = string
self.__simple = simple
You can then call it like this:
def __getitem__(self, key):
assert isinstance(key, slice)
return String(self.__string[key], self.__simple[key])
Also, I'm not sure it's allowed to name both the field and the method __simple. If only for readability, you should change that.

Categories