I am looking for the most efficient way to represent small sets of integers in a given range (say 0-10) in Python. In this case, efficiency means fast construction (from an unsorted list), fast query (a couple of queries on each set), and reasonably fast construction of a sorted version (perhaps once per ten sets or so). A priori the candidates are using Python's builtin set type (fast query), using a sorted array (perhaps faster to constrct?), or using a bit-array (fast everything if I was in C... but I doubt Python will be that efficient (?)). Any advice of which one to choose?
Thanks.
I'd use a bitmapping and store the members of a "set" in an int...which might actually be faster than the built-in set type in this case -- although I haven't tested that. It would definitely require less storage.
Update
I don't have the time right now to do a full set-like implementation and benchmark it against Python's built-in class, but here's what I believe is a working example illustrating my suggestion. As I think you'd agree, the code looks fairly fast as well as memory efficient.
Given Python's almost transparent "unlimited" long integer capabilities, what is written will automatically work with integer values in a much larger range than you need, although doing so would likely slow things down a bit. ;)
class BitSet(object):
def __init__(self, *bitlist):
self._bitmap = 0
for bitnum in bitlist:
self._bitmap |= (1 << bitnum)
def add(self, bitnum):
self._bitmap |= (1 << bitnum)
def remove(self, bitnum):
if self._bitmap & (1 << bitnum):
self._bitmap &= ~(1 << bitnum)
else:
raise KeyError
def discard(self, bitnum):
self._bitmap &= ~(1 << bitnum)
def clear(self):
self._bitmap = 0
def __contains__(self, bitnum):
return bool(self._bitmap & (1 << bitnum))
def __int__(self):
return self._bitmap
if __name__ == '__main__':
bs = BitSet()
print '28 in bs:', 28 in bs
print 'bs.add(28)'
bs.add(28)
print '28 in bs:', 28 in bs
print
print '5 in bs:', 5 in bs
print 'bs.add(5)'
bs.add(5)
print '5 in bs:', 5 in bs
print
print 'bs.remove(28)'
bs.remove(28)
print '28 in bs:', 28 in bs
In this case you might just use a list of True/False values. The hash table used by set will be doing the same thing, but it will include overhead for hashing, bucket assignment, and collision detection.
myset = [False] * 11
for i in values:
myset[i] = True
mysorted = [i for i in range(11) if myset[i]]
As always you need to time it yourself to know how it works in your circumstances.
My advice is to stick with the built-in set(). It will be very difficult to write Python code that beats the built-in C code for performance. Speed of construction and speed of lookup will be fastest if you are relying on the built-in C code.
For a sorted list, your best bet is to use the built-in sort feature:
x = set(seq) # build set from some sequence
lst = sorted(x) # get sorted list from set
In general, in Python, the less code you write, the faster it is. The more you can rely on the built-in C underpinnings of Python, the faster. Interpreted Python is 20x to 100x slower than C code in many cases, and it is extremely hard to be so clever that you come out ahead vs. just using the built-in features as intended.
If your sets are guaranteed to always be integers in the range of [0, 10], and you want to make sure the memory footprint is as small as possible, then bit-flags inside an integer would be the way to go.
pow2 = [2**i for i in range(32)]
x = 0 # set with no values
def add_to_int_set(x, n):
return x | pow2[n]
def in_int_set(x, n):
return x & pow2[n]
def list_from_int_set(x):
return [i for i in range(32) if x & pow2[i]]
I'll bet this is actually slower than using the built-in set() functions, but you know that each set will just be an int object: 4 bytes, plus the overhead of a Python object.
If you literally needed billions of them, you could save space by using a NumPy array instead of a Python list; the NumPy array will just store bare integers. In fact, NumPy has a 16-bit integer type, so if your sets are really only in the range of [0, 10] you could get the storage size down to two bytes each using a NumPy array.
http://www.scipy.org/FAQ#head-16a621f03792969969e44df8a9eb360918ce9613
Even for small collections, 'contains' checks turn out quite a bit faster with sets.
>>> Timer("3 in values", 'values = [range(10)]').timeit(number = 10**7)
0.5200109481811523
>>> Timer("3 in values", 'values = set(range(10))').timeit(number = 10**7)
0.2755239009857178
On the other hand, as you've indicated, constructing a set takes a little bit longer.
>>> Timer("set(range(10))").timeit(number = 10**7)
5.87517786026001
>>> Timer("list(range(10))").timeit(number = 10**7)
4.129410028457642
There are also some differences when sorting:
>>> Timer("sorted(values)", 'values = set(range(10, 0, -1))').timeit(number = 10**7)
5.277467966079712
>>> Timer("sorted(values)", 'values = list(range(10, 0, -1))').timeit(number = 10**7)
4.3836448192596436
>>> Timer("values.sort()", 'values = list(range(10, 0, -1))').timeit(number = 10**7)
2.073429822921753
Sorting in-place is significantly faster and is only available for lists.
So if you're only doing a small amount of queries per collection, lists are more performant. When doing a lot of queries, I'd go with sets.
In either case, the difference between small collections is small.
Building your own collection type in Python for better performance is not recommended.
Related
I saw a video about speed of loops in python, where it was explained that doing sum(range(N)) is much faster than manually looping through range and adding the variables together, since the former runs in C due to built-in functions being used, while in the latter the summation is done in (slow) python. I was curious what happens when adding numpy to the mix. As I expected np.sum(np.arange(N)) is the fastest, but sum(np.arange(N)) and np.sum(range(N)) are even slower than doing the naive for loop.
Why is this?
Here's the script I used to test, some comments about the supposed cause of slowing done where I know (taken mostly from the video) and the results I got on my machine (python 3.10.0, numpy 1.21.2):
updated script:
import numpy as np
from timeit import timeit
N = 10_000_000
repetition = 10
def sum0(N = N):
s = 0
i = 0
while i < N: # condition is checked in python
s += i
i += 1 # both additions are done in python
return s
def sum1(N = N):
s = 0
for i in range(N): # increment in C
s += i # addition in python
return s
def sum2(N = N):
return sum(range(N)) # everything in C
def sum3(N = N):
return sum(list(range(N)))
def sum4(N = N):
return np.sum(range(N)) # very slow np.array conversion
def sum5(N = N):
# much faster np.array conversion
return np.sum(np.fromiter(range(N),dtype = int))
def sum5v2_(N = N):
# much faster np.array conversion
return np.sum(np.fromiter(range(N),dtype = np.int_))
def sum6(N = N):
# possibly slow conversion to Py_long from np.int
return sum(np.arange(N))
def sum7(N = N):
# list returns a list of np.int-s
return sum(list(np.arange(N)))
def sum7v2(N = N):
# tolist conversion to python int seems faster than the implicit conversion
# in sum(list()) (tolist returns a list of python int-s)
return sum(np.arange(N).tolist())
def sum8(N = N):
return np.sum(np.arange(N)) # everything in numpy (fortran libblas?)
def sum9(N = N):
return np.arange(N).sum() # remove dispatch overhead
def array_basic(N = N):
return np.array(range(N))
def array_dtype(N = N):
return np.array(range(N),dtype = np.int_)
def array_iter(N = N):
# np.sum's source code mentions to use fromiter to convert from generators
return np.fromiter(range(N),dtype = np.int_)
print(f"while loop: {timeit(sum0, number = repetition)}")
print(f"for loop: {timeit(sum1, number = repetition)}")
print(f"sum_range: {timeit(sum2, number = repetition)}")
print(f"sum_rangelist: {timeit(sum3, number = repetition)}")
print(f"npsum_range: {timeit(sum4, number = repetition)}")
print(f"npsum_iterrange: {timeit(sum5, number = repetition)}")
print(f"npsum_iterrangev2: {timeit(sum5, number = repetition)}")
print(f"sum_arange: {timeit(sum6, number = repetition)}")
print(f"sum_list_arange: {timeit(sum7, number = repetition)}")
print(f"sum_arange_tolist: {timeit(sum7v2, number = repetition)}")
print(f"npsum_arange: {timeit(sum8, number = repetition)}")
print(f"nparangenpsum: {timeit(sum9, number = repetition)}")
print(f"array_basic: {timeit(array_basic, number = repetition)}")
print(f"array_dtype: {timeit(array_dtype, number = repetition)}")
print(f"array_iter: {timeit(array_iter, number = repetition)}")
print(f"npsumarangeREP: {timeit(lambda : sum8(N/1000), number = 100000*repetition)}")
print(f"npsumarangeREP: {timeit(lambda : sum9(N/1000), number = 100000*repetition)}")
# Example output:
#
# while loop: 11.493371912998555
# for loop: 7.385945574002108
# sum_range: 2.4605720699983067
# sum_rangelist: 4.509678105998319
# npsum_range: 11.85120212900074
# npsum_iterrange: 4.464334709002287
# npsum_iterrangev2: 4.498494338993623
# sum_arange: 9.537815956995473
# sum_list_arange: 13.290120724996086
# sum_arange_tolist: 5.231948580003518
# npsum_arange: 0.241889145996538
# nparangenpsum: 0.21876695199898677
# array_basic: 11.736577274998126
# array_dtype: 8.71628468400013
# array_iter: 4.303306431000237
# npsumarangeREP: 21.240833958996518
# npsumarangeREP: 16.690092379001726
np.sum(range(N)) is slow mostly because the current Numpy implementation do not use enough informations about the exact type/content of the values provided by the generator range(N). The heart of the general problem is inherently due to dynamic typing of Python and big integers although Numpy could optimize this specific case.
First of all, range(N) returns a dynamically-typed Python object which is a (special kind of) Python generator. The object provided by this generator are also dynamically-typed. It is in practice a pure-Python integer.
The thing is Numpy is written in the statically-typed language C and so it cannot efficiently work on dynamically-typed pure-Python objects. The strategy of Numpy is to convert such objects into C types when it can. One big problem in this case is that the integers provided by the generator can theorically be huge: Numpy do not know if the values can overflow a np.int32 or even a np.int64 type. Thus, Numpy first detect the good type to use and then compute the result using this type.
This translation process can be quite expensive and appear not to be needed here since all the values provided by range(10_000_000). However, range(5_000_000_000) returns the same object type with pure-Python integers overflowing np.int32 and Numpy needs to automatically detect this case not to return wrong results. The thing is also the input type can be correctly identified (np.int32 on my machine), it does not means that the output result will be correct because overflows can appear in during the computation of the sum. This is sadly the case on my machine.
Numpy developers decided to deprecate such a use and put in the documentation that np.fromiter should be used instead. np.fromiter has a dtype required parameter to let the user define what is the good type to use.
One way to check this behaviour in practice is to simply use create a temporary list:
tmp = list(range(10_000_000))
# Numpy implicitly convert the list in a Numpy array but
# still automatically detect the input type to use
np.sum(tmp)
A faster implementation is the following:
tmp = list(range(10_000_000))
# The array is explicitly converted using a well-defined type and
# thus there is no need to perform an automatic detection
# (note that the result is still wrong since it does not fit in a np.int32)
tmp2 = np.array(tmp, dtype=np.int32)
result = np.sum(tmp2)
The first case takes 476 ms on my machine while the second takes 289 ms. Note that np.sum takes only 4 ms. Thus, a large part of the time is spend in the conversion of pure-Python integer objects to internal int32 types (more specifically the management of pure-Python integers). list(range(10_000_000)) is expensive too as it takes 205 ms. This is again due to the overhead of pure-Python integers (ie. allocations, deallocations, reference counting, increment of variable-sized integers, memory indirections and conditions due to the dynamic typing) as well as the overhead of the generator.
sum(np.arange(N)) is slow because sum is a pure-Python function working on a Numpy-defined object. The CPython interpreter needs to call Numpy functions to perform basic additions. Moreover, Numpy-defined integer object are still Python object and so they are subject to reference counting, allocation, deallocation, etc. Not to mention Numpy and CPython add many checks in the functions aiming to finally just add two native numbers together. A Numpy-aware just-in-time compiler such as Numba can solve this issue. Indeed, Numba takes 23 ms on my machine to compute the sum of np.arange(10_000_000) (with code still written in Python) while the CPython interpreter takes 556 ms.
Let's see if I can summarize the results.
sum can work with any iterable, repeatedly asking for the next value and adding it. range is a generator, that's happy to supply the next value
# sum_range: 1.4830789409988938
Making a list from a range takes time:
# sum_rangelist: 3.6745876889999636
Summing a pregenerated list is actually faster than summing the range:
%%timeit x = list(range(N))
...: sum(x)
np.sum is designed to sum arrays. It's a wrapper to np.add.reduce.
np.sum has a deprecation warning for np.sum(generator), recommending the use of fromiter or Python sum:
# npsum_range: 16.216972655000063
fromiter is the best way of making an array from a generator. Using np.array on range is legacy code and may go away in the future. I think it's the only generator that np.array will accept.
np.array is a general purpose function that can handle many cases, including nested arrays, and conversion to various dtypes. As such it has to process the whole input argument, deducing both shape and dtype.
# npsum_fromiterrange:3.47655400199983
Iteration on a numpy array is slower than a list, since it has to "unbox" each element.
# sum_arange: 16.656015603000924
Similarly making a list from an array is slow; same sort of python level iteration.
# sum_list_arange: 19.500842117000502
arr.tolist() is relatively fast, creating a pure python list in compiled code. So speed is similar to making a list from range.
# sum_arange_tolist: 4.004777374000696
np.sum of an array is pure numpy and quite fast. np.sum(x) where x=np.arange(N) is even faster (by about 4x)
# npsum_arange: 0.2332638230000157
np.sum from range or list is dominated by the cost of creating the array first:
# array_basic: 16.1631146109994
# array_dtype: 16.550737804000164
# array_iter: 3.9803170430004684
From the cpython source code for sum sum initially seems to attempt a fast path that assumes all inputs are the same type. If that fails it will just iterate:
/* Fast addition by keeping temporary sums in C instead of new Python objects.
Assumes all inputs are the same type. If the assumption fails, default
to the more general routine.
*/
I'm not entirely certain what is happening under the hood, but it is likely the repeated creation/conversion of C types to Python objects that is causing these slow-downs. It's worth noting that both sum and range are implemented in C.
This next bit is not really an answer to the question, but I wondered if we could speed up sum for python ranges as range is quite a smart object.
To do this I've used functools.singledispatch to override the built-in sum function specifically for the range type; then implemented a small function to calculate the sum of an arithmetic progression.
from functools import singledispatch
def sum_range(range_, /, start=0):
"""Overloaded `sum` for range, compute arithmetic sum"""
n = len(range_)
if not n:
return start
return int(start + (n * (range_[0] + range_[-1]) / 2))
sum = singledispatch(sum)
sum.register(range, sum_range)
def test():
"""
>>> sum(range(0, 100))
4950
>>> sum(range(0, 10, 2))
20
>>> sum(range(0, 9, 2))
20
>>> sum(range(0, -10, -1))
-45
>>> sum(range(-10, 10))
-10
>>> sum(range(-1, -100, -2))
-2500
>>> sum(range(0, 10, 100))
0
>>> sum(range(0, 0))
0
>>> sum(range(0, 100), 50)
5000
>>> sum(range(0, 0), 10)
10
"""
if __name__ == "__main__":
import doctest
doctest.testmod()
I'm not sure if this is complete, but it's definitely faster than looping.
Closed. This question is opinion-based. It is not currently accepting answers.
Closed 4 years ago.
Locked. This question and its answers are locked because the question is off-topic but has historical significance. It is not currently accepting new answers or interactions.
Python has a built in function sum, which is effectively equivalent to:
def sum2(iterable, start=0):
return start + reduce(operator.add, iterable)
for all types of parameters except strings. It works for numbers and lists, for example:
sum([1,2,3], 0) = sum2([1,2,3],0) = 6 #Note: 0 is the default value for start, but I include it for clarity
sum({888:1}, 0) = sum2({888:1},0) = 888
Why were strings specially left out?
sum( ['foo','bar'], '') # TypeError: sum() can't sum strings [use ''.join(seq) instead]
sum2(['foo','bar'], '') = 'foobar'
I seem to remember discussions in the Python list for the reason, so an explanation or a link to a thread explaining it would be fine.
Edit: I am aware that the standard way is to do "".join. My question is why the option of using sum for strings was banned, and no banning was there for, say, lists.
Edit 2: Although I believe this is not needed given all the good answers I got, the question is: Why does sum work on an iterable containing numbers or an iterable containing lists but not an iterable containing strings?
Python tries to discourage you from "summing" strings. You're supposed to join them:
"".join(list_of_strings)
It's a lot faster, and uses much less memory.
A quick benchmark:
$ python -m timeit -s 'import operator; strings = ["a"]*10000' 'r = reduce(operator.add, strings)'
100 loops, best of 3: 8.46 msec per loop
$ python -m timeit -s 'import operator; strings = ["a"]*10000' 'r = "".join(strings)'
1000 loops, best of 3: 296 usec per loop
Edit (to answer OP's edit): As to why strings were apparently "singled out", I believe it's simply a matter of optimizing for a common case, as well as of enforcing best practice: you can join strings much faster with ''.join, so explicitly forbidding strings on sum will point this out to newbies.
BTW, this restriction has been in place "forever", i.e., since the sum was added as a built-in function (rev. 32347)
You can in fact use sum(..) to concatenate strings, if you use the appropriate starting object! Of course, if you go this far you have already understood enough to use "".join(..) anyway..
>>> class ZeroObject(object):
... def __add__(self, other):
... return other
...
>>> sum(["hi", "there"], ZeroObject())
'hithere'
Here's the source: http://svn.python.org/view/python/trunk/Python/bltinmodule.c?revision=81029&view=markup
In the builtin_sum function we have this bit of code:
/* reject string values for 'start' parameter */
if (PyObject_TypeCheck(result, &PyBaseString_Type)) {
PyErr_SetString(PyExc_TypeError,
"sum() can't sum strings [use ''.join(seq) instead]");
Py_DECREF(iter);
return NULL;
}
Py_INCREF(result);
}
So.. that's your answer.
It's explicitly checked in the code and rejected.
From the docs:
The preferred, fast way to concatenate a
sequence of strings is by calling
''.join(sequence).
By making sum refuse to operate on strings, Python has encouraged you to use the correct method.
Short answer: Efficiency.
Long answer: The sum function has to create an object for each partial sum.
Assume that the amount of time required to create an object is directly proportional to the size of its data. Let N denote the number of elements in the sequence to sum.
doubles are always the same size, which makes sum's running time O(1)×N = O(N).
int (formerly known as long) is arbitary-length. Let M denote the absolute value of the largest sequence element. Then sum's worst-case running time is lg(M) + lg(2M) + lg(3M) + ... + lg(NM) = N×lg(M) + lg(N!) = O(N log N).
For str (where M = the length of the longest string), the worst-case running time is M + 2M + 3M + ... + NM = M×(1 + 2 + ... + N) = O(N²).
Thus, summing strings would be much slower than summing numbers.
str.join does not allocate any intermediate objects. It preallocates a buffer large enough to hold the joined strings, and copies the string data. It runs in O(N) time, much faster than sum.
The Reason Why
#dan04 has an excellent explanation for the costs of using sum on large lists of strings.
The missing piece as to why str is not allowed for sum is that many, many people were trying to use sum for strings, and not many use sum for lists and tuples and other O(n**2) data structures. The trap is that sum works just fine for short lists of strings, but then gets put in production where the lists can be huge, and the performance slows to a crawl. This was such a common trap that the decision was made to ignore duck-typing in this instance, and not allow strings to be used with sum.
Edit: Moved the parts about immutability to history.
Basically, its a question of preallocation. When you use a statement such as
sum(["a", "b", "c", ..., ])
and expect it to work similar to a reduce statement, the code generated looks something like
v1 = "" + "a" # must allocate v1 and set its size to len("") + len("a")
v2 = v1 + "b" # must allocate v2 and set its size to len("a") + len("b")
...
res = v10000 + "$" # must allocate res and set its size to len(v9999) + len("$")
In each of these steps a new string is created, which for one might give some copying overhead as the strings are getting longer and longer. But that’s maybe not the point here. What’s more important, is that every new string on each line must be allocated to it’s specific size (which. I don’t know it it must allocate in every iteration of the reduce statement, there might be some obvious heuristics to use and Python might allocate a bit more here and there for reuse – but at several points the new string will be large enough that this won’t help anymore and Python must allocate again, which is rather expensive.
A dedicated method like join, however has the job to figure out the real size of the string before it starts and would therefore in theory only allocate once, at the beginning and then just fill that new string, which is much cheaper than the other solution.
I dont know why, but this works!
import operator
def sum_of_strings(list_of_strings):
return reduce(operator.add, list_of_strings)
This question already has answers here:
Time complexity of python set operations?
(3 answers)
Closed 3 years ago.
Say we add a group of long strings to a hashset, and then test if some string already exists in this hashset. Is the time complexity going to be constant for adding and retrieving operations? Or does it depend on the length of the strings?
For example, if we have three strings.
s1 = 'abcdefghijklmn'
s2 = 'dalkfdboijaskjd'
s3 = 'abcdefghijklmn'
Then we do:
pool = set()
pool.add(s1)
pool.add(s2)
print s3 in pool # => True
print 'zzzzzzzzzz' in pool # => False
Would time complexity of the above operations be a factor of the string length?
Another question is that what if we are hashing a tuple? Something like (1,2,3,4,5,6,7,8,9)?
I appreciate your help!
==================================
I understand that there are resources around like this one that is talking about why hashing is constant time and collision issues. However, they usually assumed that the length of the key can be neglected. This question asks if hashing still has constant time when the key has a length that cannot be neglected. For example, if we are to judge N times if a key of length K is in the set, is the time complexity O(N) or O(N*K).
One of the best ways to answer something like this is to dig into the implementation :)
Notwithstanding some of that optimization magic described in the header of setobject.c, adding an object into a set reuses hashes from strings where hash() has already been once called (recall, strings are immutable), or calls the type's hash implementation.
For Unicode/bytes objects, we end up via here to _Py_HashBytes, which seems to have an optimization for small strings, otherwise it uses the compile-time configured hash function, all of which naturally are somewhat O(n)-ish. But again, this seems to only happen once per string object.
For tuples, the hash implementation can be found here – apparently a simplified, non-cached xxHash.
However, once the hash has been computed, the time complexity for sets should be around O(1).
EDIT: A quick, not very scientific benchmark:
import time
def make_string(c, n):
return c * n
def make_tuple(el, n):
return (el,) * n
def hashtest(gen, n):
# First compute how long generation alone takes
gen_time = time.perf_counter()
for x in range(n):
gen()
gen_time = time.perf_counter() - gen_time
# Then compute how long hashing and generation takes
hash_and_gen_time = time.perf_counter()
for x in range(n):
hash(gen())
hash_and_gen_time = time.perf_counter() - hash_and_gen_time
# Return the two
return (hash_and_gen_time, gen_time)
for gen in (make_string, make_tuple):
for obj_length in (10000, 20000, 40000):
t = f"{gen.__name__} x {obj_length}"
# Using `b'hello'.decode()` here to avoid any cached hash shenanigans
hash_and_gen_time, gen_time = hashtest(
lambda: gen(b"hello".decode(), obj_length), 10000
)
hash_time = hash_and_gen_time - gen_time
print(t, hash_time, obj_length / hash_time)
outputs
make_string x 10000 0.23490356100000004 42570.66158311665
make_string x 20000 0.47143921999999994 42423.284172241765
make_string x 40000 0.942087403 42458.905482254915
make_tuple x 10000 0.45578034300000025 21940.393335480014
make_tuple x 20000 0.9328520900000008 21439.62608263008
make_tuple x 40000 1.8562772150000004 21548.505620158674
which basically says hashing sequences, be they strings or tuples, is linear time, yet hashing strings is a lot faster than hashing tuples.
EDIT 2: this proves strings and bytestrings cache their hashes:
import time
s = ('x' * 500_000_000)
t0 = time.perf_counter()
a = hash(s)
t1 = time.perf_counter()
print(t1 - t0)
t0 = time.perf_counter()
b = hash(s)
t2 = time.perf_counter()
assert a == b
print(t2 - t0)
outputs
0.26157095399999997
1.201999999977943e-06
Strictly speaking it depends on the implementation of the hash set and the way you're using it (there may be cleverness that will optimize some of the time away in specialized circumstances), but in general, yes, you should expect that it will take O(n) time to hash a key to do an insert or lookup where n is the size of the key. Usually hash sets are assumed to be O(1), but there's an implicit assumption there that the keys are of fixed size and that hashing them is a O(1) operation (in other words, there's an assumption that the key size is negligible compared to the number of items in the set).
Optimizing the storage and retrieval of really big chunks of data is why databases are a thing. :)
Average case is O(1).
However, the worst case is O(n), with n being the number of elements in the set. This case is caused by hashing collisions.
you can read more about it in here
https://www.geeksforgeeks.org/internal-working-of-set-in-python/
Wiki is your friend
https://wiki.python.org/moin/TimeComplexity
for the operations above it seems that they are all O(1) for a set
I've been immensely frustrated with many of the implementations of python radix sort out there on the web.
They consistently use a radix of 10 and get the digits of the numbers they iterate over by dividing by a power of 10 or taking the log10 of the number. This is incredibly inefficient, as log10 is not a particularly quick operation compared to bit shifting, which is nearly 100 times faster!
A much more efficient implementation uses a radix of 256 and sorts the number byte by byte. This allows for all of the 'byte getting' to be done using the ridiculously quick bit operators. Unfortunately, it seems that absolutely nobody out there has implemented a radix sort in python that uses bit operators instead of logarithms.
So, I took matters into my own hands and came up with this beast, which runs at about half the speed of sorted on small arrays and runs nearly as quickly on larger ones (e.g. len around 10,000,000):
import itertools
def radix_sort(unsorted):
"Fast implementation of radix sort for any size num."
maximum, minimum = max(unsorted), min(unsorted)
max_bits = maximum.bit_length()
highest_byte = max_bits // 8 if max_bits % 8 == 0 else (max_bits // 8) + 1
min_bits = minimum.bit_length()
lowest_byte = min_bits // 8 if min_bits % 8 == 0 else (min_bits // 8) + 1
sorted_list = unsorted
for offset in xrange(lowest_byte, highest_byte):
sorted_list = radix_sort_offset(sorted_list, offset)
return sorted_list
def radix_sort_offset(unsorted, offset):
"Helper function for radix sort, sorts each offset."
byte_check = (0xFF << offset*8)
buckets = [[] for _ in xrange(256)]
for num in unsorted:
byte_at_offset = (num & byte_check) >> offset*8
buckets[byte_at_offset].append(num)
return list(itertools.chain.from_iterable(buckets))
This version of radix sort works by finding which bytes it has to sort by (if you pass it only integers below 256, it'll sort just one byte, etc.) then sorting each byte from LSB up by dumping them into buckets in order then just chaining the buckets together. Repeat this for each byte that needs to be sorted and you have your nice sorted array in O(n) time.
However, it's not as fast as it could be, and I'd like to make it faster before I write about it as a better radix sort than all the other radix sorts out there.
Running cProfile on this tells me that a lot of time is being spent on the append method for lists, which makes me think that this block:
for num in unsorted:
byte_at_offset = (num & byte_check) >> offset*8
buckets[byte_at_offset].append(num)
in radix_sort_offset is eating a lot of time. This is also the block that, if you really look at it, does 90% of the work for the whole sort. This code looks like it could be numpy-ized, which I think would result in quite a performance boost. Unfortunately, I'm not very good with numpy's more complex features so haven't been able to figure that out. Help would be very appreciated.
I'm currently using itertools.chain.from_iterable to flatten the buckets, but if anyone has a faster suggestion I'm sure it would help as well.
Originally, I had a get_byte function that returned the nth byte of a number, but inlining the code gave me a huge speed boost so I did it.
Any other comments on the implementation or ways to squeeze out more performance are also appreciated. I want to hear anything and everything you've got.
You already realized that
for num in unsorted:
byte_at_offset = (num & byte_check) >> offset*8
buckets[byte_at_offset].append(num)
is where most of the time goes - good ;-)
There are two standard tricks for speeding that kind of thing, both having to do with moving invariants out of loops:
Compute "offset*8" outside the loop. Store it in a local variable. Save a multiplication per iteration.
Add bucketappender = [bucket.append for bucket in buckets] outside the loop. Saves a method lookup per iteration.
Combine them, and the loop looks like:
for num in unsorted:
bucketappender[(num & byte_check) >> ofs8](num)
Collapsing it to one statement also saves a pair of local vrbl store/fetch opcodes per iteration.
But, at a higher level, the standard way to speed radix sort is to use a larger radix. What's magical about 256? Nothing, apart from that it's convenient for bit-shifting. But so are 512, 1024, 2048 ... it's a classical time/space tradeoff.
PS: for very long numbers,
(num >> offset*8) & 0xff
will run faster. That's because your num & byte_check takes time proportional to log(num) - it generally has to create an integer about as big as num.
This is an old thread, but I came across this when looking to radix sort an array of positive integers. I was trying to see if I can do any better than the already wickedly fast timsort (hats off to you again, Tim Peters) which implements python's builtin sorted and sort! Either I don't understand certain aspects of the above code, or if I do, the code as presented above has some problems IMHO.
It only sorts bytes starting with the highest byte of the smallest item and ending with the highest byte of the biggest item. This may be okay in some cases of special data. But in general the approach fails to differentiate items which differ on account of the lower bits. For example:
arr=[65535,65534]
radix_sort(arr)
produces the wrong output:
[65535, 65534]
The range used to loop over the helper function is not correct. What I mean is that if lowest_byte and highest_byte are the same, execution of the helper function is altogether skipped. BTW I had to change xrange to range in 2 places.
With modifications to address the above 2 points, I got it to work. But it is taking 10-20 times the time of python's builtin sorted or sort! I know timsort is very efficient and takes advantage of already sorted runs in the data. But I was trying to see if I can use the prior knowledge that my data is all positive integers to some advantage in my sorting. Why is the radix sort doing so badly compared to timsort? The array sizes I was using are in the order of 80K items. Is it because the timsort implementation in addition to its algorithmic efficiency has also other efficiencies stemming from possible use of low level libraries? Or am I missing something entirely? The modified code I used is below:
import itertools
def radix_sort(unsorted):
"Fast implementation of radix sort for any size num."
maximum, minimum = max(unsorted), min(unsorted)
max_bits = maximum.bit_length()
highest_byte = max_bits // 8 if max_bits % 8 == 0 else (max_bits // 8) + 1
# min_bits = minimum.bit_length()
# lowest_byte = min_bits // 8 if min_bits % 8 == 0 else (min_bits // 8) + 1
sorted_list = unsorted
# xrange changed to range, lowest_byte deleted from the arguments
for offset in range(highest_byte):
sorted_list = radix_sort_offset(sorted_list, offset)
return sorted_list
def radix_sort_offset(unsorted, offset):
"Helper function for radix sort, sorts each offset."
byte_check = (0xFF << offset*8)
# xrange changed to range
buckets = [[] for _ in range(256)]
for num in unsorted:
byte_at_offset = (num & byte_check) >> offset*8
buckets[byte_at_offset].append(num)
return list(itertools.chain.from_iterable(buckets))
You could simply use one of the existing C or C++ implementations, such
as example, integer_sort from Boost.Sort or u4_sort from usort. It is surprisingly easy to call native C or C++ code from Python, see How to sort an array of integers faster than quicksort?
I totally get your frustration. Although it's been more than 2 years, numpy still does not have radix sort. I will let the NumPy developers know that they could simply grab one of the existing implementations; licensing should not be an issue.
I'm having some troubles understanding this behaviour.
I'm measuring the execution time with the timeit-module and get the following results for 10000 cycles:
Merge : 1.22722930395
Bubble: 0.810706578175
Select: 0.469924766812
This is my code for MergeSort:
def mergeSort(array):
if len(array) <= 1:
return array
else:
left = array[:len(array)/2]
right = array[len(array)/2:]
return merge(mergeSort(left),mergeSort(right))
def merge(array1,array2):
merged_array=[]
while len(array1) > 0 or len(array2) > 0:
if array2 and not array1:
merged_array.append(array2.pop(0))
elif (array1 and not array2) or array1[0] < array2[0]:
merged_array.append(array1.pop(0))
else:
merged_array.append(array2.pop(0))
return merged_array
Edit:
I've changed the list operations to use pointers and my tests now work with a list of 1000 random numbers from 0-1000. (btw: I changed to only 10 cycles here)
result:
Merge : 0.0574434420723
Bubble: 1.74780097558
Select: 0.362952293025
This is my rewritten merge definition:
def merge(array1, array2):
merged_array = []
pointer1, pointer2 = 0, 0
while pointer1 < len(array1) and pointer2 < len(array2):
if array1[pointer1] < array2[pointer2]:
merged_array.append(array1[pointer1])
pointer1 += 1
else:
merged_array.append(array2[pointer2])
pointer2 += 1
while pointer1 < len(array1):
merged_array.append(array1[pointer1])
pointer1 += 1
while pointer2 < len(array2):
merged_array.append(array2[pointer2])
pointer2 += 1
return merged_array
seems to work pretty well now :)
list.pop(0) pops the first element and has to shift all remaining ones, this is an additional O(n) operation which must not happen.
Also, slicing a list object creates a copy:
left = array[:len(array)/2]
right = array[len(array)/2:]
Which means you're also using O(n * log(n)) memory instead of O(n).
I can't see BubbleSort, but I bet it works in-place, no wonder it's faster.
You need to rewrite it to work in-place. Instead of copying part of original list, pass starting and ending indexes.
For starters : I cannot reproduce your timing results, on 100 cycles and lists of size 10000. The exhaustive benchmark with timeit of all implementations discussed in this answer (including bubblesort and your original snippet) is posted as a gist here. I find the following results for the average duration of a single run :
Python's native (Tim)sort : 0.0144600081444
Bubblesort : 26.9620819092
(Your) Original Mergesort : 0.224888720512
Now, to make your function faster, you can do a few things.
Edit : Well, apparently, I was wrong on that one (thanks cwillu). Length computation takes O(1) in python. But removing useless computation everywhere still improves things a bit (Original Mergesort: 0.224888720512, no-length Mergesort: 0.195795390606):
def nolenmerge(array1,array2):
merged_array=[]
while array1 or array2:
if not array1:
merged_array.append(array2.pop(0))
elif (not array2) or array1[0] < array2[0]:
merged_array.append(array1.pop(0))
else:
merged_array.append(array2.pop(0))
return merged_array
def nolenmergeSort(array):
n = len(array)
if n <= 1:
return array
left = array[:n/2]
right = array[n/2:]
return nolenmerge(nolenmergeSort(left),nolenmergeSort(right))
Second, as suggested in this answer, pop(0) is linear. Rewrite your merge to pop() at the end:
def fastmerge(array1,array2):
merged_array=[]
while array1 or array2:
if not array1:
merged_array.append(array2.pop())
elif (not array2) or array1[-1] > array2[-1]:
merged_array.append(array1.pop())
else:
merged_array.append(array2.pop())
merged_array.reverse()
return merged_array
This is again faster: no-len Mergesort: 0.195795390606, no-len Mergesort+fastmerge: 0.126505711079
Third - and this would only be useful as-is if you were using a language that does tail call optimization, without it , it's a bad idea - your call to merge to merge is not tail-recursive; it calls both (mergeSort left) and (mergeSort right) recursively while there is remaining work in the call (merge).
But you can make the merge tail-recursive by using CPS (this will run out of stack size for even modest lists if you don't do tco):
def cps_merge_sort(array):
return cpsmergeSort(array,lambda x:x)
def cpsmergeSort(array,continuation):
n = len(array)
if n <= 1:
return continuation(array)
left = array[:n/2]
right = array[n/2:]
return cpsmergeSort (left, lambda leftR:
cpsmergeSort(right, lambda rightR:
continuation(fastmerge(leftR,rightR))))
Once this is done, you can do TCO by hand to defer the call stack management done by recursion to the while loop of a normal function (trampolining, explained e.g. here, trick originally due to Guy Steele). Trampolining and CPS work great together.
You write a thunking function, that "records" and delays application: it takes a function and its arguments, and returns a function that returns (that original function applied to those arguments).
thunk = lambda name, *args: lambda: name(*args)
You then write a trampoline that manages calls to thunks: it applies a thunk until the thunk returns a result (as opposed to another thunk)
def trampoline(bouncer):
while callable(bouncer):
bouncer = bouncer()
return bouncer
Then all that's left is to "freeze" (thunk) all your recursive calls from the original CPS function, to let the trampoline unwrap them in proper sequence. Your function now returns a thunk, without recursion (and discarding its own frame), at every call:
def tco_cpsmergeSort(array,continuation):
n = len(array)
if n <= 1:
return continuation(array)
left = array[:n/2]
right = array[n/2:]
return thunk (tco_cpsmergeSort, left, lambda leftR:
thunk (tco_cpsmergeSort, right, lambda rightR:
(continuation(fastmerge(leftR,rightR)))))
mycpomergesort = lambda l: trampoline(tco_cpsmergeSort(l,lambda x:x))
Sadly this does not go that fast (recursive mergesort:0.126505711079, this trampolined version : 0.170638551712). OK, I guess the stack blowup of the recursive merge sort algorithm is in fact modest : as soon as you get out of the leftmost path in the array-slicing recursion pattern, the algorithm starts returning (& removing frames). So for 10K-sized lists, you get a function stack of at most log_2(10 000) = 14 ... pretty modest.
You can do slightly more involved stack-based TCO elimination in the guise of this SO answer gives:
def leftcomb(l):
maxn,leftcomb = len(l),[]
n = maxn/2
while maxn > 1:
leftcomb.append((l[n:maxn],False))
maxn,n = n,n/2
return l[:maxn],leftcomb
def tcomergesort(l):
l,stack = leftcomb(l)
while stack: # l sorted, stack contains tagged slices
i,ordered = stack.pop()
if ordered:
l = fastmerge(l,i)
else:
stack.append((l,True)) # store return call
rsub,ssub = leftcomb(i)
stack.extend(ssub) #recurse
l = rsub
return l
But this goes only a tad faster (trampolined mergesort: 0.170638551712, this stack-based version:0.144994809628). Apparently, the stack-building python does at the recursive calls of our original merge sort is pretty inexpensive.
The final results ? on my machine (Ubuntu natty's stock Python 2.7.1+), the average run timings (out of of 100 runs -except for Bubblesort-, list of size 10000, containing random integers of size 0-10000000) are:
Python's native (Tim)sort : 0.0144600081444
Bubblesort : 26.9620819092
Original Mergesort : 0.224888720512
no-len Mergesort : 0.195795390606
no-len Mergesort + fastmerge : 0.126505711079
trampolined CPS Mergesort + fastmerge : 0.170638551712
stack-based mergesort + fastmerge: 0.144994809628
Your merge-sort has a big constant factor, you have to run it on large lists to see the asymptotic complexity benefit.
Umm.. 1,000 records?? You are still well within the polynomial cooefficient dominance here.. If I have
selection-sort: 15 * n ^ 2 (reads) + 5 * n^2 (swaps)
insertion-sort: 5 * n ^2 (reads) + 15 * n^2 (swaps)
merge-sort: 200 * n * log(n) (reads) 1000 * n * log(n) (merges)
You're going to be in a close race for a lonng while.. By the way, 2x faster in sorting is NOTHING. Try 100x slower. That's where the real differences are felt. Try "won't finish in my life-time" algorithms (there are known regular expressions that take this long to match simple strings).
So try 1M or 1G records and let us know if you still thing merge-sort isn't doing too well.
That being said..
There are lots of things causing this merge-sort to be expensive. First of all, nobody ever runs quick or merge sort on small scale data-structures.. Where you have if (len <= 1), people generally put:
if (len <= 16) : (use inline insertion-sort)
else: merge-sort
At EACH propagation level.
Since insertion-sort is has smaller coefficent cost at smaller sizes of n. Note that 50% of your work is done in this last mile.
Next, you are needlessly running array1.pop(0) instead of maintaining index-counters. If you're lucky, python is efficiently managing start-of-array offsets, but all else being equal, you're mutating input parameters
Also, you know the size of the target array during merge, why copy-and-double the merged_array repeatedly.. Pre-allocate the size of the target array at the start of the function.. That'll save at least a dozen 'clones' per merge-level.
In general, merge-sort uses 2x the size of RAM.. Your algorithm is probably using 20x because of all the temporary merge buffers (hopefully python can free structures before recursion). It breaks elegance, but generally the best merge-sort algorithms make an immediate allocation of a merge buffer equal to the size of the source array, and you perform complex address arithmetic (or array-index + span-length) to just keep merging data-structures back and forth. It won't be as elegent as a simple recursive problem like this, but it's somewhat close.
In C-sorting, cache-coherence is your biggest enemy. You want hot data-structures so you maximize your cache. By allocating transient temp buffers (even if the memory manager is returning pointers to hot memory) you run the risk of making slow DRAM calls (pre-filling cache-lines for data you're about to over-write). This is one advantage insertion-sort,selection-sort and quick-sort have over merge-sort (when implemented as above)
Speaking of which, something like quick-sort is both naturally-elegant code, naturally efficient-code, and doesn't waste any memory (google it on wikipedia- they have a javascript implementation from which to base your code). Squeezing the last ounce of performance out of quick-sort is hard (especially in scripting languages, which is why they generally just use the C-api to do that part), and you have a worst-case of O(n^2). You can try and be clever by doing a combination bubble-sort/quick-sort to mitigate worst-case.
Happy coding.