Pushing Radix Sort (and python) to its limits - python

I've been immensely frustrated with many of the implementations of python radix sort out there on the web.
They consistently use a radix of 10 and get the digits of the numbers they iterate over by dividing by a power of 10 or taking the log10 of the number. This is incredibly inefficient, as log10 is not a particularly quick operation compared to bit shifting, which is nearly 100 times faster!
A much more efficient implementation uses a radix of 256 and sorts the number byte by byte. This allows for all of the 'byte getting' to be done using the ridiculously quick bit operators. Unfortunately, it seems that absolutely nobody out there has implemented a radix sort in python that uses bit operators instead of logarithms.
So, I took matters into my own hands and came up with this beast, which runs at about half the speed of sorted on small arrays and runs nearly as quickly on larger ones (e.g. len around 10,000,000):
import itertools
def radix_sort(unsorted):
"Fast implementation of radix sort for any size num."
maximum, minimum = max(unsorted), min(unsorted)
max_bits = maximum.bit_length()
highest_byte = max_bits // 8 if max_bits % 8 == 0 else (max_bits // 8) + 1
min_bits = minimum.bit_length()
lowest_byte = min_bits // 8 if min_bits % 8 == 0 else (min_bits // 8) + 1
sorted_list = unsorted
for offset in xrange(lowest_byte, highest_byte):
sorted_list = radix_sort_offset(sorted_list, offset)
return sorted_list
def radix_sort_offset(unsorted, offset):
"Helper function for radix sort, sorts each offset."
byte_check = (0xFF << offset*8)
buckets = [[] for _ in xrange(256)]
for num in unsorted:
byte_at_offset = (num & byte_check) >> offset*8
buckets[byte_at_offset].append(num)
return list(itertools.chain.from_iterable(buckets))
This version of radix sort works by finding which bytes it has to sort by (if you pass it only integers below 256, it'll sort just one byte, etc.) then sorting each byte from LSB up by dumping them into buckets in order then just chaining the buckets together. Repeat this for each byte that needs to be sorted and you have your nice sorted array in O(n) time.
However, it's not as fast as it could be, and I'd like to make it faster before I write about it as a better radix sort than all the other radix sorts out there.
Running cProfile on this tells me that a lot of time is being spent on the append method for lists, which makes me think that this block:
for num in unsorted:
byte_at_offset = (num & byte_check) >> offset*8
buckets[byte_at_offset].append(num)
in radix_sort_offset is eating a lot of time. This is also the block that, if you really look at it, does 90% of the work for the whole sort. This code looks like it could be numpy-ized, which I think would result in quite a performance boost. Unfortunately, I'm not very good with numpy's more complex features so haven't been able to figure that out. Help would be very appreciated.
I'm currently using itertools.chain.from_iterable to flatten the buckets, but if anyone has a faster suggestion I'm sure it would help as well.
Originally, I had a get_byte function that returned the nth byte of a number, but inlining the code gave me a huge speed boost so I did it.
Any other comments on the implementation or ways to squeeze out more performance are also appreciated. I want to hear anything and everything you've got.

You already realized that
for num in unsorted:
byte_at_offset = (num & byte_check) >> offset*8
buckets[byte_at_offset].append(num)
is where most of the time goes - good ;-)
There are two standard tricks for speeding that kind of thing, both having to do with moving invariants out of loops:
Compute "offset*8" outside the loop. Store it in a local variable. Save a multiplication per iteration.
Add bucketappender = [bucket.append for bucket in buckets] outside the loop. Saves a method lookup per iteration.
Combine them, and the loop looks like:
for num in unsorted:
bucketappender[(num & byte_check) >> ofs8](num)
Collapsing it to one statement also saves a pair of local vrbl store/fetch opcodes per iteration.
But, at a higher level, the standard way to speed radix sort is to use a larger radix. What's magical about 256? Nothing, apart from that it's convenient for bit-shifting. But so are 512, 1024, 2048 ... it's a classical time/space tradeoff.
PS: for very long numbers,
(num >> offset*8) & 0xff
will run faster. That's because your num & byte_check takes time proportional to log(num) - it generally has to create an integer about as big as num.

