How to reduce JAX compile time when using for loop? - python

This is a basic example.
#jax.jit
def block(arg1, arg2):
for x1 in range(cons1):
for x2 in range(cons2):
for x3 in range(cons3):
--do something--
return result
When cons are small, the compile-time is around a minute. With larger cons, compile time is much higher—10s of minutes. And I need even higher cons. What can be done?
From what I am reading, the loops are the cause. They are unrolled at compile time.
Are there any workarounds? There is also jax.fori_loop. But I don't understand how to use it. There is jax.experimental.loops module, but again I'm not able to understand it.
I am very new to all this. Hence, all help is appreciated.
If you can provide some examples of how to use jax loops, that will be much appreciated.
Also, what is an ok compile time? Is it ok for it to be in minutes?
In one of the examples, compile time is 262 seconds and remaining runs are ~0.1-0.2 seconds.
Any gain in runtime is overshadowed by the compile time.

JAX's JIT compiler flattens all Python loops. To see what I mean, take a look at this simple function run through jax.make_jaxpr, which is a way to examine how JAX's tracer interprets python code (see Understanding Jaxprs for more):
import jax
def f(x):
for i in range(5):
x += i
return x
print(jax.make_jaxpr(f)(0))
# { lambda ; a.
# let b = add a 0
# c = add b 1
# d = add c 2
# e = add d 3
# f = add e 4
# in (f,) }
Notice that the loop is flattened: every step becomes an explicit operation sent to the XLA compiler. The XLA compile time increases as you increase the number of operations in the function, so it makes sense that a triply-nested for-loop would lead to long compile times.
So, how to address this? Well, unfortunately the answer depends on what your --do something-- is doing, so I can't guess that.
In general, the best option is to use vectorized array operations rather than loops over the values in those vectors; for example, here is a very slow way of adding two vectors:
import jax.numpy as jnp
def f_slow(x, y):
z = []
for xi, yi in zip(xi, yi):
z.append(xi + yi)
return jnp.array(z)
and here is a much faster way to do the same thing:
def f_fast(x, y):
return x + y
If your operations don't lend themselves to vectorization, another option is to use lax control flow operators in place of the for loops: this will push the loop down into XLA. This can have quite good performance on CPU, but is slower on accelerators when compared to equivalent vectorized array operations.
For more discussion on JAX and Python control flow statements (such as for, if, while, etc.), see 🔪 JAX - The Sharp Bits 🔪: Control Flow.

I am not sure if this is will be the same as with numba, but this might be similar case.
When I use numba.jit compiler and have big data input, first I compile function on some small example data, then use it.
Pseudo-code:
func_being_compiled(small_amount_of_data) # compile-only purpose
func_being_compiled(large_amount_of_data)

Related

Cache performance of Python array (not list)

