Recursion, memoization and mutable default arguments in Python - python

"Base" meaning without just using lru_cache. All of these are "fast enough" -- I'm not looking for the fastest algorithm -- but the timings surprised me so I was hoping I could learn something about how Python "works".
Simple loop (/tail recursion):
def fibonacci(n):
a, b = 0, 1
if n in (a, b): return n
for _ in range(n - 1):
a, b = b, a + b
return b
Simple memoized:
def fibonacci(n, memo={0:0, 1:1}):
if len(memo) <= n:
memo[n] = fibonacci(n - 1) + fibonacci(n - 2)
return memo[n]
Using a generator:
def fib_seq():
a, b = 0, 1
yield a
yield b
while True:
a, b = b, a + b
yield b
def fibonacci(n):
return next(x for (i, x) in enumerate(fib_seq()) if i == n)
I expected the first, being dead simple, to be the fastest. It's not. The second is by far the fastest, despite the recursion and lots of function calls. The third is cool, and uses "modern" features, but is even slower, which is disappointing. (I was tempted to think of generators as in some ways an alternative to memoization -- since they remember their state -- and since they're implemented in C I was hoping they'd be faster.)
Typical results:
loop: about 140 μs
memo: about 430 ns
genr: about 250 μs
So can anyone explain, in particular, why memoization is an order of magnitude faster than a simple loop?
EDIT:
Clear now that I have (like many before me) simply stumbled upon Python's mutable default arguments. This behavior explains the real and the apparent gains in execution speeds.

What you're seeing is the whole point of memoization. The first time you call the function, the memo cache is empty and it has to recurse. But the next time you call it with the same or a lower parameter, the answer is already in the cache, so it returns immediately. if you perform thousands of calls, you're amortizing that first call's time over all the other calls. That's what makes memoization such a useful optimization, you only pay the cost the first time.
If you want to see how long it takes when the cache is fresh and you have to do all the recursions, you can pass the initial cache as an explicit argument in the benchmark call:
fibonacci(100, {0:0, 1:1})

Related

Why is this recursive solution faster than the iterative solution?