This is an old thread, but I came across this when looking to radix sort an array of positive integers. I was trying to see if I can do any better than the already wickedly fast timsort (hats off to you again, Tim Peters) which implements python's builtin sorted and sort! Either I don't understand certain aspects of the above code, or if I do, the code as presented above has some problems IMHO.
It only sorts bytes starting with the highest byte of the smallest item and ending with the highest byte of the biggest item. This may be okay in some cases of special data. But in general the approach fails to differentiate items which differ on account of the lower bits. For example:
arr=[65535,65534]
radix_sort(arr)
produces the wrong output:
[65535, 65534]
The range used to loop over the helper function is not correct. What I mean is that if lowest_byte and highest_byte are the same, execution of the helper function is altogether skipped. BTW I had to change xrange to range in 2 places.
With modifications to address the above 2 points, I got it to work. But it is taking 10-20 times the time of python's builtin sorted or sort! I know timsort is very efficient and takes advantage of already sorted runs in the data. But I was trying to see if I can use the prior knowledge that my data is all positive integers to some advantage in my sorting. Why is the radix sort doing so badly compared to timsort? The array sizes I was using are in the order of 80K items. Is it because the timsort implementation in addition to its algorithmic efficiency has also other efficiencies stemming from possible use of low level libraries? Or am I missing something entirely? The modified code I used is below:
import itertools
def radix_sort(unsorted):
"Fast implementation of radix sort for any size num."
maximum, minimum = max(unsorted), min(unsorted)
max_bits = maximum.bit_length()
highest_byte = max_bits // 8 if max_bits % 8 == 0 else (max_bits // 8) + 1
# min_bits = minimum.bit_length()
# lowest_byte = min_bits // 8 if min_bits % 8 == 0 else (min_bits // 8) + 1
sorted_list = unsorted
# xrange changed to range, lowest_byte deleted from the arguments
for offset in range(highest_byte):
sorted_list = radix_sort_offset(sorted_list, offset)
return sorted_list
def radix_sort_offset(unsorted, offset):
"Helper function for radix sort, sorts each offset."
byte_check = (0xFF << offset*8)
# xrange changed to range
buckets = [[] for _ in range(256)]
for num in unsorted:
byte_at_offset = (num & byte_check) >> offset*8
buckets[byte_at_offset].append(num)
return list(itertools.chain.from_iterable(buckets))

You could simply use one of the existing C or C++ implementations, such
as example, integer_sort from Boost.Sort or u4_sort from usort. It is surprisingly easy to call native C or C++ code from Python, see How to sort an array of integers faster than quicksort?
I totally get your frustration. Although it's been more than 2 years, numpy still does not have radix sort. I will let the NumPy developers know that they could simply grab one of the existing implementations; licensing should not be an issue.

Related

Fastest way to find powers of two higher than a number