I understand that Python's array provided by the array module stores consecutively the actual values (not pointers). Hence I would expect that, when elements of such an array are read in order, CPU cache would play a role.
Thus I would expect that Code A below should be faster than Code B (the difference between the two is in the order of reading the elements).
Code A:
import array
import time
arr = array.array('l', range(100000000))
sum = 0
begin = time.time()
for i in range(10000):
for j in range(10000):
sum += arr[i * 10000 + j]
print(sum)
print(time.time() - begin)
Code B:
import array
import time
arr = array.array('l', range(100000000))
sum = 0
begin = time.time()
for i in range(10000):
for j in range(10000):
sum += arr[j * 10000 + i]
print(sum)
print(time.time() - begin)
The two versions' timings are almost identical (a difference of only ~3%). Am I missing something about the workings of the array?
The two codes are completely dominated by the overhead of CPython (by a very large margin). Let's try to understand why.
First of all, CPython is an interpreter so it optimize (nearly) nothing. This means operation like i * 10000 are recomputed over and over while it can be precomputed in the parent loop. This also means instructions are fetch+decoded from a bytecode which is pretty slow (and cause many memory accesses + branches).
Additionally, access to global variable is significantly slower in CPython because the interpreter needs to fetch the variable from a global dictionary which is much slower than an access to the CPU cache.
Moreover, most CPython operations allocate/free objects and this is expensive (again, far much than a cache access). Indeed, allocating an object require to fetch a bucket data-structure and find some available space in it. Note that small integers are cached so they are not allocated. This means looping on small ranges is actually a bit faster. Checks for caching are always done so they add some overhead even when this is not possible to cache objets. Such operations requires several memory operations (twice since the free is needed). Not to mention the reference counting of each object also requiring memory operations (and the global interpreter lock operations).
In addition, CPython integer operation are pretty slow because CPython deals with variable-sized integers and not native one. This means CPython does additional checks when integers are large. Bad news: sum is a large integer.
The following code is actually about 2.5 times faster than the original one and it still spent a lot of time in CPython overheads (lot of object allocation/free, ref-counting, C calls, etc.) :
import array
import time
arr = array.array('l', range(100000000))
def compute(arr):
sum = 0
for i in range(10000):
tmp = i * 10000
for j in range(100):
tmp2 = tmp + j * 100
for k in range(100):
sum += arr[tmp2 + k]
print(sum)
begin = time.time()
compute(arr)
print(time.time() - begin)
Pure-Python codes running with the CPython interpreter are so slow that you often cannot see the impact of caches. Thus, using Python to benchmark such effect is a terrible idea. The only way to see such an impact is to use vectorized functions, that is C function doing the job far more efficiently than an interpreted Python code. Numpy is able to do that. Here is an equivalent code for the two original codes:
import numpy as np
import time
arr = np.arange(100_000_000).astype(np.int64)
begin = time.time()
sum = 0
for i in range(10000):
sum += arr[i*10000:i*10000+10000].sum()
print(sum)
print(time.time() - begin)
begin = time.time()
sum = 0
for i in range(10000):
sum += arr[i:100_000_000+i:10000].sum()
print(sum)
print(time.time() - begin)
The above code give the resulting timing:
Original first code: 19.581 s
First Numpy code: 0.064 s
Second Numpy code: 0.725 s
The first Numpy code is about 300 times faster than the original one showing how inefficient was the pure-Python code. Indeed, this shows that ≥99.7% of the original code was pure overheads. We can also see that the second Numpy code is slower than the first due to the strided access pattern (but still 27 times faster than the first original code).
Nearly all the time is spent in the same section of the same internal function in Numpy for both variants. That being said, the second one is much slower because of the strided access. Here is the executed assembly code:
Block 6:
0x180198970 add r9, qword ptr [rcx]
0x180198973 add rcx, r11
0x180198976 add r10, qword ptr [rcx]
0x180198979 add rcx, r11
0x18019897c sub rdx, 0x1
0x180198980 jnz 0x180198970 <Block 6>
This code is not optimal when the array slice is contiguous. The compiler could have generated a significantly SIMD code for this case. Not to mention, the the dependency chain prevent the processor to execute more instructions in parallel (on the same core). That being said, it enable us to see the impact of the strided access using the exact same assembly code. Thus, this is a pretty good benchmark unless you want to include the benefit of using SIMD instructions. SIMD instructions can make this code about 2-3 times faster on my machine. They can only speed up the non-strided use-case on mainstream platforms.
If you want to measure cache effects, it is generally better to use a natively compiled code. This can be done using Numba in Python (JIT compiler using LLVM) or simply natively compiled languages like C or C++.

Using numba #jit to speed up my multiprocessing loops that use queue

I am learning the ways of Numba and have not figured out how to use or whether I need to use multiprocessing.queue to combine all my loop data from separate processes.
Do I even want to use the multiprocessing module to break up big loops into multiple smaller ones to run in separate processes or does Numba do this automatically?
The code below is run in the multiprocessing module where it opens up in multiple processes that are divided up into your system core count. So there are many instances of the code running and compute looping through different segments of the overall calculation and then the result 0 or 1 is sent back to the parent function.
My guess is Numba does this differently on its own and I don't want to use queue or the multiprocessing module?
#jit(nopython=True)
def prime_multiprocess(n, c, q):
a, b, c = n[0], n[1], c
for i in range(a, b):
if c % i == 0:
return q.put(0)
return q.put(1)
This error may have been caused by the following argument(s):
- argument 2: cannot determine Numba type of <class 'multiprocessing.queues.Queue'>
I appreciate any explanation or link that explains using numba with parallel loops that speed things up.
I did some testing and it appears that a nested function solved the problem:
I rewrote it to:
def prime_multiprocess(n, c, q):
a, b, c = n[0], n[1], c
#jit(nopython=True)
def speed_comp():
for i in range(a, b):
if c % i == 0:
return 0
return 1
q.put(speed_comp())
It is faster!
edit:
It appears there is a downside to where I am limited to the size of the integers I can use. "sigh" "Why is there always a trade off :( "
I wonder if its possible to workaround this with numpy and if it would slow it down. Answer might be here: Numba support for big integers?
The way Numba works is it converts integers into machine-level integers which are limited in scope to your system level such as 64 bit. This is what makes it run faster because there is no overhead on-top of the calculations. Unfortunately without the overhead slowing things down, you cannot compute bigger integers.

