Equality and inheritance in python - python

When reading about how to implement __eq__ in python, such as in this SO question, you get recommendations like
class A(object):
def __init__(self, a, b):
self._a = a
self._b = b
def __eq__(self, other):
return (self._a, self._b) == (other._a, other._b)
Now, I'm having problems when combining this with inheritance. Specifically, if I define a new class
class B(A):
def new_method(self):
return self._a + self._b
Then I get this issue
>>> a = A(1, 2)
>>> b = B(1, 2)
>>> a == b
True
But clearly, a and b are not the (exact) same!
What is the correct way to implement __eq__ with inheritance?

If you mean that instances of different (sub)classes should not be equal, consider comparing their types as well:
def __eq__(self, other):
return (self._a, self._b, type(self)) == (other._a, other._b, type(other))
In this way A(1, 2) == B(1,2) returns False.

When performing a comparison between objects of two types where one type is derived from the other, Python ensures that the operator method is called on the derived class (in this case, B).
So to ensure that B objects compare dissimilarly to A objects, you can define __eq__ in B:
class B(A):
def __eq__(self, other):
return isinstance(other, B) and super(B, self).__eq__(other)

Related

Upcast instance to parent class in python

(How) Is it possible in Python to treat an instance of class B exactly as an instance of class A, where A is a parent of B (like up-casting in compiled languages)?
Say we have the following:
class A:
def __init__(self, prop=None):
self.prop = prop
def f(self):
return 'A.f'
def g(self):
return 'A.g', self.f()
class B(A):
def f(self):
return 'B.f'
def g(self):
return 'B.g', self.f()
Calling A().g() produces ('A.g', 'A.f'), whereas B().g() produces ('B.g', 'B.f'). Calling super(B, B()).g() produces ('A.g', 'B.f'), which is different from compiled languages, so I cannot use super. Instead, I need a function that changes the type from which an instance's methods are resolved, but preserves (a reference to) the original state. In short, I'm looking for a way to do this:
b = B(object())
a = upcast(b, A)
a.prop = object()
assert isinstance(a, A)
assert a.g() == ('A.g', 'A.f')
assert a.prop is b.prop
The closest I could get is
a = copy.copy(b)
a.__class__ = A
a.__dict__ = b.__dict__
(assuming A/B are "nice" "heap" classes), but this makes unnecessary copies of all objects in the __dict__ before I discard them. Is there a better way to do this?

Adding two distinct subclasses as their common superclass

I have a class with a lot of subclasses in my code. Consider the following code:
class Dataset:
def __init__(self, samples):
self.data = samples
def __add__(self, other):
if isinstance(other, type(self)):
return type(self)(self.data + other.data)
return NotImplemented
def __radd__(self, other):
if isinstance(other, type(self)):
return type(self)(other.data + self.data)
return NotImplemented
def __str__(self):
return str(self.data)
class DatasetA(Dataset):
def __init__(self, samples):
super().__init__(samples)
self.a = len(self.data)
def __str__(self):
return f"{self.a}: {self.data}"
class DatasetB(Dataset):
def __init__(self, samples):
super().__init__(samples)
self.b = sum(self.data)
def __str__(self):
return f"{self.data} (sum: {self.b})"
d = Dataset([1, 2, 3])
a = DatasetA([4, 5])
b = DatasetB([6, 7])
print(d + d)
print(d + a)
print(a + d)
print(d + b)
print(b + d)
print(a + a)
print(b + b)
print(a + b)
print(b + a)
The idea is to define an __add__ method in the superclass that won't require being overridden in each of the subclasses and will still add them correctly (for instance, two DatasetBs should add up to another DatasetB). This works correctly (the first 7 prints are fine), however there is one additional functionality I'd like to implement, represented in the last 2 prints.
I would like any two distinct subclasses to add up to the first common superclass. For example, a + b should result in a Dataset instance. Also if we add up instances of two subclasses of DatasetB, the result should be a DatasetB instance.
I've tried changing return NotImplemented to return super().__add__(other) (and similar for __radd__), but that resulted in an AttributeError on the 3rd print statement already.
Is there a way to implement this desired functionality without breaking the existing one (i.e. still have the first 7 prints execute properly), and without explicitly having to override __add__ in each of the subclasses?
You could change:
return type(self)(self.data + other.data)
to:
klass = next(c for c in type(self).mro() if c in type(other).mro())
return klass(self.data + other.data)
This will pick the most specific superclass the two have in common, leveraging the method resolution order.

I want to give an input of adding two numbers but it should return multiplication of those two numbers using operator overloading

