Why Numba doesn't improve this recursive function - python

I have an array of true/false values with a very simple structure:
# the real array has hundreds of thousands of items
positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)
I want to traverse this array and output the places where changes happen (true becomes false or the contrary). For this purpose, I've put together two different approaches:
a recursive binary search (see if all values are the same, if not, split in two, then recurse)
a purely iterative search (loop through all elements and compare with the previous/next one)
Both versions give exactly the result that I want, however Numba has a greater effect on one than another. With a dummy array of 300k values, here are the performance results:
Performance results with array of 300k elements
pure Python binary-search runs in 11 ms
pure Python iterative-search runs in 1.1 s (100x slower than binary-search)
Numba binary-search runs in 5 ms (2 times faster than pure Python equivalent)
Numba iterative-search runs in 900 µs (1,200 times faster than pure Python equivalent)
As a result, when using Numba, binary_search is 5x slower than iterative_search, while in theory it should be 100x faster (it should be expected to run in 9 µs if it was properly accelerated).
What can be done to make Numba accelerate binary-search as much as it accelerates iterative-search?
Code for both approaches (along with a sample position array) is available on this public gist: https://gist.github.com/JivanRoquet/d58989aa0a4598e060ec2c705b9f3d8f
Note: Numba is not running binary_search() in object mode, because when mentioning nopython=True, it doesn't complain and happily compiles the function.

You can find the positions of value changes by using np.diff, there is no need to run a more complicated algorithm, or to use numba:
positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)
dpos = np.diff(positions)
# array([ True, False, False, True, False, False, False, True, False, False])
This works, because False - True == -1 and np.bool(-1) == True.
It performs quite well on my battery powered (= throttled due to energy saving mode), and several years old laptop:
In [52]: positions = np.random.randint(0, 2, size=300_000, dtype=bool)
In [53]: %timeit np.diff(positions)
633 µs ± 4.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
I'd imagine that writing your own diff in numba should yield similar performance.
EDIT: The last statement is false, I implemented a simple diff function using numba, and it's more than a factor of 10 faster than the numpy one (but it obviously also has much less features, but should be sufficient for this task):
#numba.njit
def ndiff(x):
s = x.size - 1
r = np.empty(s, dtype=x.dtype)
for i in range(s):
r[i] = x[i+1] - x[i]
return r
In [68]: np.all(ndiff(positions) == np.diff(positions))
Out[68]: True
In [69]: %timeit ndiff(positions)
46 µs ± 138 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

The main issue is that you are not performing an apple-to-apple comparison.
What you provide is not an iterative and a recursive version of the same algorithm.
You are proposing two fundamentally different algorithms, which happen to be recursive/iterative.
In particular you are using NumPy built-ins a lot more in the recursive approach, so no wonder that there is such a staggering difference in the two approaches. It should also come at no surprise that the Numba JITting is more effective when you are avoiding NumPy built-ins.
Eventually, the recursive algorithm seems to be less efficient as there is some hidden nested looping in the np.all() and np.any() calls that the iterative approach is avoiding, so even if you were to write all your code in pure Python to be accelerated with Numba more effectively, the recursive approach would be slower.
In general, iterative approaches are faster then the recursive equivalent, because they avoid the function call overhead (which is minimal for JIT accelerated functions compared to pure Python ones).
So I would advise against trying to rewrite the algorithm in recursive form, only to discover that it is slower.
EDIT
On the premises that a simple np.diff() would do the trick, Numba can still be quite beneficial:
import numpy as np
import numba as nb
#nb.jit
def diff(arr):
n = arr.size
result = np.empty(n - 1, dtype=arr.dtype)
for i in range(n - 1):
result[i] = arr[i + 1] ^ arr[i]
return result
positions = np.random.randint(0, 2, size=300_000, dtype=bool)
print(np.allclose(np.diff(positions), diff(positions)))
# True
%timeit np.diff(positions)
# 1000 loops, best of 3: 603 µs per loop
%timeit diff(positions)
# 10000 loops, best of 3: 43.3 µs per loop
with the Numba approach being some 13x faster (in this test, mileage may vary, of course).