I wrote a recursive solution for something today, and as it goes, curiosity led me down a weird path. I wanted to see how an optimized recursive solution compares to an iterative solution so I chose a classic, the Nth Fibonacci to test with.
I was surprised to find that the recursive solution with memoization is much faster than the iterative solution and I would like to know why.
Here is the code (using python3):
import time
import sys
sys.setrecursionlimit(10000)
## recursive:
def fibr(n, memo = {}):
if n <= 1:
return n
if n in memo:
return memo[n]
memo[n] = fibr(n-1, memo) + fibr(n-2, memo)
return memo[n]
## iterative:
def fibi(n):
a, b = 0, 1
for _ in range(n):
a, b = b, a + b
return a
rstart = time.time()
for n in range(10000):
fibr(n)
rend = time.time()
istart = time.time()
for n in range(10000):
fibi(n)
iend = time.time()
print(f"recursive: {rend-rstart}")
print(f"iterative: {iend-istart}")
My results:
recursive: 0.010010004043579102
iterative: 6.274333238601685
Unless I'm mistaken, both the recursive solution and the iterative solution are about as optimized as they can get? If I'm wrong about that, I'd like to know why.
If not, what would cause the iterative solution to be so much slower? It seems to be slower for all values of n, but harder to notice when n is something more reasonable, like <1000. (I'm using 10000 as you can see above)
Some things I've tried:
I thought it might be the magic swapping in the iterative solution a, b = b, a + b, so I tried replacing it with a more traditional "swap" pattern:
tmp = a + b
a = b
b = tmp
#a, b = b, a + b
But the results are basically the same, so that's not the problem there.
Re-arrange the code so that the iterative solution runs first, just to see if there was some weird cache issue at the OS level? It doesn't change the results so that's not it, probably?
My understanding here (and it might be wrong) is that the recursive solution with memoization is O(n). And the iterative solution is also O(n) simply because it iterates from 0..n.
Am I missing something really obvious? I feel like I must be missing something here.
You might expect that def fibr(n, memo = {}): — when called without a supplied memo — turns into something translated a bit like:
def _fibr_without_defaults(n, memo):
...
def fibr(n):
return _fibr_without_defaults(n, {})
That is, if memo is missing it implicitly gets a blank dictionary per your default.
In actuality, it translates into something more like:
def _fibr_without_defaults(n, memo):
...
_fibr_memo_default = {}
def fibr(n):
return _fibr_without_defaults(n, _fibr_memo_default)
That is, a default argument value is not "use this to construct a default value" but instead "use this actual value by default". Every single call to fibr you make (without supplying memo) is sharing a default memo dictionary.
That means in:
for n in range(10000):
fibr(n)
The prior iterations of your loop are filling out memo for the future iterations. When n is 1000, for example, all the work performed by n<=999 is still stored.
By contrast, the iterative version always starts iterating from 0, no matter what work prior iterative calls performed.
If you perform the translation above by hand, so that you really do get a fresh empty memo for each call, you'll see the iterative version is faster. (Makes sense; inserting things into a dictionary and retrieving them just to do the same work as simple iteration will be slower.)
They are not the same.
The recursive version uses the Memoization pattern, calculating only once the result of fibr(n) and storing/caching the result for insta-return if needed again. It's an O(n) algorithm.
The iterative version calculates everything from scratch. It's an O(n2) algorithm (I think).

Recursion with memory vs loop

I've made two functions for computing the Fibonacci Sequence, one using recursion with memory and one using a loop;
def fib_rec(n, dic = {0 : 0, 1 : 1}):
if n in dic:
return dic[n]
else:
fib = fib_rec(n - 2, dic) + fib_rec(n - 1, dic)
dic[n] = fib
return fib
def fib_loop(n):
if n == 0 or n == 1:
return n
else:
smaller = 0
larger = 1
for i in range(1, n):
smaller, larger = larger, smaller + larger
return larger
I've heard that the Fibonacci Sequence often is solved using recursion, but I'm wondering why. Both my algorithms are of linear time complexity, but the one using a loop will not have to carry a dictionary of all past Fibonacci numbers, it also won't exceed Python's recursion depth.
Is this problem solved using recursion only to teach recursion or am I missing something?
The usual recursive O(N) Fibonacci implementation is more like this:
def fib(n, a=0, b=1):
if n == 0: return a
if n == 1: return b
return fib(n - 1, b, a + b)
The advantage with this approach (aside from the fact that it uses O(1) memory) is that it is tail-recursive: some compilers and/or runtimes can take advantage of that to secretly convert it to a simple JUMP instruction. This is called tail-call optimization.
Python, sadly, doesn't use this strategy, so it will use extra memory for the call stack, which as you noted quickly runs into Python's recursion depth limit.
The Fibonacci sequence is mostly a toy problem, used for teaching people how to write algorithms and about big Oh notation. It has elegant functional solutions as well as showing the strengths of dynamic programming (basically your dictionary-based solution), but it's also practically a solved problem.
We can also go a lot faster. The page https://www.nayuki.io/page/fast-fibonacci-algorithms describes how. It includes a fast doubling algorithm written in Python:
#
# Fast doubling Fibonacci algorithm (Python)
# by Project Nayuki, 2015. Public domain.
# https://www.nayuki.io/page/fast-fibonacci-algorithms
#
# (Public) Returns F(n).
def fibonacci(n):
if n < 0:
raise ValueError("Negative arguments not implemented")
return _fib(n)[0]
# (Private) Returns the tuple (F(n), F(n+1)).
def _fib(n):
if n == 0:
return (0, 1)
else:
a, b = _fib(n // 2)
c = a * (b * 2 - a)
d = a * a + b * b
if n % 2 == 0:
return (c, d)
else:
return (d, c + d)

why is this o(n) three-way set disjointness algorithm slower than then o(n^3) version?

O(n) because converting list to set is O(n) time, getting intersection is O(n) time and len is O(n)
def disjoint3c(A, B, C):
"""Return True if there is no element common to all three lists."""
return len(set(A) & set(B) & set(C)) == 0
or similarly, should be clearly O(N)
def set_disjoint_medium (a, b, c):
a, b, c = set(a), set(b), set(c)
for elem in a:
if elem in b and elem in c:
return False
return True
yet this O(n^3) code:
def set_disjoint_slowest (a, b, c):
for e1 in a:
for e2 in b:
for e3 in c:
if e1 == e2 == e3:
return False
return True
runs faster
see time where algorithm one is the n^3, and algorithm three is the O(n) set code... algorithm two is actually n^2 where we optimize algorithm one by checking for disjointness before the third loop starts
Size Input (n): 10000
Algorithm One: 0.014993906021118164
Algorithm Two: 0.013481855392456055
Algorithm Three: 0.01955580711364746
Size Input (n): 100000
Algorithm One: 0.15916991233825684
Algorithm Two: 0.1279449462890625
Algorithm Three: 0.18677806854248047
Size Input (n): 1000000
Algorithm One: 1.581618070602417
Algorithm Two: 1.146049976348877
Algorithm Three: 1.8179030418395996
The comments made clarifications about the Big-Oh notations. So I will just start with testing the code.
Here is the setup I used for testing the speed of the code.
import random
# Collapsed these because already known
def disjoint3c(A, B, C):
def set_disjoint_medium (a, b, c):
def set_disjoint_slowest (a, b, c):
a = [random.randrange(100) for i in xrange(10000)]
b = [random.randrange(100) for i in xrange(10000)]
c = [random.randrange(100) for i in xrange(10000)]
# Ran timeit.
# Results with timeit module.
1-) 0.00635750419422
2-) 0.0061145967287
3-) 0.0487953200969
Now to the results, as you see, the O(n^3) solution runs 8 times slower than the other solutions. But this is still fast for such an algorithm(Even faster in your test). Why this happens ?
Because medium and slowest solutions you used, finishes the execution of the code as soon as a common element is detected. So the full complexity of the code is not realized. It breaks as soon as it finds an answer. Why the slowest solution ran almost as fast as the other ones in your test ? Probably because it finds the answer closer to the beginning of the lists.
To test this, you could create the lists like this. Try this yourself.
a = range(1000)
b = range(1000, 2000)
c = range(2000, 3000)
Now the real difference between the times will be obvious because the slowest solution will have to run until it finishes all iterations, because there is no common element.
So it is a situation of Worst case and Best case performance.
Not a part of the question edit: So, what if you want to retain the speed of finding early common occurances, but also don't want to increase complexity. I made a rough solution for that, maybe more experienced users can suggest faster code.
def mysol(a, b, c):
store = [set(), set(), set()]
# zip_longest for Python3, not izip_longest.
for i, j, k in itertools.izip_longest(a, b, c):
if i: store[0].add(i)
if j: store[1].add(j)
if k: store[2].add(k)
if (i in store[1] and i in store[2]) or (j in store[0] and i in store[2]) or (k in store[0] and i in store[1]):
return False
return True
What is basically being done in this code is, you avoid converting all the lists to sets in the beginning. Rather, iterate through all lists at the same time, add elements to sets, check for common occurances. So now, you keep the speed of finding an early solution, but it is still slow for the worst case that I shown.
For the speeds, this runs 3-4 times slower than your first two solutions in the worst case. But runs 4-10 times faster than those solutions in randomized lists.
Note: The fact that you are finding all common elements in three lists( in the first solution) unquestionably means that there is a faster solution by theory. Because you just need to know if there is even a single common element, and that knowledge is enough.
O notation ignores all the constant factors. So it will only answer for infinite data sets. For any finite set, it is only a rule of thumb.
With interpreted languages such as Python and R, constant factors can be pretty large. They need to create and collect many objects, which is all O(1) but not free. So it is fairly common to see 100x performance differences of virtually equivalent code, unfortunately.
Secondly, the first algorithm computes all common elements, while the others fail on the first. If you benchmark algX(a,a,a) (yes, all three sets be the same) then it will do much more work than the others!
I would not be surprised to see a sort-based O(n log n) algorithm to be very competitive (because sorting is usually incredibly well optimized). For integers, I would use numpy arrays, and by avoiding the python interpreter as much as possible you can get very fast. While numpys in1d and intersect will likely give you aan O(n^2) or O(n^3) algorithm, they may end up being faster as long as your sets are usually disjoint.
Also note that in your case, the sets won't necessarily be pairwise disjoint... algX(set(),a,a)==True.