I am a newbie. I want to use operator overloading which gives 3+4 but returns answer of 3*4
I have made a class and passed two functions add and mul
class A:
def __init__(self, a,b):
self.a = a
self.b = b
# adding two objects
def __add__(self, other):
return self.a + other.a , self.b + other.b
# multiply two objects
def __mul__(self, other):
return self.a * other.a , self.b +other.b
ob1 = A(1)
ob2 = A(2)
ob3 = ob1+ob2
ob4 = ob1*ob2
print(ob3)
print(ob4)
Expected: input 3 and 4 , it should show 3+4 but return 3*4
In your __mul__ and __add__ methods, you need to return an instance of A, not just some values (unless you are doing in place operations). It seems like you only want to add 2 numbers together, so maybe you should try having only 1 parameter __init__:
class A:
def __init__(self, a):
self.a = a
def __add__(self, other):
return A(self.a * other.a)
Now when you do:
A(3) + A(2)
You are getting back 2*3 as the __add__ method returns a new instance of A whose .a attribute is the product, not sum, of the given two.
You should also consider type checking or error handling as your next step. What if I typed:
A(2) + 10 # not A(10)
Would an error be raised? That’s up to you. The easiest way to cause an error to be raised if you return NotImplemented from the function. This method also allows polymorphism to take place where any object with a .a attribute will work (as long as it is something that can be multiplied).
...
def __add__(self, other):
try:
return A(self.a * other.a)
except Exception:
return NotImplemented

Python set intersection and __eq__

According to this page, set.intersection test for element equality using the __eq__ method. Can anyone then explain to me why this fails?
>>> Class Foo(object):
>>> def __eq__(self, other):
>>> return True
>>>
>>> set([Foo()]).intersection([Foo()])
set([])
Using 2.7.3. Is there another (not overly complex) way to do this?
If you overwrite __eq__ you should always overwrite __hash__, too.
"If a == b, then it must be the case that hash(a) == hash(b), else sets
and dictionaries will fail." Eric
__hash__ is used to generate an integer out of an object.
This is used to put the keys of a dict or the elements of sets into buckets so that one can faster find them.
If you do not overwrite __hash__, the default algorithm creates different hash-integers although the objects are equal.
In your case I would do this:
class Foo(object):
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return 1
Because all objects of your class are equal to every other object of that class they must all be in the same bucket(1) in the set. This way in returns also True.
What should __eq__ be like:
if you only compare Foo objects
def __eq__(self, other):
return self.number == other.number
if you also compare Foo objects to other objects:
def __eq__(self, other):
return type(self) == type(other) and self.number == other.number
if you have different classes with different algorithms for equal, I recommend double-dispatch.
class Foo:
def __eq__(self, other):
return hasattr(other, '_equals_foo') and other._equals_foo(self)
def _equals_foo(self, other):
return self.number == other.number
def _equals_bar(self, other):
return False # Foo never equals Bar
class Bar:
def __eq__(self, other):
return hasattr(other, '_equals_bar') and other._equals_bar(self)
def _equals_foo(self, other):
return False # Foo never equals Bar
def _equals_bar(self, other):
return True # Bar always equals Bar
This way both a and b in a == b decide what equal means.

How does a Python set([]) check if two objects are equal? What methods does an object need to define to customise this?

I need to create a 'container' object or class in Python, which keeps a record of other objects which I also define. One requirement of this container is that if two objects are deemed to be identical, one (either one) is removed. My first thought was to use a set([]) as the containing object, to complete this requirement.
However, the set does not remove one of the two identical object instances. What must I define to create one?
Here is the Python code.
class Item(object):
def __init__(self, foo, bar):
self.foo = foo
self.bar = bar
def __repr__(self):
return "Item(%s, %s)" % (self.foo, self.bar)
def __eq__(self, other):
if isinstance(other, Item):
return ((self.foo == other.foo) and (self.bar == other.bar))
else:
return False
def __ne__(self, other):
return (not self.__eq__(other))
Interpreter
>>> set([Item(1,2), Item(1,2)])
set([Item(1, 2), Item(1, 2)])
It is clear that __eq__(), which is called by x == y, is not the method called by the set. What is called? What other method must I define?
Note: The Items must remain mutable, and can change, so I cannot provide a __hash__() method. If this is the only way of doing it, then I will rewrite for use of immutable Items.
Yes, you need a __hash__()-method AND the comparing-operator which you already provided.
class Item(object):
def __init__(self, foo, bar):
self.foo = foo
self.bar = bar
def __repr__(self):
return "Item(%s, %s)" % (self.foo, self.bar)
def __eq__(self, other):
if isinstance(other, Item):
return ((self.foo == other.foo) and (self.bar == other.bar))
else:
return False
def __ne__(self, other):
return (not self.__eq__(other))
def __hash__(self):
return hash(self.__repr__())
I am afraid you will have to provide a __hash__() method. But you can code it the way, that it does not depend on the mutable attributes of your Item.

Categories