Why is JAX's `split()` so slow at first call? - python

jax.numpy.split can be used to segment an array into equal-length segments with a remainder in the last element. e.g. splitting an array of 5000 elements into segments of 10:
array = jnp.ones(5000)
segment_size = 10
split_indices = jnp.arange(segment_size, array.shape[0], segment_size)
segments = jnp.split(array, split_indices)
This takes around 10 seconds to execute on Google Colab and on my local machine. This seems unreasonable for such a simple task on a small array. Am I doing something wrong to make this slow?
Further Details (JIT caching, maybe?)
Subsequent calls to .split are very fast, provided an array of the same shape and the same split indices. e.g. the first iteration of the following loop is extremely slow, but all others fast. (11 seconds vs 40 milliseconds)
from timeit import default_timer as timer
import jax.numpy as jnp
array = jnp.ones(5000)
segment_size = 10
split_indices = jnp.arange(segment_size, array.shape[0], segment_size)
for k in range(5):
start = timer()
segments = jnp.split(array, split_indices)
end = timer()
print(f'call {k}: {end - start:0.2f} s')
Output:
call 0: 11.79 s
call 1: 0.04 s
call 2: 0.04 s
call 3: 0.05 s
call 4: 0.04 s
I assume that the subsequent calls are faster because JAX is caching jitted versions of split for each combination of arguments. If that's the case, then I assume split is slow (on its first such call) because of compilation overhead.
Is that true? If yes, how should I split a JAX array without incurring the performance hit?