I am trying to find a very fast way to find the next higher powers of 2 than a very large number (1,000,000) digits. Example, i have 1009, and want to find it's next higher powers of two which is 1024 or 2**10
I tried using a loop, but for large numbers this is very, very slow
y=0
while (1<<y)<1009:
y+=1
print(1<<y)
1024
While this works, it's slow for numbers larger than a million digits. Is there a faster algorithm to find the next higher powers of 2 than a number that is large?
ANSWERED BY #JonClements
using 2**number.bit_length() works perfectly. So this will work for large numbers as well. Thanks Jon.
Here's a code example from Jon's implementation:
2**j.bit_length()
1024
Here's a code example using the shift operator
2<<(j.bit_length()-1)
1024
Here is the time difference using the million length number, the shift operator and bit_length is significantly faster:
len(str(aa))
1000000
def useBITLENGTHwithshiftoperator(hm):
return 1<<hm.bit_length()-1<<1
def useBITLENGTHwithpowersoperator(hm):
return 2**hm.bit_length()
start = time.time()
l=useBITLENGTHwithpowersoperator(aa)
end = time.time()
print(end - start)
0.014303922653198242
start = time.time()
l=useBITLENGTHwithshiftoperator(aa)
end = time.time()
print(end - start)
0.0002968311309814453
take 2^ceiling(logBase2(x)) - should work unless x is a power of 2. and you can check for that with: if x==ceiling(x).
I do not code in python but millions of digits implies bignums so:
try to look inside your bignum lib
It might return the number of words or bits used in O(1) as some number representations need it to speed up other stuff. In such case you can obtain your answer in O(1) for free.
As #JonClements suggested in a comments try bit_length() and measure if it is O(1) or O(log(n)) ...
Your while is O(n^3) instead of O(n^2)
You are bitshifting from 1 over and over again in each iteration. Why not just shift last result by 1 bit again instead? Something like
for (y=0,yy=1;yy<1009;y++,yy<<=1);
using log2 might be faster
in case the bignum class you use have it implemented correctly after some number size threshold the log2(1009) might be signifficantly faster. But that depends on the type of numbers you using and bignum implementation itself.
bit-shifting can be even faster
If you got some upper limit on your numbers you can use binary search converting your bitshifting into O(n.log2(n)).
If not you can start bitshifting by 32 bits instead of by 1 when reached target size bitshift by 1 bit. Or even use more layers like 1024/128/16/1 bits. The complexity would be still O(n^2) but the constant time would be ~1024 times smaller speeding up ~1024 times your code for big numbers...
Other option is to find the limit by shifting by 1 bit, then by 2 then by 4,8,16,32,64,... until result is bigger than your target number and from there either bitshift back or use binary search. This one would be O(n.log2(n)) even without any upper limit..
However all of these brings up much higher overhead and will slow down the processing of smaller numbers.
Constructing 2^(y-1) < x <= 2^y might be possible to enhance too. For example by using bit shifting approach to find the y you got your answer as byproduct for free. For example with floating point or fixed point numbers you can directly construct such number as computing exponent for 1 or by setting correct bit in the zero ... But for arbitrary numbers (where size of number is dynamic) i sthis much harder/slower. So all boils down what kind of bignums class you got and what values you use.

Is there any efficient way to increment the corresponding set positions of an integer in an integer array?

Any solution consuming less than O(Bit Length) time is welcome. I need to process around 100 million large integers.
answer = [0 for i in xrange(100)]
def pluginBits(val):
global answer
for j in xrange(len(answer)):
if val <= 0:
break
answer[j] += (val & 1)
val >>= 1
A speedier way to do this would be to use '{:b}'.format(someval) to convert from integer to a string of '1's and '0's. Python still needs to do similar work to perform this conversion, but doing it at the C layer in the interpreter internals involves significantly less overhead for larger values.
For conversion to actual list of integer 1s and 0s, you could do something like:
# Done once at top level to make translation table:
import string
bitstr_to_intval = string.maketrans(b'01', b'\x00\x01')
# Done for each value to convert:
bits = bytearray('{:b}'.format(origint).translate(bitstr_to_intval))
Since bytearray is a mutable sequence of values in range(256) that iterates the actual int values, you don't need to convert to list; it should be usable in 99% of the places the list would be used, using less memory and running faster.
This does generate the values in the reverse of the order your code produces (that is, bits[-1] here is the same as your answer[0], bits[-2] is your answer[1], etc.), and it's unpadded, but since you're summing bits, the padding isn't needed, and reversing the result is a trivial reversing slice (add [::-1] to the end). Summing the bits from each input can be made much faster by making answer a numpy array (that allows a bulk element-wise addition at the C layer), and putting it all together gets:
import string
bitstr_to_intval = string.maketrans(b'01', b'\x00\x01')
answer = numpy.zeros(100, numpy.uint64)
def pluginBits(val):
bits = bytearray('{:b}'.format(val).translate(bitstr_to_intval))[::-1]
answer[:len(bits)] += bits
In local tests, this definition of pluginBits takes a little under one-seventh the time to sum the bits at each position for 10,000 random input integers of 100 bits each, and gets the same results.

Sum of primes below 2,000,000 in python