Why this piece of Python code is taking unusually longer time to compute the sum?

I like to explore this in terms of Asymptotic Notations: Big(O), Omega and Theta.
Here is a small piece of Python code. And tried running it by giving larger value each time. If you look at the 3rd scenario(image), the code is taking usually longer time to calculate the sum.
I wonder if i can rewrite it sequentially, would there be any difference? How can I optimize this code to take larger values? Thank you
You can make the same algorithm a lot faster by using numba:
from numba import jit
#jit
def compute(n):
x = 0
for i in range(1, n+1):
x += 1/i
return x
print(compute(1000000000))

Vectorize a simple function in python: avoid double for loop

I'm completely new to python. I'm trying to do a very simple thing, evaluate a non-trivial function that takes floats as input on a 2D mesh. The following code does exactly what I want, but it is slow, due to the double for loop.
import numpy as np
from galpy.potential import RazorThinExponentialDiskPotential
R = np.logspace(0., 2., 10)
z=R
#initialize with default values for this example
potfunc=RazorThinExponentialDiskPotential()
pot=np.zeros((R.size, z.size))
for i in range(0, R.size):
for j in range(0, z.size):
pot[i,j]=potfunc(R[i],z[j])
At the end, the array pot contains all the information I want, but now I want to increase the efficency. I know that pure python is slow, expecially on loops (like IDL), so I checked np.vectorize, but it's just a python loop under the hood.
The problem is that potfunc seems not accepting arrays, but just plain scalars.
How can I optimize this simple program?
Many thanks in advance.
The standard way to do that is using meshgrid :
r,z=np.meshgrid(R,Z)
pot=potfunc(r,z)
You must avoid looping on numpy array, or you will loose all vectorisation efficiency.
In case you cannot vectorize the function by hand (maybe you could subclass the Razor.. class and rewrite the function), you could use multiprocessing. Instead of my simple worker function you could use the function you like.:
from multiprocessing import pool
import numpy as np
def worker(x):
ai,bj = x
return ai + bj
def run_pool():
a = np.linspace(0,10,10)
b = np.logspace(0,10,len(a))
vec = [(a[i],b[j]) for i in range(len(a)) for j in range(len(b))]
p = pool.Pool(processes=4) # as many cores as you have
print(p.map(worker,vec))
p.close()
p.join()
run_pool()
But before you think about speeding things up, profiling would be good. I am pretty sure that in your case the function itself is the bottleneck. So either you rewrite it in a compiler language, vectorize it or you use all of your cores.

Python: Splitting up a sum with threads

i have a costly calculation to do for fitting some experimental data. The fitting function is a sum over eigenmodes, each of them containing a specific surface integral. As it is rather slow if you do it the classical way i thought about threading it. I'm using python btw.
The function i want to calculate is something like
def fit_func(params , Mmin, Mmax):
values = np.zeros(1000)
for m in range(Mmin, Mmax):
# Fancy Calculation for each mode
# some calulation with all modes, adding them up 'values'
return values
How can i split this up? I did something like
data1 = thread.start_new_thread(fit_func, (params,0,13))
data2 = thread.start_new_thread(fit_func, (params,13,25))
but then the sum of data1 and data2 is not the same as fitfunc(params, 0,25)...
Try out multiprocessing. This will effectively create separate Python processes using a thread-like interface. However, make sure that you profile your computation and make sure that it is the problem, not something else like IO. Starting processes is very slow, so keep them around for a while if you are planning to use them.
You can also use numpy for those functions. They're written in C code, so they're stupid fast. Check them both out and see what fits best. I would go for the numpy solution myself...
use multiprocessing pool
import multiprocessing as mp
p = mp.Pool(10)
res = p.map(your_function, range(Mmin, Mmax))

Categories