This is slow because there are tradeoffs in the implementation of split(), and your function happens to be on the wrong side of the tradeoff.
There are several ways to compute slices in XLA, including XLA:Slice (i.e. lax.slice), XLA:DynamicSlice (i.e. lax.dynamic_slice), and XLA:Gather (i.e. lax.gather).
The main difference between these concerns whether the start and ending indices are static or dynamic. Static indices essentially mean you're specializing your computation for specific index values: this incurs some small compilation overhead on the first call, but subsequent calls can be very fast. Dynamic indices, on the other hand, don't include such specialization, so there is less compilation overhead, but each execution takes slightly longer. You may be able to guess where this is going...
jnp.split currently is implemented in terms of lax.slice (see code), meaning it uses static indices. This means that the first use of jnp.split will incur compilation cost proportional to the number of outputs, but repeated calls will execute very quickly. This seemed like the best approach for common uses of split, where a handful of arrays are produced.
In your case, you're generating hundreds of arrays, so the compilation cost far dominates over the execution.
To illustrate this, here are some timings for three approaches to the same array split, based on gather, slice, and dynamic_slice. You might wish to use one of these directly rather than using jnp.split if your program benefits from different implementations:
from timeit import default_timer as timer
from jax import lax
import jax.numpy as jnp
import jax
def f_slice(x, step=10):
return [lax.slice(x, (N,), (N + step,)) for N in range(0, x.shape[0], step)]
def f_dynamic_slice(x, step=10):
return [lax.dynamic_slice(x, (N,), (step,)) for N in range(0, x.shape[0], step)]
def f_gather(x, step=10):
step = jnp.asarray(step)
return [x[N: N + step] for N in range(0, x.shape[0], step)]
def time(f, x):
print(f.__name__)
for k in range(5):
start = timer()
segments = jax.block_until_ready(f(x))
end = timer()
print(f' call {k}: {end - start:0.2f} s')
x = jnp.ones(5000)
time(f_slice, x)
time(f_dynamic_slice, x)
time(f_gather, x)
Here's the output on a Colab CPU runtime:
f_slice
call 0: 7.78 s
call 1: 0.05 s
call 2: 0.04 s
call 3: 0.04 s
call 4: 0.04 s
f_dynamic_slice
call 0: 0.15 s
call 1: 0.12 s
call 2: 0.14 s
call 3: 0.13 s
call 4: 0.16 s
f_gather
call 0: 0.55 s
call 1: 0.54 s
call 2: 0.51 s
call 3: 0.58 s
call 4: 0.59 s
You can see here that static indices (lax.slice) lead to the fastest execution after compilation. However, for generating many slices, dynamic_slice and gather avoid repeated compilations. It may be that we should re-implement jnp.split in terms of dynamic_slice, but that wouldn't come without tradeoffs: for example, it would lead to a slowdown in the (possibly more common?) case of few splits, where lax.slice would be faster on both initial and subsequent runs. Also, dynamic_slice only avoids recompilation if each slice is the same size, so generating many slices of varying sizes would incur a large compilation overhead similar to lax.slice.
These kinds of tradeoffs are actively discussed in JAX development channels; a recent example very similar to this can be found in PR #12219. If you wish to weigh-in on this particular issue, I'd invite you to file a new jax issue on the topic.
A final note: if you're truly just interested in generating equal-length sequential slices of an array, you would be much better off just calling reshape:
out = x.reshape(len(x) // 10, 10)
The result is now a 2D array where each row corresponds to a slice from the above functions, and this will far out-perform anything that's generating a list of array slices.

Jax inbult functions are also JIT compiled
Benchmarking JAX code
JAX code is Just-In-Time (JIT) compiled. Most code written in JAX can
be written in such a way that it supports JIT compilation, which can
make it run much faster (see To JIT or not to JIT). To get maximium
performance from JAX, you should apply jax.jit() on your outer-most
function calls.
Keep in mind that the first time you run JAX code, it will be slower
because it is being compiled. This is true even if you don’t use jit
in your own code, because JAX’s builtin functions are also JIT
compiled.
So the first time you run it, it is compiling jnp.split (Or at least, compiling some of the functions used within jnp.split)
%%timeit -n1 -r1
jnp.split(array, split_indices)
1min 15s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
The second time, it is calling the compiled function
%%timeit -n1 -r1
jnp.split(array, split_indices)
131 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
It is fairly complicated, calling other jax.numpy functions, so I assume it can take quite a while to compile (1 minute on my machine!)

Related

Is it possible to improve python performance for this code?

I have a simple code that:
Read a trajectory file that can be seen as a list of 2D arrays (list of positions in space) stored in Y
I then want to compute for each pair (scipy.pdist style) the RMSD
My code works fine:
trajectory = read("test.lammpstrj", index="::")
m = len(trajectory)
#.get_positions() return a 2d numpy array
Y = np.array([snapshot.get_positions() for snapshot in trajectory])
b = [np.sqrt(((((Y[i]- Y[j])**2))*3).mean()) for i in range(m) for j in range(i + 1, m)]
This code execute in 0.86 seconds using python3.10, using Julia1.8 the same kind of code execute in 0.46 seconds
I plan to have trajectory much larger (~ 200,000 elements), would it be possible to get a speed-up using python or should I stick to Julia?
You've mentioned that snapshot.get_positions() returns some 2D array, suppose of shape (p, q). So I expect that Y is a 3D array with some shape (m, p, q), where m is the number of snapshots in the trajectory. You also expect m to scale rather high.
Let's see a basic way to speed up the distance calculation, on the setting m=1000:
import numpy as np
# dummy inputs
m = 1000
p, q = 4, 5
Y = np.random.randn(m, p, q)
# your current method
def foo():
return [np.sqrt(((((Y[i]- Y[j])**2))*3).mean()) for i in range(m) for j in range(i + 1, m)]
# vectorized approach -> compute the upper triangle of the pairwise distance matrix
def bar():
u, v = np.triu_indices(Y.shape[0], 1)
return np.sqrt((3 * (Y[u] - Y[v]) ** 2).mean(axis=(-1, -2)))
# Check for correctness
out_1 = foo()
out_2 = bar()
print(np.allclose(out_1, out_2))
# True
If we test the time required:
%timeit -n 10 -r 3 foo()
# 3.16 s ± 50.3 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
The first method is really slow, it takes over 3 seconds for this calculation. Let's check the second method:
%timeit -n 10 -r 3 bar()
# 97.5 ms ± 405 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
So we have a ~30x speedup here, which would make your large calculation in python much more feasible than using the original code. Feel free to test out with other sizes of Y to see how it scales compared to the original.
JIT
In addition, you can also try out JIT, mainly jax or numba. It is fairly simple to port the function bar with jax.numpy, for example:
import jax
import jax.numpy as jnp
#jax.jit
def jit_bar(Y):
u, v = jnp.triu_indices(Y.shape[0], 1)
return jnp.sqrt((3 * (Y[u] - Y[v]) ** 2).mean(axis=(-1, -2)))
# check for correctness
print(np.allclose(bar(), jit_bar(Y)))
# True
If we test the time of the jitted jnp op:
%timeit -n 10 -r 3 jit_bar(Y)
# 10.6 ms ± 678 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
So compared to the original, we could reach even up to ~300x speed.
Note that not every operation can be converted to jax/jit so easily (this particular problem is conveniently suitable), so the general advice is to simply avoid python loops and use numpy's broadcasting/vectorization capabilities, like in bar().
Stick to Julia.
If you already made it in a language which runs faster, why are you trying to use python in the first place?
Your question is about speeding up Python, relative to Julia, so I'd like to offer some Julia code for comparison.
Since your data is most naturally expressed as a list of 4x5 arrays, I suggest expressing it as a vector of SMatrixes:
sumdiff2(A, B) = sum((A[i] - B[i])^2 for i in eachindex(A, B))
function dists(Y)
M = length(Y)
V = Vector{float(eltype(eltype(Y)))}(undef, sum(1:M-1))
Threads.#threads for i in eachindex(Y)
ii = sum(M-i+1:M-1) # don't worry about this sum
for j in i+1:lastindex(Y)
ind = ii + (j-i)
V[ind] = sqrt(3 * sumdiff2(Y[i], Y[j])/length(Y[i]))
end
end
return V
end
using Random: randn
using StaticArrays: SMatrix
Ys = [randn(SMatrix{4,5,Float64}) for _ in 1:1000];
Benchmarks:
# single-threaded
julia> using BenchmarkTools
julia> #btime dists($Ys);
6.561 ms (2 allocations: 3.81 MiB)
# multi-threaded with 6 cores
julia> #btime dists($Ys);
1.606 ms (75 allocations: 3.82 MiB)
I was not able to install jax on my computer, but when comparing with #Mercury's numpy code I got
foo: 5.5seconds
bar: 179ms
i.e. approximately 3400x speedup over foo.
It is possible to write this as a one-liner at a ~2-3x performance cost.
While Python tends to be slower than Julia for many tasks, it is possible to write numerical codes as fast as Julia in Python using Numba and plain loops. Indeed, Numba is based on LLVM-Lite which is basically a JIT-compiler based on the LLVM toolchain. The standard implementation of Julia also use a JIT and the LLVM toolchain. This means the two should behave pretty closely besides the overhead introduced by the languages that are negligible once the computation is performed in parallel (because the resulting computation will be memory-bound on nearly all modern platforms).
This computation can be parallelized in both Julia and Python (still using Numba). While writing a sequential computation is quite straightforward, writing a parallel computation is if bit more complex. Indeed, computing the upper triangular values can result in an imbalanced workload and so to a sub-optimal execution time. An efficient strategy is to compute, for each iteration, a pair of lines: one comes from the top of the upper triangular part and one comes from the bottom. The top line contains m-i items while the bottom one contains i+1 items. In the end, there is m+1 items to compute per iteration so the number of item is independent of the iteration number. This results in a much better load-balancing. The line of the middle needs to be computed separately regarding the size of the input array.
Here is the final implementation:
import numba as nb
import numpy as np
#nb.njit(inline='always', fastmath=True)
def compute_line(tmp, res, i, m):
offset = (i * (2 * m - i - 1)) // 2
factor = 3.0 / n
for j in range(i + 1, m):
s = 0.0
for k in range(n):
s += (tmp[i, k] - tmp[j, k]) ** 2
res[offset] = np.sqrt(s * factor)
offset += 1
return res
#nb.njit('()', parallel=True, fastmath=True)
def fastest():
m, n = Y.shape[0], Y.shape[1] * Y.shape[2]
res = np.empty(m*(m-1)//2)
tmp = Y.reshape(m, n)
for i in nb.prange(m//2):
compute_line(tmp, res, i, m)
compute_line(tmp, res, m-i-1, m)
if m % 2 == 1:
compute_line(tmp, res, (m+1)//2, m)
return res
# [...] same as others
%timeit -n 100 fastest()
Results
Here are performance results on my machine (with a i5-9600KF having 6 cores):
foo (seq, Python, Mercury): 4910.7 ms
bar (seq, Python, Mercury): 134.2 ms
jit_bar (seq, Python, Mercury): ???
dists (seq, Julia, DNF) 6.9 ms
dists (par, Julia, DNF) 2.2 ms
fastest (par, Python, me): 1.5 ms <-----
(Jax does not work on my machine so I cannot test it yet)
This implementation is the fastest one and succeed to beat the best Julia code so far.
Optimal implementation
Note that for large arrays like (200_000,4,5), all implementations provided so far are inefficient since they are not cache friendly. Indeed, the input array will take 32 MiB and will not for on the cache of most modern processors (and even if it could, one need to consider the space needed for the output and the fact that caches are not perfect). This can be fixed using tiling, at the expense of an even more complex code. I think such an implementation should be optimal if you use Z-order curves.

numpy array: fast assign short array to large array with index

I want to assign values to large array from short arrays with indexing. Simple codes are as follows:
import numpy as np
def assign_x():
a = np.zeros((int(3e6), 20))
index_a = np.random.randint(int(3e6), size=(int(3e6), 20))
b = np.random.randn(1000, 20)
for i in range(20):
index_b = np.random.randint(1000, size=int(3e6))
a[index_a[:, i], i] = b[index_b, i]
return a
%timeit x = assign_x()
# 2.79 s ± 18.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
I have tried other ways which may be relevant, for example np.take and numba's jit, but seems the above is the fastest way. It can also be possibly speed-up using multiprocessing. I have profile the codes, the most time is at the line below as it runs many times (20 here)
a[index_a[:, i], i] = b[index_b, i]
Any chance I can make this faster before using multiprocessing?
Why this is slow
This is slow because the memory access pattern is very inefficient. Indeed, random accesses are slow because the processor cannot predict them. As a result, it causes expensive cache misses (if the array does not fit in the L1/L2 cache) that cannot be avoided by prefetching data ahead of time. The thing is the arrays are to big to fit in caches: index_a and a takes each 457 MiB and b takes 156 KiB. As a results, access to b are typically done in the L2 cache with a higher latency and the accesses to the two other array are done in RAM. This is slow because the current DDR RAMs have huge latency of 60-100 ns on a typical PC. Even worse: this latency is likely not gonna be much smaller in a near future: the RAM latency has not changed much since the last two decades. This is called the Memory wall. Note also that modern processors fetch a full cache line of usually 64 bytes from the RAM when a value at a random location is requested (resulting in only 56/64=87.5% of the bandwidth to be wasted). Finally, generating random numbers is a quite expensive process, especially large integers, and np.random.randint can generate either 32-bit or 64-bit integers regarding the target platform.
How to improve this
The first improvement is to prefer indirection on the most contiguous dimension which is generally the last one since a[:,i] is slower than a[i,:]. You can transpose the arrays and swap the indexed values. However, the Numpy transposition function only return a view and does not actually transpose the array in memory. Thus an explicit copy in currently required. The best here is simply to directly generate the array so that accesses are efficient (rather than using expensive transpositions). Note you can use simple precision so array can better fit in caches at the expense of a lower precision.
Here is an example that returns a transposed array:
import numpy as np
def assign_x():
a = np.zeros((20, int(3e6)))
index_a = np.random.randint(int(3e6), size=(20, int(3e6)))
b = np.random.randn(20, 1000)
for i in range(20):
index_b = np.random.randint(1000, size=int(3e6))
a[i, index_a[i, :]] = b[i, index_b]
return a
%timeit x = assign_x()
The code can be improved further using Numba so to run the code in parallel (one core should not be enough to saturate the memory because of the RAM latency but many core can better use it because multiple fetches can be done concurrently). Moreover, it can help avoid the creation of big temporary arrays.
Here is an optimize Numba code:
import numpy as np
import numba as nb
import random
#nb.njit('float64[:,:]()', parallel=True)
def assign_x():
a = np.zeros((20, int(3e6)))
b = np.random.randn(20, 1000)
for i in nb.prange(20):
for j in range(3_000_000):
index_a = random.randint(0, 3_000_000)
index_b = random.randint(0, 1000)
a[i, index_a] = b[i, index_b]
return a
%timeit x = assign_x()
Here are results on a 10-core Skylake Xeon processor:
Initial code: 2798 ms
Better memory access pattern: 1741 ms
With Numba: 318 ms
Note that parallelizing the inner-most loop would theoretically be faster because one line of a is more likely to fit in the last-level cache. However, doing this will cause a race condition that can only be fixed efficiently with atomic stores not yet available in Numba (on CPU).
Note that the final code does not scale well because it is memory-bound. This is due to 87.5% of the memory throughput being wasted as explained before. Additionally, on many processors (like all Intel and AMD-Zen processors) the write allocate cache policy force data being read from memory for each store in this case. This makes the computation much more inefficient raising the wasted throughput to 93.7%... AFAIK, there is no way to prevent this in Python. In C/C++, the write allocate issue can be fixed using low-level instructions. The rule of thumb is avoid memory random access patterns on big arrays like the plague.

Is there a way to speed up indexing a vector with JAX?

I am indexing vectors and using JAX, but I have noticed a considerable slow-down compared to numpy when simply indexing arrays. For example, consider making a basic array in JAX numpy and ordinary numpy:
import jax.numpy as jnp
import numpy as onp
jax_array = jnp.ones((1000,))
numpy_array = onp.ones(1000)
Then simply indexing between two integers, for JAX (on GPU) this gives a time of:
%timeit jax_array[435:852]
1000 loops, best of 5: 1.38 ms per loop
And for numpy this gives a time of:
%timeit numpy_array[435:852]
1000000 loops, best of 5: 271 ns per loop
So numpy is 5000 times faster than JAX. When JAX is on a CPU, then
%timeit jax_array[435:852]
1000 loops, best of 5: 577 µs per loop
So faster, but still 2000 times slower than numpy. I am using Google Colab notebooks for this, so there should not be a problem with the installation/CUDA.
Am I missing something? I realise that indexing is different for JAX and numpy, as given by the JAX 'sharp edges' documentation, but I cannot find any way to perform assignment such as
new_array = jax_array[435:852]
without a considerable slowdown. I cannot avoid indexing the arrays as it is necessary in my program.
The short answer: to speed things up in JAX, use jit.
The long answer:
You should generally expect single operations using JAX in op-by-op mode to be slower than similar operations in numpy. This is because JAX execution has some amount of fixed per-python-function-call overhead involved in pushing compilations down to XLA.
Even seemingly simple operations like indexing are implemented in terms of multiple XLA operations, which (outside JIT) will each add their own call overhead. You can see this sequence using the make_jaxpr transform to inspect how the function is expressed in terms of primitive operations:
from jax import make_jaxpr
f = lambda x: x[435:852]
make_jaxpr(f)(jax_array)
# { lambda ; a.
# let b = broadcast_in_dim[ broadcast_dimensions=( )
# shape=(1,) ] 435
# c = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(), start_index_map=(0,))
# indices_are_sorted=True
# slice_sizes=(417,)
# unique_indices=True ] a b
# d = broadcast_in_dim[ broadcast_dimensions=(0,)
# shape=(417,) ] c
# in (d,) }
(See Understanding Jaxprs for info on how to read this).
Where JAX outperforms numpy is not in single small operations (in which JAX dispatch overhead dominates), but rather in sequences of operations compiled via the jit transform. So, for example, compare the JIT-compiled versus not-JIT-compiled version of the indexing:
%timeit f(jax_array).block_until_ready()
# 1000 loops, best of 5: 612 µs per loop
f_jit = jit(f)
f_jit(jax_array) # trigger compilation
%timeit f_jit(jax_array).block_until_ready()
# 100000 loops, best of 5: 4.34 µs per loop
(note that block_until_ready() is required for accurate micro-benchmarks because of JAX's asynchronous dispatch)
JIT-compiling this code gives a 150x speedup. It's still not as fast as numpy because of JAX's few-millisecond dispatch overhead, but with JIT that overhead is incurred only once. And when you move past microbenchmarks to more complicated sequences of real-world computations, those few milliseconds will no longer dominate, and the optimization provided by the XLA compiler can make JAX far faster than the equivalent numpy computation.

How to determine if numba's prange actually works correctly?

In another Q+A (Can I perform dynamic cumsum of rows in pandas?) I made a comment regarding the correctness of using prange about this code (of this answer):
from numba import njit, prange
#njit
def dynamic_cumsum(seq, index, max_value):
cumsum = []
running = 0
for i in prange(len(seq)):
if running > max_value:
cumsum.append([index[i], running])
running = 0
running += seq[i]
cumsum.append([index[-1], running])
return cumsum
The comment was:
I wouldn't recommend parallelizing a loop that isn't pure. In this case the running variable makes it impure. There are 4 possible outcomes: (1)numba decides that it cannot parallelize it and just process the loop as if it was cumsum instead of prange (2)it can lift the variable outside the loop and use parallelization on the remainder (3)numba incorrectly inserts synchronization between the parallel executions and the result may be bogus (4)numba inserts the necessary synchronizations around running which may impose more overhead than you gain by parallelizing it in the first place
And the later addition:
Of course both the running and cumsum variable make the loop "impure", not just the running variable as stated in the previous comment
Then I was asked:
This might sound like a silly question, but how can I figure out which of the 4 things it did and improve it? I would really like to become better with numba!
Given that it could be useful for future readers I decided to create a self-answered Q+A here. Spoiler: I cannot really answer the question which of the 4 outcomes is produced (or if numba produces a totally different outcome) so I highly encourage other answers.
TL;DR: First: prange in identical to range, except when you add parallel to the jit, for example njit(parallel=True). If you try that you'll see an exception about an "unsupported reduction" - that's because Numba limits the scope of prange to "pure" loops and "impure loops" with numba-supported reductions and puts the responsibility of making sure that it falls into either of these categories on the user.
This is clearly stated in the documentation of numbas prange (version 0.42):
1.10.2. Explicit Parallel Loops
Another feature of this code transformation pass is support for explicit parallel loops. One can use Numba’s prange instead of range to specify that a loop can be parallelized. The user is required to make sure that the loop does not have cross iteration dependencies except for supported reductions.
What the comments refer to as "impure" is called "cross iteration dependencies" in that documentation. Such a "cross-iteration dependency" is a variable that changes between loops. A simple example would be:
def func(n):
a = 0
for i in range(n):
a += 1
return a
Here the variable a depends on the value it had before the loop started and how many iterations of the loop had been executed. That's what is meant by a "cross iteration dependency" or an "impure" loop.
The problem when explicitly parallelizing such a loop is that iterations are performed in parallel but each iteration needs to know what the other iterations are doing. Failure to do so would result in a wrong result.
Let's for a moment assume that prange would spawn 4 workers and we pass 4 as n to the function. What would a completely naive implementation do?
Worker 1 starts, gets a i = 1 from `prange`, and reads a = 0
Worker 2 starts, gets a i = 2 from `prange`, and reads a = 0
Worker 3 starts, gets a i = 3 from `prange`, and reads a = 0
Worker 1 executed the loop and sets `a = a + 1` (=> 1)
Worker 3 executed the loop and sets `a = a + 1` (=> 1)
Worker 4 starts, gets a i = 4 from `prange`, and reads a = 2
Worker 2 executed the loop and sets `a = a + 1` (=> 1)
Worker 4 executed the loop and sets `a = a + 1` (=> 3)
=> Loop ended, function return 3
The order in which the different workers read, execute and write to a can be arbitrary, this was just one example. It could also produce (by accident) the correct result! That's generally called a Race condition.
What would a more sophisticated prange do that recognizes that there is such a cross iteration dependency?
There are three options:
Simply don't parallelize it.
Implement a mechanism where the workers share the variable. Typical examples here are Locks (this can incur a high overhead).
Recognize that it's a reduction that can be parallelized.
Given my understanding of the numba documentation (repeated again):
The user is required to make sure that the loop does not have cross iteration dependencies except for supported reductions.
Numba does:
If it's a known reduction then use patterns to parallelize it
If it's not a known reduction throw an exception
Unfortunately it's not clear what "supported reductions" are. But the documentation hints that it's binary operators that operate on the previous value in the loop body:
A reduction is inferred automatically if a variable is updated by a binary function/operator using its previous value in the loop body. The initial value of the reduction is inferred automatically for += and *= operators. For other functions/operators, the reduction variable should hold the identity value right before entering the prange loop. Reductions in this manner are supported for scalars and for arrays of arbitrary dimensions.
The code in the OP uses a list as cross iteration dependency and calls list.append in the loop body. Personally I wouldn't call list.append a reduction and it's not using a binary operator so my assumption would be that it's very likely not supported. As for the other cross iteration dependency running: It's using addition on the result of the previous iteration (which would be fine) but also conditionally resets it to zero if it exceeds a threshold (which is probably not fine).
Numba provides ways to inspect the intermediate code (LLVM and ASM) code:
dynamic_cumsum.inspect_types()
dynamic_cumsum.inspect_llvm()
dynamic_cumsum.inspect_asm()
But even if I had the required understanding of the results to make any statement about the correctness of the emitted code - in general it's highly nontrivial to "prove" that multi-threaded/process code works correctly. Given that I even lack the LLVM and ASM knowledge to even see if it even tries to parallelize it I cannot actually answer your specific question which outcome it produces.
Back to the code, as mentioned it throws an exception (unsupported reduction) if I use parallel=True, so I assume that numba doesn't parallelize anything in the example:
from numba import njit, prange
#njit(parallel=True)
def dynamic_cumsum(seq, index, max_value):
cumsum = []
running = 0
for i in prange(len(seq)):
if running > max_value:
cumsum.append([index[i], running])
running = 0
running += seq[i]
cumsum.append([index[-1], running])
return cumsum
dynamic_cumsum(np.ones(100), np.arange(100), 10)
AssertionError: Invalid reduction format
During handling of the above exception, another exception occurred:
LoweringError: Failed in nopython mode pipeline (step: nopython mode backend)
Invalid reduction format
File "<>", line 7:
def dynamic_cumsum(seq, index, max_value):
<source elided>
running = 0
for i in prange(len(seq)):
^
[1] During: lowering "id=2[LoopNest(index_variable = parfor_index.192, range = (0, seq_size0.189, 1))]{56: <ir.Block at <> (10)>, 24: <ir.Block at <> (7)>, 34: <ir.Block at <> (8)>}Var(parfor_index.192, <> (7))" at <> (7)
So what is left to say: prange does not provide any speed advantage in this case over a normal range (because it's not executing in parallel). So in that case I would not "risk" potential problems and/or confusing the readers - given that it's not supported according to the numba documentation.
from numba import njit, prange
#njit
def p_dynamic_cumsum(seq, index, max_value):
cumsum = []
running = 0
for i in prange(len(seq)):
if running > max_value:
cumsum.append([index[i], running])
running = 0
running += seq[i]
cumsum.append([index[-1], running])
return cumsum
#njit
def dynamic_cumsum(seq, index, max_value):
cumsum = []
running = 0
for i in range(len(seq)): # <-- here is the only change
if running > max_value:
cumsum.append([index[i], running])
running = 0
running += seq[i]
cumsum.append([index[-1], running])
return cumsum
Just a quick timing that supports the "not faster than" statement I made earlier:
import numpy as np
seq = np.random.randint(0, 100, 10_000_000)
index = np.arange(10_000_000)
max_ = 500
# Correctness and warm-up
assert p_dynamic_cumsum(seq, index, max_) == dynamic_cumsum(seq, index, max_)
%timeit p_dynamic_cumsum(seq, index, max_)
# 468 ms ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit dynamic_cumsum(seq, index, max_)
# 470 ms ± 9.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Possible to use numba.guvectorize to emulate parallel forall / prange?

As a user of Python for data analysis and numerical calculations, rather than a real "coder", I had been missing a really low-overhead way of distributing embarrassingly parallel loop calculations on several cores.
As I learned, there used to be the prange construct in Numba, but it was abandoned because of "instability and performance issues".
Playing with the newly open-sourced #guvectorize decorator I found a way to use it for virtually no-overhead emulation of the functionality of late prange.
I am very happy to have this tool at hand now, thanks to the guys at Continuum Analytics, and did not find anything on the web explicitly mentioning this use of #guvectorize. Although it may be trivial to people who have been using NumbaPro earlier, I'm posting this for all those fellow non-coders out there (see my answer to this "question").
Consider the example below, where a two-level nested for loop with a core doing some numerical calculation involving two input arrays and a function of the loop indices is executed in four different ways. Each variant is timed with Ipython's %timeit magic:
naive for loop, compiled using numba.jit
forall-like construct using numba.guvectorize, executed in a single thread (target = "cpu")
forall-like construct using numba.guvectorize, executed in as many threads as there are cpu "cores" (in my case hyperthreads) (target = "parallel")
same as 3., however calling the "guvectorized" forall with the sequence of "parallel" loop indices randomly permuted
The last one is done because (in this particular example) the inner loop's range depends on the value of the outer loop's index. I don't know how exactly the dispatchment of gufunc calls is organized inside numpy, but it appears as if the randomization of "parallel" loop indices achieves slightly better load balancing.
On my (slow) machine (1st gen core i5, 2 cores, 4 hyperthreads) I get the timings:
1 loop, best of 3: 8.19 s per loop
1 loop, best of 3: 8.27 s per loop
1 loop, best of 3: 4.6 s per loop
1 loop, best of 3: 3.46 s per loop
Note: I'd be interested if this recipe readily applies to target="gpu" (it should do, but I don't have access to a suitable graphics card right now), and what's the speedup. Please post!
And here's the example:
import numpy as np
from numba import jit, guvectorize, float64, int64
#jit
def naive_for_loop(some_input_array, another_input_array, result):
for i in range(result.shape[0]):
for k in range(some_input_array.shape[0] - i):
result[i] += some_input_array[k+i] * another_input_array[k] * np.sin(0.001 * (k+i))
#guvectorize([(float64[:],float64[:],int64[:],float64[:])],'(n),(n),()->()', nopython=True, target='parallel')
def forall_loop_body_parallel(some_input_array, another_input_array, loop_index, result):
i = loop_index[0] # just a shorthand
# do some nontrivial calculation involving elements from the input arrays and the loop index
for k in range(some_input_array.shape[0] - i):
result[0] += some_input_array[k+i] * another_input_array[k] * np.sin(0.001 * (k+i))
#guvectorize([(float64[:],float64[:],int64[:],float64[:])],'(n),(n),()->()', nopython=True, target='cpu')
def forall_loop_body_cpu(some_input_array, another_input_array, loop_index, result):
i = loop_index[0] # just a shorthand
# do some nontrivial calculation involving elements from the input arrays and the loop index
for k in range(some_input_array.shape[0] - i):
result[0] += some_input_array[k+i] * another_input_array[k] * np.sin(0.001 * (k+i))
arg_size = 20000
input_array_1 = np.random.rand(arg_size)
input_array_2 = np.random.rand(arg_size)
result_array = np.zeros_like(input_array_1)
# do single-threaded naive nested for loop
# reset result_array inside %timeit call
%timeit -r 3 result_array[:] = 0.0; naive_for_loop(input_array_1, input_array_2, result_array)
result_1 = result_array.copy()
# do single-threaded forall loop (loop indices in-order)
# reset result_array inside %timeit call
loop_indices = range(arg_size)
%timeit -r 3 result_array[:] = 0.0; forall_loop_body_cpu(input_array_1, input_array_2, loop_indices, result_array)
result_2 = result_array.copy()
# do multi-threaded forall loop (loop indices in-order)
# reset result_array inside %timeit call
loop_indices = range(arg_size)
%timeit -r 3 result_array[:] = 0.0; forall_loop_body_parallel(input_array_1, input_array_2, loop_indices, result_array)
result_3 = result_array.copy()
# do forall loop (loop indices scrambled for better load balancing)
# reset result_array inside %timeit call
loop_indices_scrambled = np.random.permutation(range(arg_size))
loop_indices_unscrambled = np.argsort(loop_indices_scrambled)
%timeit -r 3 result_array[:] = 0.0; forall_loop_body_parallel(input_array_1, input_array_2, loop_indices_scrambled, result_array)
result_4 = result_array[loop_indices_unscrambled].copy()
# check validity
print(np.all(result_1 == result_2))
print(np.all(result_1 == result_3))
print(np.all(result_1 == result_4))

Categories