How many combinations are possible?

The recursive formula for computing the number of ways of choosing k items out of a set of n items, denoted C(n,k), is:
1 if K = 0
C(n,k) = { 0 if n<k
c(n-1,k-1)+c(n-1,k) otherwise
I’m trying to write a recursive function C that computes C(n,k) using this recursive formula. The code I have written should work according to myself but it doesn’t give me the correct answers.
This is my code:
def combinations(n,k):
# base case
if k ==0:
return 1
elif n<k:
return 0
# recursive case
else:
return combinations(n-1,k-1)+ combinations(n-1,k)
The answers should look like this:
>>> c(2, 1)
0
>>> c(1, 2)
2
>>> c(2, 5)
10
but I get other numbers... don’t see where the problem is in my code.
I would try reversing the arguments, because as written n < k.
I think you mean this:
>>> c(2, 1)
2
>>> c(5, 2)
10
Your calls, e.g. c(2, 5) means that n=2 and k=5 (as per your definition of c at the top of your question). So n < k and as such the result should be 0. And that’s exactly what happens with your implementation. And all other examples do yield the actually correct results as well.
Are you sure that the arguments of your example test cases have the correct order? Because they are all c(k, n)-calls. So either those calls are wrong, or the order in your definition of c is off.
This is one of those times where you really shouldn't be using a recursive function. Computing combinations is very simple to do directly. For some things, like a factorial function, using recursion there is no big deal, because it can be optimized with tail-recursion anyway.
Here's the reason why:
Why do we never use this definition for the Fibonacci sequence when we are writing a program?
def fibbonacci(idx):
if(idx < 2):
return idx
else:
return fibbonacci(idx-1) + fibbonacci(idx-2)
The reason is because that, because of recursion, it is prohibitively slow. Multiple separate recursive calls should be avoided if possible, for the same reason.
If you do insist on using recursion, I would recommend reading this page first. A better recursive implementation will require only one recursive call each time. Rosetta code seems to have some pretty good recursive implementations as well.

python 2.7 - Recursive Fibonacci blows up

I have two functions fib1 and fib2 to calculate Fibonacci.
def fib1(n):
if n < 2:
return 1
else:
return fib1(n-1) + fib1(n-2)
def fib2(n):
def fib2h(s, c, n):
if n < 1:
return s
else:
return fib2h(c, s + c, n-1)
return fib2h(1, 1, n)
fib2 works fine until it blows up the recursion limit. If understand correctly, Python doesn't optimize for tail recursion. That is fine by me.
What gets me is fib1 starts to slow down to a halt even with very small values of n. Why is that happening? How come it doesn't hit the recursion limit before it gets sluggish?
Basically, you are wasting lots of time by computing the fib1 for the same values of n over and over. You can easily memoize the function like this
def fib1(n, memo={}):
if n in memo:
return memo[n]
if n < 2:
memo[n] = 1
else:
memo[n] = fib1(n-1) + fib1(n-2)
return memo[n]
You'll notice that I am using an empty dict as a default argument. This is usually a bad idea because the same dict is used as the default for every function call.
Here I am taking advantage of that by using it to memoize each result I calculate
You can also prime the memo with 0 and 1 to avoid needing the n < 2 test
def fib1(n, memo={0: 1, 1: 1}):
if n in memo:
return memo[n]
else:
memo[n] = fib1(n-1) + fib1(n-2)
return memo[n]
Which becomes
def fib1(n, memo={0: 1, 1: 1}):
return memo.setdefault(n, memo.get(n) or fib1(n-1) + fib1(n-2))
Your problem isn't python, it's your algorithm. fib1 is a good example of tree recursion. You can find a proof here on stackoverflow that this particular algorithm is (~θ(1.6n)).
n=30 (apparently from the comments) takes about a third of a second. If computational time scales up as 1.6^n, we'd expect n=100 to take about 2.1 million years.
Think of the recursion trees in each. The second version is a single branch of recursion with the addition taking place in the parameter calculations for the function calls, and then it returns the values back up. As you have noted, Python doesn't require tail recursion optimization, but if tail call optimization were a part of your interpreter, the tail recursion optimization could be triggered as well.
The first version, on the other hand, requires 2 recursion branches at EACH level! So the number of potential executions of the function skyrockets considerably. Not only that, but most of the work is repeated twice! Consider: fib1(n-1) eventually calls fib1(n-1) again, which is the same as calling fib1(n-2) from the point of reference of the first call frame. But after that value is calculated, it must be added to the value of fib1(n-2) again! So the work is needlessly duplicated many times.

Categories