I'm currently working on a high-performance python 2.7 project utilizing lists ten thousands elements in size. Obviously, every operation must be performed as fast as possible.
So, I have two lists: One of them is a list of unique arbitrary numbers, let's call it A, and the other one is a linear list starting with 1 and with the same length as the first list, named B, which represents indices in A (starting with 1)
Something like enumerate, starting with 1.
For example:
A = [500, 300, 400, 200, 100] # The order here is arbitrary, they can be any integers, but every integer can only exist once
B = [ 1, 2, 3, 4, 5] # This is fixed, starting from 1, with exactly as many elements as A
If I have an element of B (called e_B) and want the corresponding element in A, I can simply do correspond_e_A = A[e_B - 1]. No problem.
But now I have a huge list of random, non-unique integers, and I want to know the indices of the integers that are in A, and what the corresponding elements in B are.
I think I have a reasonable solution for the first question:
indices_of_existing = numpy.nonzero(numpy.in1d(random_list, A))[0]
What is great about this approach is that there is no need to map() single operations, numpy's in1d just returns a list like [True, True, False, True, ...]. Using nonzero() I can get the indices of the elements in random_list that exist in A. Perfect, I think.
But for the second question, I'm stumped.
I tried something like:
corresponding_e_B = map(lambda x: numpy.where(A==x)[0][0] + 1, random_list))
This correctly gives me the indices, but it's not optimal, because firstly I need a map(), secondly I need a lambda, and finally numpy.where() does not stop after the item was found once (remember, A has only unique elements), meaning that it scales horribly with huge datasets like mine.
I took a look at bisect, but it seems bisect only works with single requests, not with lists, meaning that I'd still have to use map() and build my list elementwise (that's slow, isn't it?)
Since I'm quite new to Python, I was hoping anyone here might have an idea? Maybe a library I don't know yet?
I think you would be well advised to use a hashtable for your lookups instead of numpy.in1d, which uses a O(n log n) merge sort as a preprocessing step.
>>> A = [500, 300, 400, 200, 100]
>>> index = { k:i for i,k in enumerate(A, 1) }
>>> random_list = [200, 100, 50]
>>> [i for i,x in enumerate(random_list) if x in index]
[0, 1]
>>> map(index.get, random_list)
[4, 5, None]
>>> filter(None, map(index.get, random_list))
[4, 5]
This is Python 2, in Python 3 map and filter return generator-like objects, so you would need a list around filter if you want to get the result as a list.
I have tried to use builtin functions as much as possible to push the computational burden to the C side (assuming you use CPython). All the names are resolved upfront, so it should be pretty fast.
In general, for maximum performance, you might want to consider using PyPy, a great alternative Python implementation with JIT compilation.
A benchmark to compare multiple approaches is never a bad idea:
import sys
is_pypy = '__pypy__' in sys.builtin_module_names
import timeit
import random
if not is_pypy:
import numpy
import bisect
n = 10000
m = 10000
q = 100
A = set()
while len(A) < n: A.add(random.randint(0,2*n))
A = list(A)
queries = set()
while len(queries) < m: queries.add(random.randint(0,2*n))
queries = list(queries)
# these two solve question one (find indices of queries that exist in A)
if not is_pypy:
def fun11():
for _ in range(q):
numpy.nonzero(numpy.in1d(queries, A))[0]
def fun12():
index = set(A)
for _ in range(q):
[i for i,x in enumerate(queries) if x in index]
# these three solve question two (find according entries of B)
def fun21():
index = { k:i for i,k in enumerate(A, 1) }
for _ in range(q):
[index[i] for i in queries if i in index]
def fun22():
index = { k:i for i,k in enumerate(A, 1) }
for _ in range(q):
list(filter(None, map(index.get, queries)))
def findit(keys, values, key):
i = bisect.bisect(keys, key)
if i == len(keys) or keys[i] != key:
return None
return values[i]
def fun23():
keys, values = zip(*sorted((k,i) for i,k in enumerate(A,1)))
for _ in range(q):
list(filter(None, [findit(keys, values, x) for x in queries]))
if not is_pypy:
# note this does not filter out nonexisting elements
def fun24():
I = numpy.argsort(A)
ss = numpy.searchsorted
maxi = len(I)
for _ in range(q):
a = ss(A, queries, sorter=I)
I[a[a<maxi]]
tests = ("fun12", "fun21", "fun22", "fun23")
if not is_pypy: tests = ("fun11",) + tests + ("fun24",)
if is_pypy:
# warmup
for f in tests:
timeit.timeit("%s()" % f, setup = "from __main__ import %s" % f, number=20)
# actual timing
for f in tests:
print("%s: %.3f" % (f, timeit.timeit("%s()" % f, setup = "from __main__ import %s" % f, number=3)))
Results:
$ python2 -V
Python 2.7.6
$ python3 -V
Python 3.3.3
$ pypy -V
Python 2.7.3 (87aa9de10f9ca71da9ab4a3d53e0ba176b67d086, Dec 04 2013, 12:50:47)
[PyPy 2.2.1 with GCC 4.8.2]
$ python2 test.py
fun11: 1.016
fun12: 0.349
fun21: 0.302
fun22: 0.276
fun23: 2.432
fun24: 0.897
$ python3 test.py
fun11: 0.973
fun12: 0.382
fun21: 0.423
fun22: 0.341
fun23: 3.650
fun24: 0.894
$ pypy ~/tmp/test.py
fun12: 0.087
fun21: 0.073
fun22: 0.128
fun23: 1.131
You can tweak n (size of A), m (size of random_list) and q (number of queries) to your scenario. To my surprise, my attempt to be clever and use builtin functions instead of list comps has not paid off, since fun22 is not a lot faster than fun21 (only ~10% In Python 2 and ~25% in Python 3, but almost 75% slower in PyPy). A case of premature optimization here. I guess the difference is due to the fact that fun22 builds up an unnecessary temporary list per query in Python 2. We also see that binary search is pretty bad (look at fun23).
def numpy_optimized(index, values):
I = np.argsort(values)
Q = np.searchsorted(values, index, sorter=I)
return I[Q]
This calculates the same thing as OP, but with the indices in matching order to the values queried, which I imagine is an improvement in functionality. It is up to twice as fast as OP's solution on my machine; which puts it slightly ahead of the non-pypy solutions, if I interpret your benchmarks correctly.
Or in case we cannot assume all index are present in values, and would like missing queries to be silently dropped:
def numpy_optimized_filtered(index, values):
I = np.argsort(values)
Q = np.searchsorted(values, index, sorter=I)
Z = I[Q]
return Z[values[Z]==index]
Related
Consider some vector:
import numpy as np
v = np.arange(10)
Assume we need to find last 2 indexes satisfying some condition.
For example in Matlab it would be written e.g.
find(v <5 , 2,'last')
answer = [ 3 , 4 ] (Note: Matlab indexing from 1)
Question: What would be the clearest way to do that in Python ?
"Nice" solution should STOP search when it finds 2 desired results, it should NOT search over all elements of vector.
So np.where does not seems to be "nice" in that sense.
We can easyly write that using "for", but is there any alternative way ?
I am afraid using "for" since it might be slow (at least it is very much so in Matlab).
This attempt doesn't use numpy, and it is probably not very idiomatic.
Nevertheless, if I understand it correctly, zip, filter and reversed are all lazy iterators that take only the elements that they really need. Therefore, you could try this:
x = list(range(10))
from itertools import islice
res = reversed(list(map(
lambda xi: xi[1],
islice(
filter(
lambda xi: xi[0] < 5,
zip(reversed(x), reversed(range(len(x))))
),
2
)
)))
print(list(res))
Output:
[3, 4]
What it does (from inside to outside):
create index range
reverse both array and indices
zip the reversed array with indices
filter the two (value, index)-pairs that you need, extract them by islice
Throw away the values, retain only indices with map
reverse again
Even though it looks somewhat monstrous, it should all be lazy, and stop after it finds the first two elements that you are looking for. I haven't compared it with a simple loop, maybe just using a loop would be both simpler and faster.
Any solution you'd find will iterate over the list even if the loop is 'hidden' inside a function.
The solution to your problem depends on the assumptions you can make e.g. is the list sorted?
for the general case I'd iterate over the loop starting at the end:
def find(condition, k, v):
indices = []
for i, var in enumerate(reversed(v)):
if condition(var):
indices.append(len(v) - i - 1)
if len(indices) >= k:
break
return indices
The condition should then be passed as a function, so you can use a lambda:
v = range(10)
find(lambda x: x < 5, 3, v)
will output
[4, 3, 2]
I'm not aware of a "good" numpy solution to short-circuiting.
The most principled way to go would be using something like Cython which to brutally oversimplify it adds fast loops to Python. Once you have set that up it would be easy.
If you do not want to do that you'd have to employ some gymnastics like:
import numpy as np
def find_last_k(vector, condition, k, minchunk=32):
if k > minchunk:
minchunk = k
l, r = vector.size - minchunk, vector.size
found = []
n_found = 0
while r > 0:
if l <= 0:
l = 0
found.append(l + np.where(condition(vector[l:r]))[0])
n_found += len(found[-1])
if n_found >= k:
break
l, r = 3 * l - 2 * r, l
return np.concatenate(found[::-1])[-k:]
This tries balancing loop overhead and numpy "inflexibility" by searching in chunks, which we grow exponentially until enough hits are found.
Not exactly pretty, though.
This is what I've found that seems to do this job for the example described (using argwhere which returns all indices that meet the criteria and then we find the last two of these as a numpy array):
ind = np.argwhere(v<5)
ind[-2:]
This searches through the entire array so is not optimal but is easy to code.
I wrote a simple script that test the speed and this is what I found out. Actually for loop was fastest in my case. That really suprised me, check out bellow (was calculating sum of squares). Is that because it holds list in memory or is that intended? Can anyone explain this.
from functools import reduce
import datetime
def time_it(func, numbers, *args):
start_t = datetime.datetime.now()
for i in range(numbers):
func(args[0])
print (datetime.datetime.now()-start_t)
def square_sum1(numbers):
return reduce(lambda sum, next: sum+next**2, numbers, 0)
def square_sum2(numbers):
a = 0
for i in numbers:
i = i**2
a += i
return a
def square_sum3(numbers):
sqrt = lambda x: x**2
return sum(map(sqrt, numbers))
def square_sum4(numbers):
return(sum([i**2 for i in numbers]))
time_it(square_sum1, 100000, [1, 2, 5, 3, 1, 2, 5, 3])
time_it(square_sum2, 100000, [1, 2, 5, 3, 1, 2, 5, 3])
time_it(square_sum3, 100000, [1, 2, 5, 3, 1, 2, 5, 3])
time_it(square_sum4, 100000, [1, 2, 5, 3, 1, 2, 5, 3])
0:00:00.302000 #Reduce
0:00:00.144000 #For loop
0:00:00.318000 #Map
0:00:00.290000 #List comprehension`
Update - when I tried longer loops there are the results.
time_it(square_sum1, 100, range(1000))
time_it(square_sum2, 100, range(1000))
time_it(square_sum3, 100, range(1000))
time_it(square_sum4, 100, range(1000))
0:00:00.068992
0:00:00.062955
0:00:00.069022
0:00:00.057446
Python function calls have overheads which make them relatively slow, so code that uses a simple expression will always be faster than code that wraps that expression in a function; it doesn't matter whether it's a normal def function or a lambda. For that reason, it's best to avoid map or reduce if you are going to pass them a Python function if you can do the equivalent job with a plain expression in a for loop or a comprehension or generator expression.
There are a couple of minor optimizations that will speed up some of your functions. Don't make unnecessary assignments. Eg,
def square_sum2a(numbers):
a = 0
for i in numbers:
a += i ** 2
return a
Also, i * i is quite a bit faster than i ** 2 because multiplication is faster than exponentiation.
As I mentioned in the comments, it's more efficient to pass sum a generator than a list comprehension, especially if the loop is large; it probably won't make difference with a small list of length 8, but it will be quite noticeable with large lists.
sum(i*i for i in numbers)
As Kelly Bundy mentions in the comments, the generator expression version isn't actually faster than the equivalent list comprehension. Generator expressions are more efficient than list comps in terms of RAM use, but they're not necessarily faster. And when the sequence length is small, the RAM usage differences are negligible, although there is also the time required to allocate & free the RAM used.
I just ran a few tests, with a larger data list. The list comp is still the winner (usually), but the speed differences are generally around 5-10%.
BTW, you shouldn't use sum or next as variable names as that masks the built-in functions with the same names. It won't hurt anything here, but it's still not a good idea, and it makes your code look odd in an editor with more comprehensive syntax highlighting than the SO syntax highlighter.
Here's a new version of your code that uses the timeit module. It does 3 repetitions of 10,000 loops each and sorts the results. As explained in the timeit docs, the important figure to look at in the series of the repetitions is the minimum one.
In a typical case, the lowest value gives a lower bound for how fast
your machine can run the given code snippet; higher values in the
result vector are typically not caused by variability in Python’s
speed, but by other processes interfering with your timing accuracy.
So the min() of the result is probably the only number you should be
interested in.
from timeit import Timer
from functools import reduce
def square_sum1(numbers):
return reduce(lambda total, u: total + u**2, numbers, 0)
def square_sum1a(numbers):
return reduce(lambda total, u: total + u*u, numbers, 0)
def square_sum2(numbers):
a = 0
for i in numbers:
i = i**2
a += i
return a
def square_sum2a(numbers):
a = 0
for i in numbers:
a += i * i
return a
def square_sum3(numbers):
sqr = lambda x: x**2
return sum(map(sqr, numbers))
def square_sum3a(numbers):
sqr = lambda x: x*x
return sum(map(sqr, numbers))
def square_sum4(numbers):
return(sum([i**2 for i in numbers]))
def square_sum4a(numbers):
return(sum(i*i for i in numbers))
funcs = (
square_sum1,
square_sum1a,
square_sum2,
square_sum2a,
square_sum3,
square_sum3a,
square_sum4,
square_sum4a,
)
data = [1, 2, 5, 3, 1, 2, 5, 3]
def time_test(loops, reps):
''' Print timing stats for all the functions '''
timings = []
for func in funcs:
fname = func.__name__
setup = 'from __main__ import data, ' + fname
cmd = fname + '(data)'
t = Timer(cmd, setup)
result = t.repeat(reps, loops)
result.sort()
timings.append((result, fname))
timings.sort()
for result, fname in timings:
print('{0:14} {1}'.format(fname, result))
loops, reps = 10000, 3
time_test(loops, reps)
output
square_sum2a [0.03815755599862314, 0.03817843700016965, 0.038571521999983815]
square_sum4a [0.06384095800240175, 0.06462285799716483, 0.06579178199899616]
square_sum3a [0.07395686000018031, 0.07405958899835241, 0.07463337299850537]
square_sum1a [0.07867341000019223, 0.0788448769999377, 0.07908406700153137]
square_sum2 [0.08781023399933474, 0.08803317899946705, 0.08846573399932822]
square_sum4 [0.10260082300010254, 0.10360279499946046, 0.10415067900248687]
square_sum3 [0.12363515399920288, 0.12434166299863136, 0.1273790529994585]
square_sum1 [0.1276186039976892, 0.13786184099808452, 0.16315817699796753]
The results were obtained on an old single core 32 bit 2GHz machine running Python 3.6.0 on Linux.
This is almost independent of the underlying programming language, as in abstractions do not come for free.
Meaning: there is always certain cost for calling methods. A stack needs to be established; control flow needs to "jump". And when you think of lower levels, such as CPUs: probably the code for that method needs to be loaded, and so on.
In other words: when your primary requirement is hard-core number crunching, then you have to balance ease-of-use with the cost of the corresponding abstractions.
Beyond: if you focus on speed, then you should look beyond python, or at least beyond "ordinary" python. Instead you could turn to modules such as numpy.
If I have a list that is already sorted and use the in keyword, for example:
a = [1,2,5,6,8,9,10]
print 8 in a
I think this should do a sequential search but can I make it faster by doing binary search?
Is there a pythonic way to search in a sorted list?
The standard library has the bisect module which supports searching in sorted sequences.
However, for small lists, I would bet that the C implementation behind the in operator would beat out bisect. You'd have to measure with a bunch of common cases to determine the real break-even point on your target hardware...
It's worth noting that if you can get away with an unordered iterable (i.e. a set), then you can do the lookup in O(1) time on average (using the in operator), compared to bisection on a sequence which is O(logN) and the in operator on a sequence which is O(N). And, with a set you also avoid the cost of sorting it in the first place :-).
There is a binary search for Python in the standard library, in module bisect. It does not support in/contains as is, but you can write a small function to handle it:
from bisect import bisect_left
def contains(a, x):
"""returns true if sorted sequence `a` contains `x`"""
i = bisect_left(a, x)
return i != len(a) and a[i] == x
Then
>>> contains([1,2,3], 3)
True
>>> contains([1,2,3], 4)
False
This is not going to be very speedy though, as bisect is written in Python, and not in C, so you'd probably find sequential in faster for quite a lot cases. bisect has had an optional C acceleration in CPython since Python 2.4.
It is hard to time the exact break-even point in CPython. This is because the code is written in C; if you check for a value that is greater to or less than any value in the sequence, then the CPU's branch prediction will play tricks on you, and you get:
In [2]: a = list(range(100))
In [3]: %timeit contains(a, 101)
The slowest run took 8.09 times longer than the fastest. This could mean that an intermediate result is being cached
1000000 loops, best of 3: 370 ns per loop
Here, the best of 3 is not representative of the true running time of the algorithm.
But tweaking tests, I've reached the conclusion that bisecting might be faster than in for lists having as few as 30 elements.
However, if you're doing really many in operations you ought to use a set; you can convert the list once into a set (it does not even be sorted) and the in operation will be asymptotically faster than any binary search ever would be:
>>> a = [10, 6, 8, 1, 2, 5, 9]
>>> a_set = set(a)
>>> 10 in a_set
True
On the other hand, sorting a list has greater time-complexity than building a set, so most of the time a set would be the way to go.
I would go with this pure one-liner (providing bisect is imported):
a and a[bisect.bisect_right(a, x) - 1] == x
Stress test:
from bisect import bisect_right
from random import randrange
def contains(a, x):
return a and a[bisect.bisect_right(a, x) - 1] == x
for _ in range(10000):
a = sorted(randrange(10) for _ in range(10))
x = randrange(-5, 15)
assert (x in a) == contains(a, x), f"Error for {x} in {a}"
... doesn't print anything.
What is the fasted way to get a sorted, unique list in python? (I have a list of hashable things, and want to have something I can iterate over - doesn't matter whether the list is modified in place, or I get a new list, or an iterable. In my concrete use case, I'm doing this with a throwaway list, so in place would be more memory efficient.)
I've seen solutions like
input = [5, 4, 2, 8, 4, 2, 1]
sorted(set(input))
but it seems to me that first checking for uniqueness and then sorting is wasteful (since when you sort the list, you basically have to determine insertion points, and thus get the uniqueness test as a side effect). Maybe there is something more along the lines of unix's
cat list | sort | uniq
that just picks out consecutive duplications in an already sorted list?
Note in the question ' Fastest way to uniqify a list in Python ' the list is not sorted, and ' What is the cleanest way to do a sort plus uniq on a Python list? ' asks for the cleanest / most pythonic way, and the accepted answer suggests sorted(set(input)), which I'm trying to improve on.
I believe sorted(set(sequence)) is the fastest way of doing it.
Yes, set iterates over the sequence but that's a C-level loop, which is a lot faster than any looping you would do at python level.
Note that even with groupby you still have O(n) + O(nlogn) = O(nlogn) and what's worst is that groupby will require a python-level loop, which increases dramatically the constants in that O(n) thus in the end you obtain worst results.
When speaking of CPython the way to optimize things is to do as much as you can at C-level (see this answer to have an other example of counter-intuitive performance). To have a faster solution you must reimplement a sort, in a C-extensions. And even then, good luck with obtaining something as fast as python's Timsort!
A small comparison of the "canonical solution" versus the groupby solution:
>>> import timeit
>>> sequence = list(range(500)) + list(range(700)) + list(range(1000))
>>> timeit.timeit('sorted(set(sequence))', 'from __main__ import sequence', number=1000)
0.11532402038574219
>>> import itertools
>>> def my_sort(seq):
... return list(k for k,_ in itertools.groupby(sorted(seq)))
...
>>> timeit.timeit('my_sort(sequence)', 'from __main__ import sequence, my_sort', number=1000)
0.3162040710449219
As you can see it's 3 times slower.
The version provided by jdm is actually even worse:
>>> def make_unique(lst):
... if len(lst) <= 1:
... return lst
... last = lst[-1]
... for i in range(len(lst) - 2, -1, -1):
... item = lst[i]
... if item == last:
... del lst[i]
... else:
... last = item
...
>>> def my_sort2(seq):
... make_unique(sorted(seq))
...
>>> timeit.timeit('my_sort2(sequence)', 'from __main__ import sequence, my_sort2', number=1000)
0.46814608573913574
Almost 5 times slower.
Note that using seq.sort() and then make_unique(seq) and make_unique(sorted(seq)) are actually the same thing, since Timsort uses O(n) space you always have some reallocation, so using sorted(seq) does not actually change much the timings.
The jdm's benchmarks give different results because the input he is using are way too small and thus all the time is taken by the time.clock() calls.
Maybe this is not the answer you are searching for, but anyway, you should take this into your consideration.
Basically, you have 2 operations on a list:
unique_list = set(your_list) # O(n) complexity
sorted_list = sorted(unique_list) # O(nlogn) complexity
Now, you say "it seems to me that first checking for uniqueness and then sorting is wasteful", and you are right. But, how bad really is that redundant step? Take n = 1000000:
# sorted(set(a_list))
O(n) => 1000000
o(nlogn) => 1000000 * 20 = 20000000
Total => 21000000
# Your fastest way
O(nlogn) => 20000000
Total: 20000000
Speed gain: (1 - 20000000/21000000) * 100 = 4.76 %
For n = 5000000, speed gain: ~1.6 %
Now, is that optimization worth it?
This is just something I whipped up in a couple minutes. The function modifies a list in place, and removes consecutive repeats:
def make_unique(lst):
if len(lst) <= 1:
return lst
last = lst[-1]
for i in range(len(lst) - 2, -1, -1):
item = lst[i]
if item == last:
del lst[i]
else:
last = item
Some representative input data:
inp = [
(u"Tomato", "de"), (u"Cherry", "en"), (u"Watermelon", None), (u"Apple", None),
(u"Cucumber", "de"), (u"Lettuce", "de"), (u"Tomato", None), (u"Banana", None),
(u"Squash", "en"), (u"Rubarb", "de"), (u"Lemon", None),
]
Make sure both variants work as wanted:
print inp
print sorted(set(inp))
# copy because we want to modify it in place
inp1 = inp[:]
inp1.sort()
make_unique(inp1)
print inp1
Now to the testing. I'm not using timeit, since I don't want to time the copying of the list, only the sorting. time1 is sorted(set(...), time2 is list.sort() followed by make_unique, and time3 is the solution with itertools.groupby by Avinash Y.
import time
def time1(number):
total = 0
for i in range(number):
start = time.clock()
sorted(set(inp))
total += time.clock() - start
return total
def time2(number):
total = 0
for i in range(number):
inp1 = inp[:]
start = time.clock()
inp1.sort()
make_unique(inp1)
total += time.clock() - start
return total
import itertools
def time3(number):
total = 0
for i in range(number):
start = time.clock()
list(k for k,_ in itertools.groupby(sorted(inp)))
total += time.clock() - start
return total
sort + make_unique is approximately as fast as sorted(set(...)). I'd have to do a couple more iterations to see which one is potentially faster, but within the variations they are very similar. The itertools version is a bit slower.
# done each 3 times
print time1(100000)
# 2.38, 3.01, 2.59
print time2(100000)
# 2.88, 2.37, 2.6
print time3(100000)
# 4.18, 4.44, 4.67
Now with a larger list (the + str(i) is to prevent duplicates):
old_inp = inp[:]
inp = []
for i in range(100):
for j in old_inp:
inp.append((j[0] + str(i), j[1]))
print time1(10000)
# 40.37
print time2(10000)
# 35.09
print time3(10000)
# 40.0
Note that if there are a lot of duplicates in the list, the first version is much faster (since it does less sorting).
inp = []
for i in range(100):
for j in old_inp:
#inp.append((j[0] + str(i), j[1]))
inp.append((j[0], j[1]))
print time1(10000)
# 3.52
print time2(10000)
# 26.33
print time3(10000)
# 20.5
import numpy as np
np.unique(...)
The np.unique function returns an ndarray unique and sorted based on an array-like parameter. This will work with any numpy types, but also regular python values that are orderable.
If you need a regular python list, use np.unique(...).tolist()
>>> import itertools
>>> a=[2,3,4,1,2,7,8,3]
>>> list(k for k,_ in itertools.groupby(sorted(a)))
[1, 2, 3, 4, 7, 8]
I've profiled my application, and it spends 90% of its time in plus_minus_variations.
The function finds ways to make various numbers given a list of numbers using addition and subtraction.
For example:
Input
1, 2
Output
1+2=3
1-2=-1
-1+2=1
-1-2=-3
This is my current code. I think it could be improved a lot in terms of speed.
def plus_minus_variations(nums):
result = dict()
for i, ops in zip(xrange(2 ** len(nums)), \
itertools.product([-1, 1], repeat=len(nums))):
total = sum(map(operator.mul, ops, nums))
result[total] = ops
return result
I'm mainly looking for a different algorithm to approach this with. My current one seems pretty inefficient. However, if you have optimization suggestions about the code itself, I'd be happy to hear those too.
Additional:
It's okay if the result is missing some of the answers (or has some extraneous answers) if it finishes a lot faster.
If there are multiple ways to get a number, any of them are fine.
For the list sizes I'm using, 99.9% of the ways produce duplicate numbers.
It's okay if the result doesn't have the way that the numbers were produced, if, again, it finishes a lot faster.
If it is ok not to get trace of number producing there is no reasons to recalculate sum of number combination every time. You can store intermediate results:
def combine(l,r):
res = set()
for x in l:
for y in r:
res.add( x+y )
res.add( x-y )
res.add( -x+y )
res.add( -x-y )
return list(res)
def pmv(nums):
if len(nums) > 1:
l = pmv( nums[:len(nums)/2] )
r = pmv( nums[len(nums)/2:] )
return combine( l, r )
return nums
EDIT: if the way of number generation is important you can use this variant:
def combine(l,r):
res = dict()
for x,q in l.iteritems():
for y,w in r.iteritems():
if not res.has_key(x+y):
res[x+y] = w+q
res[-x-y] = [-i for i in res[x+y]]
if not res.has_key(x-y):
res[x-y] = w+[-i for i in q]
res[-x+y] = [-i for i in res[x-y]]
return res
def pmv(nums):
if len(nums) > 1:
l = pmv( nums[:len(nums)/2] )
r = pmv( nums[len(nums)/2:] )
return combine( l, r )
return {nums[0]:[1]}
My tests shows that it is still faster than the other solutions.
EDITED:
Aha!
Code is in Python 3,
inspired by tyz:
from functools import reduce # only in Python 3
def process(old, num):
new = set(map(num.__add__, old)) # use itertools.imap for Python 2
new.update(map((-num).__add__, old))
return new
def pmv(nums):
n = iter(nums)
x = next(n)
result = {x, -x} # set([x, -x]) for Python 2
return reduce(process, n, result)
Instead of split half and recursive, I use reduce to compute it one by one. that extremely reduced the times of function calls.
Take less than 1 sec to compute 256 numbers.
Why product then mult?
def pmv(nums):
return {sum(i):i for i in itertools.product(*((num, -num) for num in nums))}
Can be faster without how the numbers were produced:
def pmv(nums):
return set(map(sum, itertools.product(*((num, -num) for num in nums))))
This seems to be significantly faster for large random lists, I guess you could further micro-optimize it, but I prefer readability.
I chunk the list into smaller pieces and create variations for it. Since you get a lot less than 2 ** len(chunk) variatons it's going to be faster. Chunk length is 6, you can play with it to see what's the optimal chunk length.
def pmv(nums):
chunklen=6
res = dict()
res[0] = ()
for i in xrange(0, len(nums), chunklen):
part = plus_minus_variations(nums[i:i+chunklen])
resnew = dict()
for (i,j) in itertools.product(res, part):
resnew[i + j] = tuple(list(res[i]) + list(part[j]))
res = resnew
return res
You can get something like a 50% speedup (at least for short lists) just by doing:
from itertools import product, imap
from operator import mul
def plus_minus_variations(nums):
result = {}
for ops in product((-1, 1), repeat=len(nums)):
result[sum(imap(mul, ops, nums))] = ops
return result
imap won't create intermediate lists you don't need. Importing into the local namespace saves the time attribute lookup takes. Tuples are faster than lists. Don't store unneeded intermediate items.
I tried this with a dict comprehension but it was a bit slower. I tried it with a set comprehension (not saving the ops) and it was the same speed.
I don't know why you were using zip and xrange at all... you weren't using the result in your calculation.
Edit: I get significant speedups with this version all the way up to the point where your version gives a memory error, not just for short lists.
From a mathematical point of view you finally arive at all multiples of the greatest common divisor of your startvalues.
For example:
startvalues 2,4. then the gcd(2,4) is 2, so the generated numbers are .. -4, -2, 0, 2, 4, ...
startvalues 3,5. then the gcd(3,5) is 1, you get all integers.
startvalues 12, 18, 15. the gcd(12,15,18) is 3, you get .. -6, -3, 0, 3, 6, ....
This simple iterative method computes all possible sums. It could be about 5 times faster than the recursive method by #tyz.
def pmv(nums):
sums = set()
sums.add(0)
for i in range(0, len(nums)):
partialsums = set()
for s in sums:
partialsums.add(s + nums[i])
partialsums.add(s - nums[i])
sums = partialsums
return sums