I am attempting problem 10 of Project Euler, which is the summation of all primes below 2,000,000. I have tried implementing the Sieve of Erasthotenes using Python, and the code I wrote works perfectly for numbers below 10,000.
However, when I attempt to find the summation of primes for bigger numbers, the code takes too long to run (finding the sum of primes up to 100,000 took 315 seconds). The algorithm clearly needs optimization.
Yes, I have looked at other posts on this website, like Fastest way to list all primes below N, but the solutions there had very little explanation as to how the code worked (I am still a beginner programmer) so I was not able to actually learn from them.
Can someone please help me optimize my code, and clearly explain how it works along the way?
Here is my code:
primes_below_number = 2000000 # number to find summation of all primes below number
numbers = (range(1, primes_below_number + 1, 2)) # creates a list excluding even numbers
pos = 0 # index position
sum_of_primes = 0 # total sum
number = numbers[pos]
while number < primes_below_number and pos < len(numbers) - 1:
pos += 1
number = numbers[pos] # moves to next prime in list numbers
sum_of_primes += number # adds prime to total sum
num = number
while num < primes_below_number:
num += number
if num in numbers[:]:
numbers.remove(num) # removes multiples of prime found
print sum_of_primes + 2
As I said before, I am new to programming, therefore a thorough explanation of any complicated concepts would be deeply appreciated. Thank you.
As you've seen, there are various ways to implement the Sieve of Erasthotenes in Python that are more efficient than your code. I don't want to confuse you with fancy code, but I can show how to speed up your code a fair bit.
Firstly, searching a list isn't fast, and removing elements from a list is even slower. However, Python provides a set type which is quite efficient at performing both of those operations (although it does chew up a bit more RAM than a simple list). Happily, it's easy to modify your code to use a set instead of a list.
Another optimization is that we don't have to check for prime factors all the way up to primes_below_number, which I've renamed to hi in the code below. It's sufficient to just go to the square root of hi, since if a number is composite it must have a factor less than or equal to its square root.
We don't need to keep a running total of the sum of the primes. It's better to do that at the end using Python's built-in sum() function, which operates at C speed, so it's much faster than doing the additions one by one at Python speed.
# number to find summation of all primes below number
hi = 2000000
# create a set excluding even numbers
numbers = set(xrange(3, hi + 1, 2))
for number in xrange(3, int(hi ** 0.5) + 1):
if number not in numbers:
#number must have been removed because it has a prime factor
continue
num = number
while num < hi:
num += number
if num in numbers:
# Remove multiples of prime found
numbers.remove(num)
print 2 + sum(numbers)
You should find that this code runs in a a few seconds; it takes around 5 seconds on my 2GHz single-core machine.
You'll notice that I've moved the comments so that they're above the line they're commenting on. That's the preferred style in Python since we prefer short lines, and also inline comments tend to make the code look cluttered.
There's another small optimization that can be made to the inner while loop, but I let you figure that out for yourself. :)
First, removing numbers from the list will be very slow. Instead of this, make a list
primes = primes_below_number * True
primes[0] = False
primes[1] = False
Now in your loop, when you find a prime p, change primes[k*p] to False for all suitable k. (You wouldn't actually do multiply, you'd continually add p, of course.)
At the end,
primes = [n for n i range(primes_below_number) if primes[n]]
This should be a great deal faster.
Second, you can stop looking once your find a prime greater than the square root of primes_below_number, since a composite number must have a prime factor that doesn't exceed its square root.
Try using numpy, should make it faster. Replace range by xrange, it may help you.
Here's an optimization for your code:
import itertools
primes_below_number = 2000000
numbers = list(range(3, primes_below_number, 2))
pos = 0
while pos < len(numbers) - 1:
number = numbers[pos]
numbers = list(
itertools.chain(
itertools.islice(numbers, 0, pos + 1),
itertools.ifilter(
lambda n: n % number != 0,
itertools.islice(numbers, pos + 1, len(numbers))
)
)
)
pos += 1
sum_of_primes = sum(numbers) + 2
print sum_of_primes
The optimization here is because:
Removed the sum to outside the loop.
Instead of removing elements from a list we can just create another one, memory is not an issue here (I hope).
When creating the new list we create it by chaining two parts, the first part is everything before the current number (we already checked those), and the second part is everything after the current number but only if they are not divisible by the current number.
Using itertools can make things faster since we'd be using iterators instead of looping through the whole list more than once.
Another solution would be to not remove parts of the list but disable them like #saulspatz said.
And here's the fastest way I was able to find: http://www.wolframalpha.com/input/?i=sum+of+all+primes+below+2+million 😁
Update
Here is the boolean method:
import itertools
primes_below_number = 2000000
numbers = [v % 2 != 0 for v in xrange(primes_below_number)]
numbers[0] = False
numbers[1] = False
numbers[2] = True
number = 3
while number < primes_below_number:
n = number * 3 # We already excluded even numbers
while n < primes_below_number:
numbers[n] = False
n += number
number += 1
while number < primes_below_number and not numbers[number]:
number += 1
sum_of_numbers = sum(itertools.imap(lambda index_n: index_n[1] and index_n[0] or 0, enumerate(numbers)))
print(sum_of_numbers)
This executes in seconds (took 3 seconds on my 2.4GHz machine).
Instead of storing a list of numbers, you can instead store an array of boolean values. This use of a bitmap can be thought of as a way to implement a set, which works well for dense sets (there aren't big gaps between the values of members).
An answer on a recent python sieve question uses this implementation python-style. It turns out a lot of people have implemented a sieve, or something they thought was a sieve, and then come on SO to ask why it was slow. :P Look at the related-questions sidebar from some of them if you want more reading material.
Finding the element that holds the boolean that says whether a number is in the set or not is easy and extremely fast. array[i] is a boolean value that's true if i is in the set, false if not. The memory address can be computed directly from i with a single addition.
(I'm glossing over the fact that an array of boolean might be stored with a whole byte for each element, rather than the more efficient implementation of using every single bit for a different element. Any decent sieve will use a bitmap.)
Removing a number from the set is as simple as setting array[i] = false, regardless of the previous value. No searching, not comparison, no tracking of what happened, just one memory operation. (Well, two for a bitmap: load the old byte, clear the correct bit, store it. Memory is byte-addressable, but not bit-addressable.)
An easy optimization of the bitmap-based sieve is to not even store the even-numbered bytes, because there is only one even prime, and we can special-case it to double our memory density. Then the membership-status of i is held in array[i/2]. (Dividing by powers of two is easy for computers. Other values are much slower.)
An SO question:
Why is Sieve of Eratosthenes more efficient than the simple "dumb" algorithm? has many links to good stuff about the sieve. This one in particular has some good discussion about it, in words rather than just code. (Nevermind the fact that it's talking about a common Haskell implementation that looks like a sieve, but actually isn't. They call this the "unfaithful" sieve in their graphs, and so on.)
discussion on that question brought up the point that trial division may be fast than big sieves, for some uses, because clearing the bits for all multiples of every prime touches a lot of memory in a cache-unfriendly pattern. CPUs are much faster than memory these days.

How can I vectorize this python count sort so it is absolutely as fast as it can be?

I am trying to write a count sort in python to beat the built-in timsort in certain situations. Right now it beats the built in sorted function, but only for very large arrays (1 million integers in length and longer, I haven't tried over 10 million) and only for a range no larger than 10,000. Additionally, the victory is narrow, with count sort only winning by a significant margin in random lists specifically tailored to it.
I have read about astounding performance gains that can be gained from vectorizing python code, but I don't particularly understand how to do it or how it could be used here. I would like to know how I can vectorize this code to speed it up, and any other performance suggestions are welcome.
Current fastest version for just python and stdlibs:
from itertools import chain, repeat
def untimed_countsort(unsorted_list):
counts = {}
for num in unsorted_list:
try:
counts[num] += 1
except KeyError:
counts[num] = 1
sorted_list = list(
chain.from_iterable(
repeat(num, counts[num])
for num in xrange(min(counts), max(counts) + 1)))
return sorted_list
All that counts is raw speed here, so sacrificing even more space for speed gains is completely fair game.
I realize the code is fairly short and clear already, so I don't know how much room there is for improvement in speed.
If anyone has a change to the code to make it shorter, as long as it doesn't make it slower, that would be awesome as well.
Execution time is down almost 80%! Now three times as fast as Timsort on my current tests!
The absolute fastest way to do this by a LONG shot is using this one-liner with numpy:
def np_sort(unsorted_np_array):
return numpy.repeat(numpy.arange(1+unsorted_np_array.max()), numpy.bincount(unsorted_np_array))
This runs about 10-15 times faster than the pure python version, and about 40 times faster than Timsort. It takes a numpy array in and outputs a numpy array.
With numpy, this function reduces to the following:
def countsort(unsorted):
unsorted = numpy.asarray(unsorted)
return numpy.repeat(numpy.arange(1+unsorted.max()), numpy.bincount(unsorted))
This ran about 40 times faster when I tried it on 100000 random ints from the interval [0, 10000). bincount does the counting, and repeat converts from counts to a sorted array.
Without thinking about your algorithm, this will help get rid of most of your pure python loops (which are quite slow) and turning them into comprehensions or generators (always faster than regular for blocks). Also, if you have to make a list consisting of all the same elements, the [x]*n syntax is probably the fastest way to go. The sum is used to flatten the list of lists.
from collections import defaultdict
def countsort(unsorted_list):
lmin, lmax = min(unsorted_list), max(unsorted_list) + 1
counts = defaultdict(int)
for j in unsorted_list:
counts[j] += 1
return sum([[num]*counts[num] for num in xrange(lmin, lmax) if num in counts])
Note that this is not vectorized, nor does it use numpy.

Why is my MergeSort so slow in Python?

I'm having some troubles understanding this behaviour.
I'm measuring the execution time with the timeit-module and get the following results for 10000 cycles:
Merge : 1.22722930395
Bubble: 0.810706578175
Select: 0.469924766812
This is my code for MergeSort:
def mergeSort(array):
if len(array) <= 1:
return array
else:
left = array[:len(array)/2]
right = array[len(array)/2:]
return merge(mergeSort(left),mergeSort(right))
def merge(array1,array2):
merged_array=[]
while len(array1) > 0 or len(array2) > 0:
if array2 and not array1:
merged_array.append(array2.pop(0))
elif (array1 and not array2) or array1[0] < array2[0]:
merged_array.append(array1.pop(0))
else:
merged_array.append(array2.pop(0))
return merged_array
Edit:
I've changed the list operations to use pointers and my tests now work with a list of 1000 random numbers from 0-1000. (btw: I changed to only 10 cycles here)
result:
Merge : 0.0574434420723
Bubble: 1.74780097558
Select: 0.362952293025
This is my rewritten merge definition:
def merge(array1, array2):
merged_array = []
pointer1, pointer2 = 0, 0
while pointer1 < len(array1) and pointer2 < len(array2):
if array1[pointer1] < array2[pointer2]:
merged_array.append(array1[pointer1])
pointer1 += 1
else:
merged_array.append(array2[pointer2])
pointer2 += 1
while pointer1 < len(array1):
merged_array.append(array1[pointer1])
pointer1 += 1
while pointer2 < len(array2):
merged_array.append(array2[pointer2])
pointer2 += 1
return merged_array
seems to work pretty well now :)
list.pop(0) pops the first element and has to shift all remaining ones, this is an additional O(n) operation which must not happen.
Also, slicing a list object creates a copy:
left = array[:len(array)/2]
right = array[len(array)/2:]
Which means you're also using O(n * log(n)) memory instead of O(n).
I can't see BubbleSort, but I bet it works in-place, no wonder it's faster.
You need to rewrite it to work in-place. Instead of copying part of original list, pass starting and ending indexes.
For starters : I cannot reproduce your timing results, on 100 cycles and lists of size 10000. The exhaustive benchmark with timeit of all implementations discussed in this answer (including bubblesort and your original snippet) is posted as a gist here. I find the following results for the average duration of a single run :
Python's native (Tim)sort : 0.0144600081444
Bubblesort : 26.9620819092
(Your) Original Mergesort : 0.224888720512
Now, to make your function faster, you can do a few things.
Edit : Well, apparently, I was wrong on that one (thanks cwillu). Length computation takes O(1) in python. But removing useless computation everywhere still improves things a bit (Original Mergesort: 0.224888720512, no-length Mergesort: 0.195795390606):
def nolenmerge(array1,array2):
merged_array=[]
while array1 or array2:
if not array1:
merged_array.append(array2.pop(0))
elif (not array2) or array1[0] < array2[0]:
merged_array.append(array1.pop(0))
else:
merged_array.append(array2.pop(0))
return merged_array
def nolenmergeSort(array):
n = len(array)
if n <= 1:
return array
left = array[:n/2]
right = array[n/2:]
return nolenmerge(nolenmergeSort(left),nolenmergeSort(right))
Second, as suggested in this answer, pop(0) is linear. Rewrite your merge to pop() at the end:
def fastmerge(array1,array2):
merged_array=[]
while array1 or array2:
if not array1:
merged_array.append(array2.pop())
elif (not array2) or array1[-1] > array2[-1]:
merged_array.append(array1.pop())
else:
merged_array.append(array2.pop())
merged_array.reverse()
return merged_array
This is again faster: no-len Mergesort: 0.195795390606, no-len Mergesort+fastmerge: 0.126505711079
Third - and this would only be useful as-is if you were using a language that does tail call optimization, without it , it's a bad idea - your call to merge to merge is not tail-recursive; it calls both (mergeSort left) and (mergeSort right) recursively while there is remaining work in the call (merge).
But you can make the merge tail-recursive by using CPS (this will run out of stack size for even modest lists if you don't do tco):
def cps_merge_sort(array):
return cpsmergeSort(array,lambda x:x)
def cpsmergeSort(array,continuation):
n = len(array)
if n <= 1:
return continuation(array)
left = array[:n/2]
right = array[n/2:]
return cpsmergeSort (left, lambda leftR:
cpsmergeSort(right, lambda rightR:
continuation(fastmerge(leftR,rightR))))
Once this is done, you can do TCO by hand to defer the call stack management done by recursion to the while loop of a normal function (trampolining, explained e.g. here, trick originally due to Guy Steele). Trampolining and CPS work great together.
You write a thunking function, that "records" and delays application: it takes a function and its arguments, and returns a function that returns (that original function applied to those arguments).
thunk = lambda name, *args: lambda: name(*args)
You then write a trampoline that manages calls to thunks: it applies a thunk until the thunk returns a result (as opposed to another thunk)
def trampoline(bouncer):
while callable(bouncer):
bouncer = bouncer()
return bouncer
Then all that's left is to "freeze" (thunk) all your recursive calls from the original CPS function, to let the trampoline unwrap them in proper sequence. Your function now returns a thunk, without recursion (and discarding its own frame), at every call:
def tco_cpsmergeSort(array,continuation):
n = len(array)
if n <= 1:
return continuation(array)
left = array[:n/2]
right = array[n/2:]
return thunk (tco_cpsmergeSort, left, lambda leftR:
thunk (tco_cpsmergeSort, right, lambda rightR:
(continuation(fastmerge(leftR,rightR)))))
mycpomergesort = lambda l: trampoline(tco_cpsmergeSort(l,lambda x:x))
Sadly this does not go that fast (recursive mergesort:0.126505711079, this trampolined version : 0.170638551712). OK, I guess the stack blowup of the recursive merge sort algorithm is in fact modest : as soon as you get out of the leftmost path in the array-slicing recursion pattern, the algorithm starts returning (& removing frames). So for 10K-sized lists, you get a function stack of at most log_2(10 000) = 14 ... pretty modest.
You can do slightly more involved stack-based TCO elimination in the guise of this SO answer gives:
def leftcomb(l):
maxn,leftcomb = len(l),[]
n = maxn/2
while maxn > 1:
leftcomb.append((l[n:maxn],False))
maxn,n = n,n/2
return l[:maxn],leftcomb
def tcomergesort(l):
l,stack = leftcomb(l)
while stack: # l sorted, stack contains tagged slices
i,ordered = stack.pop()
if ordered:
l = fastmerge(l,i)
else:
stack.append((l,True)) # store return call
rsub,ssub = leftcomb(i)
stack.extend(ssub) #recurse
l = rsub
return l
But this goes only a tad faster (trampolined mergesort: 0.170638551712, this stack-based version:0.144994809628). Apparently, the stack-building python does at the recursive calls of our original merge sort is pretty inexpensive.
The final results ? on my machine (Ubuntu natty's stock Python 2.7.1+), the average run timings (out of of 100 runs -except for Bubblesort-, list of size 10000, containing random integers of size 0-10000000) are:
Python's native (Tim)sort : 0.0144600081444
Bubblesort : 26.9620819092
Original Mergesort : 0.224888720512
no-len Mergesort : 0.195795390606
no-len Mergesort + fastmerge : 0.126505711079
trampolined CPS Mergesort + fastmerge : 0.170638551712
stack-based mergesort + fastmerge: 0.144994809628
Your merge-sort has a big constant factor, you have to run it on large lists to see the asymptotic complexity benefit.
Umm.. 1,000 records?? You are still well within the polynomial cooefficient dominance here.. If I have
selection-sort: 15 * n ^ 2 (reads) + 5 * n^2 (swaps)
insertion-sort: 5 * n ^2 (reads) + 15 * n^2 (swaps)
merge-sort: 200 * n * log(n) (reads) 1000 * n * log(n) (merges)
You're going to be in a close race for a lonng while.. By the way, 2x faster in sorting is NOTHING. Try 100x slower. That's where the real differences are felt. Try "won't finish in my life-time" algorithms (there are known regular expressions that take this long to match simple strings).
So try 1M or 1G records and let us know if you still thing merge-sort isn't doing too well.
That being said..
There are lots of things causing this merge-sort to be expensive. First of all, nobody ever runs quick or merge sort on small scale data-structures.. Where you have if (len <= 1), people generally put:
if (len <= 16) : (use inline insertion-sort)
else: merge-sort
At EACH propagation level.
Since insertion-sort is has smaller coefficent cost at smaller sizes of n. Note that 50% of your work is done in this last mile.
Next, you are needlessly running array1.pop(0) instead of maintaining index-counters. If you're lucky, python is efficiently managing start-of-array offsets, but all else being equal, you're mutating input parameters
Also, you know the size of the target array during merge, why copy-and-double the merged_array repeatedly.. Pre-allocate the size of the target array at the start of the function.. That'll save at least a dozen 'clones' per merge-level.
In general, merge-sort uses 2x the size of RAM.. Your algorithm is probably using 20x because of all the temporary merge buffers (hopefully python can free structures before recursion). It breaks elegance, but generally the best merge-sort algorithms make an immediate allocation of a merge buffer equal to the size of the source array, and you perform complex address arithmetic (or array-index + span-length) to just keep merging data-structures back and forth. It won't be as elegent as a simple recursive problem like this, but it's somewhat close.
In C-sorting, cache-coherence is your biggest enemy. You want hot data-structures so you maximize your cache. By allocating transient temp buffers (even if the memory manager is returning pointers to hot memory) you run the risk of making slow DRAM calls (pre-filling cache-lines for data you're about to over-write). This is one advantage insertion-sort,selection-sort and quick-sort have over merge-sort (when implemented as above)
Speaking of which, something like quick-sort is both naturally-elegant code, naturally efficient-code, and doesn't waste any memory (google it on wikipedia- they have a javascript implementation from which to base your code). Squeezing the last ounce of performance out of quick-sort is hard (especially in scripting languages, which is why they generally just use the C-api to do that part), and you have a worst-case of O(n^2). You can try and be clever by doing a combination bubble-sort/quick-sort to mitigate worst-case.
Happy coding.

Categories