The gist is, only the part of logic that uses Python machinery can be accelerated -- by replacing it with some equivalent C logic that strips away most of the complexity (and flexibility) of Python runtime (I presume this is what Numba does).
All the heavy lifting in NumPy operations is already implemented in C and very simple (since NumPy arrays are contiguous chunks of memory holding regular C types) so Numba can only strip the parts that interface with Python machinery.
Your "binary search" algorithm does much more work and makes much heavier use of NumPy's vector operations while at it, so less of it can be accelerated this way.

Related

Sum a numpy array in chunks

Let's say I have a numpy array:
x = np.array([3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])
And I want to sum it in groups of, say, 3, so that the results is as follows:
np.array([12, 21, 30, 39])
Here is one way to do it:
n = x.size
out = x.reshape(n//3, 3) # np.ones(3)
Is there a quicker way? I feel like this could be improved.
EDIT: just wanted to give an update for some of the methods described here
n = int(1e6)
arr = np.random.random(4*n)
def method1(arr):
return arr.reshape(n, 4) # np.ones(4)
def method2(arr):
return arr.reshape(n, 4).sum(-1)
def method3(arr):
return np.add.reduceat(arr, np.arange(0, 4*n, 4))
%timeit method1(arr)
1.53 ms ± 85.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit method2(arr)
14.6 ms ± 867 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit method3(arr)
14.2 ms ± 369 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
method2 is the basic way to do that in Numpy. That being said it is not very well optimized yet internally for such a case. Indeed, the reduction is done along a very small number of items and the internal reduction is optimized for a relatively large number of items. AFAIR, compilers like GCC tends to auto-vectorize the code using SIMD instructions resulting in a much slower execution for small reductions. It might be optimized in the future but this is tricky to do since the problem is mainly related to the way compilers optimize the code and the assumption they make during the optimization steps. Thus, it is not really a problem of Numpy though there are ways to specifically optimize this use-case at the expensive of a less-maintainable code.
method3 is not very efficient since np.add.reduceat is currently not yet very-optimized internally in Numpy. We plan to do that but one should not expect a drastic improvement since the method is fundamentally not very efficient on modern CPUs anyway.
method1 is clever because it makes use of BLAS that are very optimized internally. The default implementation on most platform, OpenBLAS, carefully optimize many use-case, including small matrices/vectors multiplications, resulting in a much faster execution. That being said, it is not optimal due to the unneeded multiplications by ones (BLAS does not optimize the computations based on the content of the values).
AFAIK, there is no way to write a faster implementation than method1 in pure Numpy. As a result, the only option left to speed up the code is to execute a natively-compiled code specifically design to solve your use-case. This is possible using Numba or Cython. Here is a naive implementation:
import numba as nb
#nb.njit('(float64[::1],)')
def method4(arr):
res = np.empty(n)
for i in range(n):
res[i] = arr[i*4] + arr[i*4+1] + arr[i*4+2] + arr[i*4+3]
return res
If you run this code, you will certainly get similar performance results than BLAS demonstrating how good BLAS implementations are (in fact, OpenBLAS is a bit faster on my machine). This code is not optimal because it is mainly memory-bound and page faults slow things down on most systems (see this related post). You can mitigate their overheads using multiple threads. This is still not optimal as page faults does not scale well on all platforms (quite fine on Linux but poor on Windows). Alternatively, you can preallocate the output array once so to pay this overhead only once. You can even mix both approaches regarding your needs (suing multiple threads can be useful to ensure the memory is saturated whatever the target platform though creating threads can be expensive). Here is the naive parallel implementation and an optimized parallel implementation:
# Naive parallel implementation mitigating a bit the page-faults overhead
#nb.njit('(float64[::1],)', parallel=True)
def method5(arr):
res = np.empty(n)
for i in nb.prange(n):
res[i] = arr[i*4] + arr[i*4+1] + arr[i*4+2] + arr[i*4+3]
return res
# Parallel implementation avoiding completely page-faults
# (assuming `res` is preallocated and filled)
#nb.njit('(float64[::1],float64[::1])', parallel=True)
def method6(arr, res):
for i in nb.prange(n):
res[i] = arr[i*4] + arr[i*4+1] + arr[i*4+2] + arr[i*4+3]
Benchmark
method1: 3.64 ms
method2: 11.7 ms
method3: 16.0 ms
method4: 3.88 ms
method5: 2.05 ms
method6: 0.84 ms <----
This last method is nearly optimal and 4.3 times faster than the previously fastest BLAS one.

Is there a way to speed up indexing a vector with JAX?

I am indexing vectors and using JAX, but I have noticed a considerable slow-down compared to numpy when simply indexing arrays. For example, consider making a basic array in JAX numpy and ordinary numpy:
import jax.numpy as jnp
import numpy as onp
jax_array = jnp.ones((1000,))
numpy_array = onp.ones(1000)
Then simply indexing between two integers, for JAX (on GPU) this gives a time of:
%timeit jax_array[435:852]
1000 loops, best of 5: 1.38 ms per loop
And for numpy this gives a time of:
%timeit numpy_array[435:852]
1000000 loops, best of 5: 271 ns per loop
So numpy is 5000 times faster than JAX. When JAX is on a CPU, then
%timeit jax_array[435:852]
1000 loops, best of 5: 577 µs per loop
So faster, but still 2000 times slower than numpy. I am using Google Colab notebooks for this, so there should not be a problem with the installation/CUDA.
Am I missing something? I realise that indexing is different for JAX and numpy, as given by the JAX 'sharp edges' documentation, but I cannot find any way to perform assignment such as
new_array = jax_array[435:852]
without a considerable slowdown. I cannot avoid indexing the arrays as it is necessary in my program.
The short answer: to speed things up in JAX, use jit.
The long answer:
You should generally expect single operations using JAX in op-by-op mode to be slower than similar operations in numpy. This is because JAX execution has some amount of fixed per-python-function-call overhead involved in pushing compilations down to XLA.
Even seemingly simple operations like indexing are implemented in terms of multiple XLA operations, which (outside JIT) will each add their own call overhead. You can see this sequence using the make_jaxpr transform to inspect how the function is expressed in terms of primitive operations:
from jax import make_jaxpr
f = lambda x: x[435:852]
make_jaxpr(f)(jax_array)
# { lambda ; a.
# let b = broadcast_in_dim[ broadcast_dimensions=( )
# shape=(1,) ] 435
# c = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(), start_index_map=(0,))
# indices_are_sorted=True
# slice_sizes=(417,)
# unique_indices=True ] a b
# d = broadcast_in_dim[ broadcast_dimensions=(0,)
# shape=(417,) ] c
# in (d,) }
(See Understanding Jaxprs for info on how to read this).
Where JAX outperforms numpy is not in single small operations (in which JAX dispatch overhead dominates), but rather in sequences of operations compiled via the jit transform. So, for example, compare the JIT-compiled versus not-JIT-compiled version of the indexing:
%timeit f(jax_array).block_until_ready()
# 1000 loops, best of 5: 612 µs per loop
f_jit = jit(f)
f_jit(jax_array) # trigger compilation
%timeit f_jit(jax_array).block_until_ready()
# 100000 loops, best of 5: 4.34 µs per loop
(note that block_until_ready() is required for accurate micro-benchmarks because of JAX's asynchronous dispatch)
JIT-compiling this code gives a 150x speedup. It's still not as fast as numpy because of JAX's few-millisecond dispatch overhead, but with JIT that overhead is incurred only once. And when you move past microbenchmarks to more complicated sequences of real-world computations, those few milliseconds will no longer dominate, and the optimization provided by the XLA compiler can make JAX far faster than the equivalent numpy computation.

Numba doesn't significantly accelerate recursive binary search [duplicate]

I have an array of true/false values with a very simple structure:
# the real array has hundreds of thousands of items
positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)
I want to traverse this array and output the places where changes happen (true becomes false or the contrary). For this purpose, I've put together two different approaches:
a recursive binary search (see if all values are the same, if not, split in two, then recurse)
a purely iterative search (loop through all elements and compare with the previous/next one)
Both versions give exactly the result that I want, however Numba has a greater effect on one than another. With a dummy array of 300k values, here are the performance results:
Performance results with array of 300k elements
pure Python binary-search runs in 11 ms
pure Python iterative-search runs in 1.1 s (100x slower than binary-search)
Numba binary-search runs in 5 ms (2 times faster than pure Python equivalent)
Numba iterative-search runs in 900 µs (1,200 times faster than pure Python equivalent)
As a result, when using Numba, binary_search is 5x slower than iterative_search, while in theory it should be 100x faster (it should be expected to run in 9 µs if it was properly accelerated).
What can be done to make Numba accelerate binary-search as much as it accelerates iterative-search?
Code for both approaches (along with a sample position array) is available on this public gist: https://gist.github.com/JivanRoquet/d58989aa0a4598e060ec2c705b9f3d8f
Note: Numba is not running binary_search() in object mode, because when mentioning nopython=True, it doesn't complain and happily compiles the function.
You can find the positions of value changes by using np.diff, there is no need to run a more complicated algorithm, or to use numba:
positions = np.array([True, False, False, False, True, True, True, True, False, False, False], dtype=np.bool)
dpos = np.diff(positions)
# array([ True, False, False, True, False, False, False, True, False, False])
This works, because False - True == -1 and np.bool(-1) == True.
It performs quite well on my battery powered (= throttled due to energy saving mode), and several years old laptop:
In [52]: positions = np.random.randint(0, 2, size=300_000, dtype=bool)
In [53]: %timeit np.diff(positions)
633 µs ± 4.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
I'd imagine that writing your own diff in numba should yield similar performance.
EDIT: The last statement is false, I implemented a simple diff function using numba, and it's more than a factor of 10 faster than the numpy one (but it obviously also has much less features, but should be sufficient for this task):
#numba.njit
def ndiff(x):
s = x.size - 1
r = np.empty(s, dtype=x.dtype)
for i in range(s):
r[i] = x[i+1] - x[i]
return r
In [68]: np.all(ndiff(positions) == np.diff(positions))
Out[68]: True
In [69]: %timeit ndiff(positions)
46 µs ± 138 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
The main issue is that you are not performing an apple-to-apple comparison.
What you provide is not an iterative and a recursive version of the same algorithm.
You are proposing two fundamentally different algorithms, which happen to be recursive/iterative.
In particular you are using NumPy built-ins a lot more in the recursive approach, so no wonder that there is such a staggering difference in the two approaches. It should also come at no surprise that the Numba JITting is more effective when you are avoiding NumPy built-ins.
Eventually, the recursive algorithm seems to be less efficient as there is some hidden nested looping in the np.all() and np.any() calls that the iterative approach is avoiding, so even if you were to write all your code in pure Python to be accelerated with Numba more effectively, the recursive approach would be slower.
In general, iterative approaches are faster then the recursive equivalent, because they avoid the function call overhead (which is minimal for JIT accelerated functions compared to pure Python ones).
So I would advise against trying to rewrite the algorithm in recursive form, only to discover that it is slower.
EDIT
On the premises that a simple np.diff() would do the trick, Numba can still be quite beneficial:
import numpy as np
import numba as nb
#nb.jit
def diff(arr):
n = arr.size
result = np.empty(n - 1, dtype=arr.dtype)
for i in range(n - 1):
result[i] = arr[i + 1] ^ arr[i]
return result
positions = np.random.randint(0, 2, size=300_000, dtype=bool)
print(np.allclose(np.diff(positions), diff(positions)))
# True
%timeit np.diff(positions)
# 1000 loops, best of 3: 603 µs per loop
%timeit diff(positions)
# 10000 loops, best of 3: 43.3 µs per loop
with the Numba approach being some 13x faster (in this test, mileage may vary, of course).
The gist is, only the part of logic that uses Python machinery can be accelerated -- by replacing it with some equivalent C logic that strips away most of the complexity (and flexibility) of Python runtime (I presume this is what Numba does).
All the heavy lifting in NumPy operations is already implemented in C and very simple (since NumPy arrays are contiguous chunks of memory holding regular C types) so Numba can only strip the parts that interface with Python machinery.
Your "binary search" algorithm does much more work and makes much heavier use of NumPy's vector operations while at it, so less of it can be accelerated this way.

Optimizing Many Matrix Operations in Python / Numpy

In writing some numerical analysis code, I have bottle-necked at a function that requires many Numpy calls. I am not entirely sure how to approach further performance optimization.
Problem:
The function determines error by calculating the following,
Code:
def foo(B_Mat, A_Mat):
Temp = np.absolute(B_Mat)
Temp /= np.amax(Temp)
return np.sqrt(np.sum(np.absolute(A_Mat - Temp*Temp))) / B_Mat.shape[0]
What would be the best way to squeeze some extra performance out of the code? Would my best course of action be performing the majority of the operations in a single for loop with Cython to cut down on the temporary arrays?
There are specific functions from the implementation that could be off-loaded to numexpr module which is known to be very efficient for arithmetic computations. For our case, specifically we could perform squaring, summation and absolute computations with it. Thus, a numexpr based solution to replace the last step in the original code, would be like so -
import numexpr as ne
out = np.sqrt(ne.evaluate('sum(abs(A_Mat - Temp**2))'))/B_Mat.shape[0]
A further performance boost could be achieved by embedding the normalization step into the numexpr's evaluate expression. Thus, the entire function modified to use numexpr would be -
def numexpr_app1(B_Mat, A_Mat):
Temp = np.absolute(B_Mat)
M = np.amax(Temp)
return np.sqrt(ne.evaluate('sum(abs(A_Mat*M**2-Temp**2))'))/(M*B_Mat.shape[0])
Runtime test -
In [198]: # Random arrays
...: A_Mat = np.random.randn(4000,5000)
...: B_Mat = np.random.randn(4000,5000)
...:
In [199]: np.allclose(foo(B_Mat, A_Mat),numexpr_app1(B_Mat, A_Mat))
Out[199]: True
In [200]: %timeit foo(B_Mat, A_Mat)
1 loops, best of 3: 891 ms per loop
In [201]: %timeit numexpr_app1(B_Mat, A_Mat)
1 loops, best of 3: 400 ms per loop

sign() much slower in python than matlab?

I have a function in python that basically takes the sign of an array (75,150), for example.
I'm coming from Matlab and the time execution looks more or less the same less this function.
I'm wondering if sign() works very slowly and you know an alternative to do the same.
Thx,
I can't tell you if this is faster or slower than Matlab, since I have no idea what numbers you're seeing there (you provided no quantitative data at all). However, as far as alternatives go:
import numpy as np
a = np.random.randn(75, 150)
aSign = np.sign(a)
Testing using %timeit in IPython:
In [15]: %timeit np.sign(a)
10000 loops, best of 3: 180 µs per loop
Because the loop over the array (and what happens inside it) is implemented in optimized C code rather than generic Python code, it tends to be about an order of magnitude faster—in the same ballpark as Matlab.
Comparing the exact same code as a numpy vectorized operation vs. a Python loop:
In [276]: %timeit [np.sign(x) for x in a]
1000 loops, best of 3: 276 us per loop
In [277]: %timeit np.sign(a)
10000 loops, best of 3: 63.1 us per loop
So, only 4x as fast here. (But then a is pretty small here.)

Categories