I often need to make a class I made comparable for various reasons (sorting, set usage,...) and I don't want to write multiple comparison functions each time. How do I support the class so I only have to write a single function for every new class?
My solution to the problem is to create an abstract class a class can inherit and override the main comparison function (diff()) with the desired comparison method.
class Comparable:
'''
An abstract class that can be inherited to make a class comparable and sortable.
For proper functionality, function diff must be overridden.
'''
def diff(self, other):
"""
Calculates the difference in value between two objects and returns a number.
If the returned number is
- positive, the value of object a is greater than that of object b.
- 0, objects are equivalent in value.
- negative, value of object a is lesser than that of object b.
Used in comparison operations.
Override this function."""
return 0
def __eq__(self, other):
return self.diff(other) == 0
def __ne__(self, other):
return self.diff(other) != 0
def __lt__(self, other):
return self.diff(other) < 0
def __le__(self, other):
return self.diff(other) <= 0
def __gt__(self, other):
return self.diff(other) > 0
def __ge__(self, other):
return self.diff(other) >= 0
Related
Let say I have two following objects which I use in DataFrame
online_price = OnlinePrice(article_id=1,
date=datetime.date(2020, 11, 5),
min_sale_price=62.99,
max_sale_price=83.94,
avg_sale_price=83.74,
median_sale_price=100.00)
online_price_group = OnlinePriceGroup(title='Test OnlinePriceGroup',
user_id=1,
selection_group_id=1,
periodic_task_id=1,
min_price_multiplier=30,
calculator_method=enums.OnlineGroupCalculatorMethod.USE_LOWEST_OF_RECOMMENDATION_AND_ONLINE_PRICE,
status=enums.OnlineGroupStatus.ACTIVE,
allow_higher_than_current_prices=enums.OnlineGroupUseHigherPriceThanCurrent.ALLOW)
I want to have groupby with 'status' for example dividing ACTIVE and INACTIVE.
Gives me back this error '<' not supported between instances of 'OnlinePriceGroup' and 'OnlinePriceGroup'
Thank in advance.
This error occurs since you are using custom defined classes where you have not defined how to compare two instances of the class with eachother. How do you determine whether group 1 is larger than group 2? That logic needs to be added first
A minimum example of this can be seen below
class GroupClass:
def __init__(self, val):
self.val = val
def __gt__(self, other):
return self.val > other.val
def __ge__(self, other):
return self.val >= other.val
def __eq__(self, other):
return self.val == other.val
With the above, you would be able to run
print(GroupClass(2) > GroupClass(3))
print(GroupClass(2) >= GroupClass(3))
print(GroupClass(2) == GroupClass(3))
I'm currently stuck on a problem to write a comparator. The base idea was to write a function, which takes to parameters (two lists), but I want to use it on a list of these lists to use it in sorted() function. How shall I do it?
Comparator:
def dispersion_sort(frec, srec):
if isinstance(frec, intervals.Interval) and isinstance(srec, intervals.Interval):
if frec[DOUBLE_RES_COL] < srec[DOUBLE_RES_COL]:
return frec
if frec[DOUBLE_RES_COL] > srec[DOUBLE_RES_COL]:
return srec
if frec[DOUBLE_RES_COL].overlaps(srec[DOUBLE_RES_COL]):
if (frec[DOUBLE_TIME_COL] < srec[DOUBLE_TIME_COL]):
return frec
else:
return srec
return frec
Sample frec data:
['1', 'Mikhail Nitenko', '#login', '✅', [-0.000509228437634554,0.0007110924383354339], datetime.datetime(2020, 1, 2, 14, 46, 46)]
How I wanted to call it:
results = sorted(results, key=dispersion_sort)
Thanks a lot!
You can use functools.cmp_to_key for this:
from functools import cmp_to_key
results = sorted(results, key=cmp_to_key(dispersion_sort))
It will transform the old style comparator function (which takes two arguments), into a new style key function (which takes one argument).
If you wanted to explicitly create a comparator, you'd want to implement a custom class that has that has these magic methods:
class comparator:
def __init__(self, obj, *args):
self.obj = obj
def __lt__(self, other):
return mycmp(self.obj, other.obj) < 0
def __gt__(self, other):
return mycmp(self.obj, other.obj) > 0
def __eq__(self, other):
return mycmp(self.obj, other.obj) == 0
def __le__(self, other):
return mycmp(self.obj, other.obj) <= 0
def __ge__(self, other):
return mycmp(self.obj, other.obj) >= 0
def __ne__(self, other):
return mycmp(self.obj, other.obj) != 0
Here, the function mycmp is a function like the one you showed. You can also choose to put your logic directly in the class itself. Here, these methods should return a True or False, which is different from your current function. Make sure that is changed accordingly if you want to use the current function directly into this class template.
Once you have the class ready , you can pass it in directly: key=comparator
I have a simple python class, that I want to be able to compare. So I implemented compare operators. I then realized that I've been doing that same thing for so many classes, and it feels a lot like code duplication.
class Foo(object):
def __init__(self, index, data):
self.index = index
self.data = data
def __lt__(self, other):
return self.index < other.index
def __gt__(self, other):
return self.index > other.index
def __le__(self, other):
return self.index <= other.index
def __ge__(self, other):
return self.index >= other.index
def __eq__(self, other):
return self.index == other.index
def __ne__(self, other):
return self.index != other.index
So I think a simple solution would be something like this:
class Comparable(object):
def _compare(self, other):
raise UnimplementedError()
def __lt__(self, other):
return self._compare(other) < 0
def __gt__(self, other):
return self._compare(other) > 0
def __le__(self, other):
return self._compare(other) <= 0
def __ge__(self, other):
return self._compare(other) >= 0
def __eq__(self, other):
return self._compare(other) == 0
def __ne__(self, other):
return self._compare(other) != 0
class Foo1(Comparable):
def _compare(self, other):
return self.index - other.index
class Foo2(Comparable):
def _compare(self, other):
# ...
class Foo3(Comparable):
def _compare(self, other):
# ...
But it seems so basic, that I feel like I'm reinventing the wheel here.
I'm wondering if there a more 'native' way to achieve that.
As described in the docs you can use functools.total_ordering to save some boilerplate in writing all of the comparisons
To avoid the hassle of providing all six functions, you can implement __eq__, __ne__, and only one of the ordering operators, and use the functools.total_ordering() decorator to fill in the rest.
To be explicit, the six functions they are referring to are: __eq__, __ne__, __lt__, __le__, __gt__, and __ge__.
So, you want some automation while creating rich comparison methods. You can have this behaviour by using functools.total_ordering() higher-order function. See the reference for more details.
I have created a new Python object as follows
class Mylist(list):
def __cmp__(self,other):
if len(self)>len(other):
return 1
elif len(self)<len(other):
return -1
elif len(self)==len(other):
return 0
my intend is, when two Mylist objects are compared the object with large number of items should be higher.
c=Mylist([4,5,6])
d=Mylist([1,2,3])
after running the above code, c and d are supposed to be equal(c==d <==True). But I am getting
>>> c==d
False
>>> c>d
True
>>>
they are being compared like the list object itself. What did I do wrong?
You need to implement function __eq__.
class Mylist(list):
def __cmp__(self,other):
if len(self)>len(other):
return 1
elif len(self)<len(other):
return -1
elif len(self)==len(other):
return 0
def __eq__(self, other):
return len(self)==len(other)
UPDATE: (previous code does not work perfectly as explained in comments)
Although #tobias_k answer explains it better, you can do it via __cmp__ function in Python 2 if you insist. You can enable it by removing other compare functions (le,lt,ge, ...):
class Mylist(list):
def __cmp__(self,other):
if len(self)>len(other):
return 1
elif len(self)<len(other):
return -1
elif len(self)==len(other):
return 0
def __eq__(self, other):
return len(self)==len(other)
#property
def __lt__(self, other): raise AttributeError()
#property
def __le__(self, other): raise AttributeError()
#property
def __ne__(self, other): raise AttributeError()
#property
def __gt__(self, other): raise AttributeError()
#property
def __ge__(self, other): raise AttributeError()
The problem seems to be that list implements all of the rich comparison operators, and __cmp__ will only be called if those are not defined. Thus, it seems like you have to overwrite all of those:
class Mylist(list):
def __lt__(self, other): return cmp(self, other) < 0
def __le__(self, other): return cmp(self, other) <= 0
def __eq__(self, other): return cmp(self, other) == 0
def __ne__(self, other): return cmp(self, other) != 0
def __gt__(self, other): return cmp(self, other) > 0
def __ge__(self, other): return cmp(self, other) >= 0
def __cmp__(self, other): return cmp(len(self), len(other))
BTW, it seems like __cmp__ was removed entirely in Python 3. The above works in Python 2.x, but for compatibility you should probably rather do it like
def __lt__(self, other): return len(self) < len(other)
Also see these two related questions. Note that while in Python 3 it would be enough to implement __eq__ and __lt__ and have Python infer the rest, this will not work in this case, since list already implements all of them, so you have to overwrite them all.
According to this page, set.intersection test for element equality using the __eq__ method. Can anyone then explain to me why this fails?
>>> Class Foo(object):
>>> def __eq__(self, other):
>>> return True
>>>
>>> set([Foo()]).intersection([Foo()])
set([])
Using 2.7.3. Is there another (not overly complex) way to do this?
If you overwrite __eq__ you should always overwrite __hash__, too.
"If a == b, then it must be the case that hash(a) == hash(b), else sets
and dictionaries will fail." Eric
__hash__ is used to generate an integer out of an object.
This is used to put the keys of a dict or the elements of sets into buckets so that one can faster find them.
If you do not overwrite __hash__, the default algorithm creates different hash-integers although the objects are equal.
In your case I would do this:
class Foo(object):
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return 1
Because all objects of your class are equal to every other object of that class they must all be in the same bucket(1) in the set. This way in returns also True.
What should __eq__ be like:
if you only compare Foo objects
def __eq__(self, other):
return self.number == other.number
if you also compare Foo objects to other objects:
def __eq__(self, other):
return type(self) == type(other) and self.number == other.number
if you have different classes with different algorithms for equal, I recommend double-dispatch.
class Foo:
def __eq__(self, other):
return hasattr(other, '_equals_foo') and other._equals_foo(self)
def _equals_foo(self, other):
return self.number == other.number
def _equals_bar(self, other):
return False # Foo never equals Bar
class Bar:
def __eq__(self, other):
return hasattr(other, '_equals_bar') and other._equals_bar(self)
def _equals_foo(self, other):
return False # Foo never equals Bar
def _equals_bar(self, other):
return True # Bar always equals Bar
This way both a and b in a == b decide what equal means.