Any way to speed up this Python code? - python

I've written some Python code to do some image processing work, but it takes a huge amount of time to run. I've spent the last few hours trying to optimize it, but I think I've reached the end of my abilities.
Looking at the outputs from the profiler, the function below is taking a large proportion of the overall time of my code. Is there any way that it can be speeded up?
def make_ellipse(x, x0, y, y0, theta, a, b):
c = np.cos(theta)
s = np.sin(theta)
a2 = a**2
b2 = b**2
xnew = x - x0
ynew = y - y0
ellipse = (xnew * c + ynew * s)**2/a2 + (xnew * s - ynew * c)**2/b2 <= 1
return ellipse
To give the context, it is called with x and y as the output from np.meshgrid with a fairly large grid size, and all of the other parameters as simple integer values.
Although that function seems to be taking a lot of the time, there are probably ways that the rest of the code can be speeded up too. I've put the rest of the code at this gist.
Any ideas would be gratefully received. I've tried using numba and autojiting the main functions, but that doesn't help much.

Let's try to optimize make_ellipse in conjunction with its caller.
First, notice that a and b are the same over many calls. Since make_ellipse squares them each time, just have the caller do that instead.
Second, notice that np.cos(np.arctan(theta)) is 1 / np.sqrt(1 + theta**2) which seems slightly faster on my system. A similar trick can be used to compute the sine, either from theta or from cos(theta) (or vice versa).
Third, and less concretely, think about short-circuiting some of the final ellipse formula evaluations. For example, wherever (xnew * c + ynew * s)**2/a2 is greater than 1, the ellipse value must be False. If this happens often, you can "mask" out the second half of the (expensive) calculation of the ellipse at those locations. I haven't planned this thoroughly, but see numpy.ma for some possible leads.

It won't speed up things for all cases, but if your ellipses don't take up the whole image, you should limit your search for points inside the ellipse to its bounding rectangle. I am lazy with the math, so I googled it and reused #JohnZwinck neat cosine of an arctangent trick to come up with this function:
def ellipse_bounding_box(x0, y0, theta, a, b):
x_tan_t = -b * np.tan(theta) / a
if np.isinf(x_tan_t) :
x_cos_t = 0
x_sin_t = np.sign(x_tan_t)
else :
x_cos_t = 1 / np.sqrt(1 + x_tan_t*x_tan_t)
x_sin_t = x_tan_t * x_cos_t
x = x0 + a*x_cos_t*np.cos(theta) - b*x_sin_t*np.sin(theta)
y_tan_t = b / np.tan(theta) / a
if np.isinf(y_tan_t):
y_cos_t = 0
y_sin_t = np.sign(y_tan_t)
else:
y_cos_t = 1 / np.sqrt(1 + y_tan_t*y_tan_t)
y_sin_t = y_tan_t * y_cos_t
y = y0 + b*y_sin_t*np.cos(theta) + a*y_cos_t*np.sin(theta)
return np.sort([-x, x]), np.sort([-y, y])
You can now modify your original function to something like this:
def make_ellipse(x, x0, y, y0, theta, a, b):
c = np.cos(theta)
s = np.sin(theta)
a2 = a**2
b2 = b**2
x_box, y_box = ellipse_bounding_box(x0, y0, theta, a, b)
indices = ((x >= x_box[0]) & (x <= x_box[1]) &
(y >= y_box[0]) & (y <= y_box[1]))
xnew = x[indices] - x0
ynew = y[indices] - y0
ellipse = np.zeros_like(x, dtype=np.bool)
ellipse[indices] = ((xnew * c + ynew * s)**2/a2 +
(xnew * s - ynew * c)**2/b2 <= 1)
return ellipse

Since everything but x and y are integers, you can try to minimize the number of array computations. I imagine most of the time is spent in this statement:
ellipse = (xnew * c + ynew * s)**2/a2 + (xnew * s - ynew * c)**2/b2 <= 1
A simple rewriting like so should reduce the number of array operations:
a = float(a)
b = float(b)
ellipse = (xnew * (c/a) + ynew * (s/a))**2 + (xnew * (s/b) - ynew * (c/b))**2 <= 1
What was 12 array operations is now 10 (plus 4 scalar ops). I'm not sure if numba's jit would have tried this. It might just do all the broadcasting first, then jit the resulting operations. In this case, reordering so common operations are done at once should help.
Furthering along, you can rewrite this again as
ellipse = ((xnew + ynew * (s/c)) * (c/a))**2 + ((xnew * (s/c) - ynew) * (c/b))**2 <= 1
Or
t = numpy.tan(theta)
ellipse = ((xnew + ynew * t) * (b/a))**2 + (xnew * t - ynew)**2 <= (b/c)**2
Replacing one more array operation with a scalar, and eliminating other scalar ops to get 9 array operations and 2 scalar ops.
As always, be aware of what the range of inputs are to avoid rounding errors.
Unfortunately there's no way good way to do a running sum and bail early if either of the two addends is greater than the right hand side of the comparison. That would be an obvious speed-up, but one you'd need cython (or c/c++) to code.

You can speed it up considerably by using Cython. There is a very good documentation on how to do this.

Related

Solve integral in an annular domain in Python

I am trying to solve a function in an annular domain that has a change of phase with respect to the angular direction of the annulus.
My attempt to solve it is the following:
import numpy as np
from scipy import integrate
def f(x0, y0):
r = np.sqrt(x0**2 + y0**2)
if r >= rIn and r <= rOut:
theta = np.arctan(y0 / x0)
R = np.sqrt((x - x0)**2 + (y - y0)**2 + z**2)
integrand = (np.exp(-1j * (k*R + theta))) / R
return integrand
else:
return 0
# Test
rIn = 0.5
rOut = 1.5
x = 1
y = 1
z = 1
k = 3.66
I = integrate.dblquad(f, -rOut, rOut, lambda x0: -rOut, lambda x0: rOut)
My problem is that I don't know how to get rid of the division by zero occuring when I evaluate theta.
Any help will be more than appreciated!
Use numpy.arctan2 instead, it will have problems only if both x and y are zero, in which case the angle is undetermined.
Also I see you that your integrand is complex, in this case you will probably have to handle real and imaginary part separately, as done here.

Solving this rectangular, nonlinear system with SciPy

Background.
I'm attempting to write a python implementation of this answer over on Math SE. You may find the following background to be useful.
Problem
I have an experimental setup consisting of three (3) receivers, with known locations [xi, yi, zi], and a transmitter with unknown location [x,y,z] emitting a signal at known velocity v. This signal arrives at the receivers at known times ti. The time of emission, t, is unknown.
I wish to find the angle of arrival (i.e. the transmitter's polar coordinates theta and phi), given only this information.
Solution
It is not possible to locate the transmitter exactly with only three (3) receivers, except in a handful of unique cases (there are several great answers across Math SE explaining why this is the case). In general, at least four (and, in practice, >>4) receivers are required to uniquely determine the rectangular coordinates of the transmitter.
The direction to the transmitter, however, may be "reliably" estimated. Letting vi be the vector representing the location of receiver i, ti being the time of signal arrival at receiver i, and n be the vector representing the unit vector pointing in the (approximate) direction of the transmitter, we obtain the following equations:
<n, vj - vi> = v(ti - tj)
(where < > denotes the scalar product)
...for all pairs of indices i,j. Together with |n| = 1, the system has 2 solutions in general, symmetric by reflection in the plane through vi/vj/vk. We may then determine phi and theta by simply writing n in polar coordinates.
Implementation.
I've attempted to write a python implementation of the above solution, using scipy's fsolve.
from dataclasses import dataclass
import scipy.optimize
import random
import math
c = 299792
#dataclass
class Vertexer:
roc: list
def fun(self, var, dat):
(x,y,z) = var
eqn_0 = (x * (self.roc[0][0] - self.roc[1][0])) + (y * (self.roc[0][1] - self.roc[1][1])) + (z * (self.roc[0][2] - self.roc[1][2])) - c * (dat[1] - dat[0])
eqn_1 = (x * (self.roc[0][0] - self.roc[2][0])) + (y * (self.roc[0][1] - self.roc[2][1])) + (z * (self.roc[0][2] - self.roc[2][2])) - c * (dat[2] - dat[0])
eqn_2 = (x * (self.roc[1][0] - self.roc[2][0])) + (y * (self.roc[1][1] - self.roc[2][1])) + (z * (self.roc[1][2] - self.roc[2][2])) - c * (dat[2] - dat[1])
norm = math.sqrt(x**2 + y**2 + z**2) - 1
return [eqn_0, eqn_1, eqn_2, norm]
def find(self, dat):
result = scipy.optimize.fsolve(self.fun, (0,0,0), args=dat)
print('Solution ', result)
# Crude code to simulate a source, receivers at random locations
x0 = random.randrange(0,50); y0 = random.randrange(0,50); z0 = random.randrange(0,50)
x1 = random.randrange(0,50); x2 = random.randrange(0,50); x3 = random.randrange(0,50);
y1 = random.randrange(0,50); y2 = random.randrange(0,50); y3 = random.randrange(0,50);
z1 = random.randrange(0,50); z2 = random.randrange(0,50); z3 = random.randrange(0,50);
t1 = math.sqrt((x0-x1)**2 + (y0-y1)**2 + (z0-z1)**2)/c
t2 = math.sqrt((x0-x2)**2 + (y0-y2)**2 + (z0-z2)**2)/c
t3 = math.sqrt((x0-x3)**2 + (y0-y3)**2 + (z0-z3)**2)/c
print('Actual coordinates ', x0,y0,z0)
myVertexer = Vertexer([[x1,y1,z1], [x2,y2,z2], [x3,y3,z3]])
myVertexer.find([t1,t2,t3])
Unfortunately, I have far more experience solving such problems in C/C++ using GSL, and have limited experience working with scipy and the like. I'm getting the error:
TypeError: fsolve: there is a mismatch between the input and output shape of the 'func' argument 'fun'.Shape should be (3,) but it is (4,).
...which seems to suggest that fsolve expects a square system.
How may I solve this rectangular system? I can't seem to find anything useful in the scipy docs.
If necessary, I'm open to using other (Python) libraries.
As you already mentioned, fsolve expects a system with N variables and N equations, i.e. it finds a root of the function F: R^N -> R^N. Since you have four equations, you simply need to add a fourth variable. Note also that fsolve is a legacy function, and it's recommended to use root instead. Last but not least, note that sqrt(x^2+y^2+z^2) = 1 is equivalent to x^2+y^2+z^2=1 and that the latter is much less susceptible to rounding errors caused by the finite differences when approximating the jacobian of F.
Long story short, your class should look like this:
from scipy.optimize import root
#dataclass
class Vertexer:
roc: list
def fun(self, var, dat):
x,y,z, *_ = var
eqn_0 = (x * (self.roc[0][0] - self.roc[1][0])) + (y * (self.roc[0][1] - self.roc[1][1])) + (z * (self.roc[0][2] - self.roc[1][2])) - c * (dat[1] - dat[0])
eqn_1 = (x * (self.roc[0][0] - self.roc[2][0])) + (y * (self.roc[0][1] - self.roc[2][1])) + (z * (self.roc[0][2] - self.roc[2][2])) - c * (dat[2] - dat[0])
eqn_2 = (x * (self.roc[1][0] - self.roc[2][0])) + (y * (self.roc[1][1] - self.roc[2][1])) + (z * (self.roc[1][2] - self.roc[2][2])) - c * (dat[2] - dat[1])
norm = x**2 + y**2 + z**2 - 1
return [eqn_0, eqn_1, eqn_2, norm]
def find(self, dat):
result = root(self.fun, (0,0,0,0), args=dat)
if result.success:
print('Solution ', result.x[:3])

Determine parabola with given arc length between two known points

Let (0,0) and (Xo,Yo) be two points on a Cartesian plane. We want to determine the parabolic curve, Y = AX^2 + BX + C, which passes from these two points and has a given arc length equal to S. Obviously, S > sqrt(Xo^2 + Yo^2). As the curve must pass from (0,0), it should be C=0. Hence, the curve equation reduces to: Y = AX^2 + BX. How can I determine {A,B} knowing {Xo,Yo,S}? There are two solutions, I want the one with A>0.
I have an analytical solution (complex) that gives S for a given set of {A,B,Xo,Yo}, though here the problem is inverted... I can proceed by solving numerically a complex system of equations... but perhaps there is a numerical routine out there that does exactly this?
Any useful Python library? Other ideas?
Thanks a lot :-)
Note that the arc length (line integral) of the quadratic a*x0^2 + b*x0 is given by the integral of sqrt(1 + (2ax + b)^2) from x = 0 to x = x0. On solving the integral, the value of the integral is obtained as 0.5 * (I(u) - I(l)) / a, where u = 2ax0 + b; l = b; and I(t) = 0.5 * (t * sqrt(1 + t^2) + log(t + sqrt(1 + t^2)), the integral of sqrt(1 + t^2).
Since y0 = a * x0^2 + b * x0, b = y0/x0 - a*x0. Substituting the value of b in u and l, u = y0/x0 + a*x0, l = y0/x0 - a*x0. Substituting u and l in the solution of the line integral (arc length), we get the arc length as a function of a:
s(a) = 0.5 * (I(y0/x0 + a*x0) - I(y0/x0 - a*x0)) / a
Now that we have the arc length as a function of a, we simply need to find the value of a for which s(a) = S. This is where my favorite root-finding algorithm, the Newton-Raphson method, comes into play yet again.
The working algorithm for the Newton-Raphson method of finding roots is as follows:
For a function f(x) whose root is to be obtained, if x(i) is the ith guess for the root,
x(i+1) = x(i) - f(x(i)) / f'(x(i))
Where f'(x) is the derivative of f(x). This process is continued till the difference between two consecutive guesses is very small.
In our case, f(a) = s(a) - S and f'(a) = s'(a). By simple application of the chain rule and the quotient rule,
s'(a) = 0.5 * (a*x0 * (I'(u) + I'(l)) + I(l) - I(u)) / (a^2)
Where I'(t) = sqrt(1 + t^2).
The only problem that remains is calculating a good initial guess. Due to the nature of the graph of s(a), the function is an excellent candidate for the Newton-Raphson method, and an initial guess of y0 / x0 converges to the solution in about 5-6 iterations for a tolerance/epsilon of 1e-10.
Once the value of a is found, b is simply y0/x0 - a*x0.
Putting this into code:
def find_coeff(x0, y0, s0):
def dI(t):
return sqrt(1 + t*t)
def I(t):
rt = sqrt(1 + t*t)
return 0.5 * (t * rt + log(t + rt))
def s(a):
u = y0/x0 + a*x0
l = y0/x0 - a*x0
return 0.5 * (I(u) - I(l)) / a
def ds(a):
u = y0/x0 + a*x0
l = y0/x0 - a*x0
return 0.5 * (a*x0 * (dI(u) + dI(l)) + I(l) - I(u)) / (a*a)
N = 1000
EPSILON = 1e-10
guess = y0 / x0
for i in range(N):
dguess = (s(guess) - s0) / ds(guess)
guess -= dguess
if abs(dguess) <= EPSILON:
print("Break:", abs((s(guess) - s0)))
break
print(i+1, ":", guess)
a = guess
b = y0/x0 - a*x0
print(a, b, s(a))
Run the example on CodeSkulptor.
Note that due to the rational approximation of the arc lengths given as input to the function in the examples, the coefficients obtained may ever so slightly differ from the expected values.

Python double integral taking too long to compute

I am trying to compute the fresnel integral over a grid of coordinates using dblquad. But its taking very long and finally it's not giving any result.
Below is my code. In this code I integrated only over a 10 x 10 grid but I need to integrate at least over a 500 x 500 grid.
import time
st = time.time()
import pylab
import scipy.integrate as inte
import numpy as np
print 'imhere 0'
def sinIntegrand(y,x, X , Y):
a = 0.0001
R = 2e-3
z = 10e-3
Lambda = 0.5e-6
alpha = 0.01
k = np.pi * 2 / Lambda
return np.cos(k * (((x-R)**2)*a + (R-(x**2 + y**2)) * np.tan(np.radians(alpha)) + ((x - X)**2 + (y - Y)**2) / (2 * z)))
print 'im here 1'
def cosIntegrand(y,x,X,Y):
a = 0.0001
R = 2e-3
z = 10e-3
Lambda = 0.5e-6
alpha = 0.01
k = np.pi * 2 / Lambda
return np.sin(k * (((x-R)**2)*a + (R-(x**2 + y**2)) * np.tan(np.radians(alpha)) + ((x - X)**2 + (y - Y)**2) / (2 * z)))
def y1(x,R = 2e-3):
return (R**2 - x**2)**0.5
def y2(x, R = 2e-3):
return -1*(R**2 - x**2)**0.5
points = np.linspace(-1e-3,1e-3,10)
points2 = np.linspace(1e-3,-1e-3,10)
yv,xv = np.meshgrid(points , points2)
#def integrate_on_grid(func, lo, hi,y1,y2):
# """Returns a callable that can be evaluated on a grid."""
# return np.vectorize(lambda n,m: dblquad(func, lo, hi,y1,y2,(n,m))[0])
#
#intensity = abs(integrate_on_grid(sinIntegrand,-1e-3 ,1e-3,y1, y2)(yv,xv))**2 + abs(integrate_on_grid(cosIntegrand,-1e-3 ,1e-3,y1, y2)(yv,xv))**2
Intensity = []
print 'im here2'
for i in points:
row = []
for j in points2:
print 'im here'
intensity = abs(inte.dblquad(sinIntegrand,-1e-3 ,1e-3,y1, y2,(i,j))[0])**2 + abs(inte.dblquad(cosIntegrand,-1e-3 ,1e-3,y1, y2,(i,j))[0])**2
row.append(intensity)
Intensity.append(row)
Intensity = np.asarray(Intensity)
pylab.imshow(Intensity,cmap = 'gray')
pylab.show()
print str(time.time() - st)
I would really appreciate if you could tell any better way of doing this.
Using a scipy.integrate.dblquad to calculate every pixel of your image is going to be slow in any case.
You should try rewriting your mathematical problem so you can use some classical function in scipy.special instead. For instance, scipy.special.fresnel might work, although it is 1D and your problem seems to be in 2D. Otherwise, that there is a relationship between the Fresnel integral and the incomplete Gamma function (scipy.special.gammainc), if that helps.
If none of this work, as a last resort you can spend time optimizing your code and adapting it to Cython. This it will probably give a speed up of a factor of 10 to 100 (see this answer). Though this wouldn't be sufficient to go from a grid 10x10 to a grid 500x500.

Karatsuba algorithm too much recursion

I am trying to implement the Karatsuba multiplication algorithm in c++ but right now I am just trying to get it to work in python.
Here is my code:
def mult(x, y, b, m):
if max(x, y) < b:
return x * y
bm = pow(b, m)
x0 = x / bm
x1 = x % bm
y0 = y / bm
y1 = y % bm
z2 = mult(x1, y1, b, m)
z0 = mult(x0, y0, b, m)
z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0
return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0
What I don't get is: how should z2, z1, and z0 be created? Is using the mult function recursively correct? If so, I'm messing up somewhere because the recursion isn't stopping.
Can someone point out where the error is?
NB: the response below addresses directly the OP's question about
excessive recursion, but it does not attempt to provide a correct
Karatsuba algorithm. The other responses are far more informative in
this regard.
Try this version:
def mult(x, y, b, m):
bm = pow(b, m)
if min(x, y) <= bm:
return x * y
# NOTE the following 4 lines
x0 = x % bm
x1 = x / bm
y0 = y % bm
y1 = y / bm
z0 = mult(x0, y0, b, m)
z2 = mult(x1, y1, b, m)
z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0
retval = mult(mult(z2, bm, b, m) + z1, bm, b, m) + z0
assert retval == x * y, "%d * %d == %d != %d" % (x, y, x * y, retval)
return retval
The most serious problem with your version is that your calculations of x0 and x1, and of y0 and y1 are flipped. Also, the algorithm's derivation does not hold if x1 and y1 are 0, because in this case, a factorization step becomes invalid. Therefore, you must avoid this possibility by ensuring that both x and y are greater than b**m.
EDIT: fixed a typo in the code; added clarifications
EDIT2:
To be clearer, commenting directly on your original version:
def mult(x, y, b, m):
# The termination condition will never be true when the recursive
# call is either
# mult(z2, bm ** 2, b, m)
# or mult(z1, bm, b, m)
#
# Since every recursive call leads to one of the above, you have an
# infinite recursion condition.
if max(x, y) < b:
return x * y
bm = pow(b, m)
# Even without the recursion problem, the next four lines are wrong
x0 = x / bm # RHS should be x % bm
x1 = x % bm # RHS should be x / bm
y0 = y / bm # RHS should be y % bm
y1 = y % bm # RHS should be y / bm
z2 = mult(x1, y1, b, m)
z0 = mult(x0, y0, b, m)
z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0
return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0
Usually big numbers are stored as arrays of integers. Each integer represents one digit. This approach allows to multiply any number by the power of base with simple left shift of the array.
Here is my list-based implementation (may contain bugs):
def normalize(l,b):
over = 0
for i,x in enumerate(l):
over,l[i] = divmod(x+over,b)
if over: l.append(over)
return l
def sum_lists(x,y,b):
l = min(len(x),len(y))
res = map(operator.add,x[:l],y[:l])
if len(x) > l: res.extend(x[l:])
else: res.extend(y[l:])
return normalize(res,b)
def sub_lists(x,y,b):
res = map(operator.sub,x[:len(y)],y)
res.extend(x[len(y):])
return normalize(res,b)
def lshift(x,n):
if len(x) > 1 or len(x) == 1 and x[0] != 0:
return [0 for i in range(n)] + x
else: return x
def mult_lists(x,y,b):
if min(len(x),len(y)) == 0: return [0]
m = max(len(x),len(y))
if (m == 1): return normalize([x[0]*y[0]],b)
else: m >>= 1
x0,x1 = x[:m],x[m:]
y0,y1 = y[:m],y[m:]
z0 = mult_lists(x0,y0,b)
z1 = mult_lists(x1,y1,b)
z2 = mult_lists(sum_lists(x0,x1,b),sum_lists(y0,y1,b),b)
t1 = lshift(sub_lists(z2,sum_lists(z1,z0,b),b),m)
t2 = lshift(z1,m*2)
return sum_lists(sum_lists(z0,t1,b),t2,b)
sum_lists and sub_lists returns unnormalized result - single digit can be greater than the base value. normalize function solved this problem.
All functions expect to get list of digits in the reverse order. For example 12 in base 10 should be written as [2,1]. Lets take a square of 9987654321.
» a = [1,2,3,4,5,6,7,8,9]
» res = mult_lists(a,a,10)
» res.reverse()
» res
[9, 7, 5, 4, 6, 1, 0, 5, 7, 7, 8, 9, 9, 7, 1, 0, 4, 1]
The goal of the Karatsuba multiplication is to improve on the divide-and conquer multiplication algorithm by making 3 recursive calls instead of four. Therefore, the only lines in your script that should contain a recursive call to the multiplication are those assigning z0,z1 and z2. Anything else will give you a worse complexity. You can't use pow to compute bm when you haven't defined multiplication yet (and a fortiori exponentiation), either.
For that, the algorithm crucially uses the fact that it is using a positional notation system. If you have a representation x of a number in base b, then x*bm is simply obtained by shifting the digits of that representation m times to the left. That shifting operation is essentially "free" with any positional notation system. That also means that if you want to implement that, you have to reproduce this positional notation, and the "free" shift. Either you chose to compute in base b=2 and use python's bit operators (or the bit operators of a given decimal, hex, ... base if your test platform has them), or you decide to implement for educational purposes something that works for an arbitrary b, and you reproduce this positional arithmetic with something like strings, arrays, or lists.
You have a solution with lists already. I like to work with strings in python, since int(s, base) will give you the integer corresponding to the string s seen as a number representation in base base: it makes tests easy. I have posted an heavily commented string-based implementation as a gist here, including string-to-number and number-to-string primitives for good measure.
You can test it by providing padded strings with the base and their (equal) length as arguments to mult:
In [169]: mult("987654321","987654321",10,9)
Out[169]: '966551847789971041'
If you don't want to figure out the padding or count string lengths, a padding function can do it for you:
In [170]: padding("987654321","2")
Out[170]: ('987654321', '000000002', 9)
And of course it works with b>10:
In [171]: mult('987654321', '000000002', 16, 9)
Out[171]: '130eca8642'
(Check with wolfram alpha)
I believe that the idea behind the technique is that the zi terms are computed using the recursive algorithm, but the results are not unified together that way. Since the net result that you want is
z0 B^2m + z1 B^m + z2
Assuming that you choose a suitable value of B (say, 2) you can compute B^m without doing any multiplications. For example, when using B = 2, you can compute B^m using bit shifts rather than multiplications. This means that the last step can be done without doing any multiplications at all.
One more thing - I noticed that you've picked a fixed value of m for the whole algorithm. Typically, you would implement this algorithm by having m always be a value such that B^m is half the number of digits in x and y when they are written in base B. If you're using powers of two, this would be done by picking m = ceil((log x) / 2).
Hope this helps!
In Python 2.7: Save this file as Karatsuba.py
def karatsuba(x,y):
"""Karatsuba multiplication algorithm.
Return the product of two numbers in an efficient manner
#author Shashank
date: 23-09-2018
Parameters
----------
x : int
First Number
y : int
Second Number
Returns
-------
prod : int
The product of two numbers
Examples
--------
>>> import Karatsuba.karatsuba
>>> a = 1234567899876543211234567899876543211234567899876543211234567890
>>> b = 9876543211234567899876543211234567899876543211234567899876543210
>>> Karatsuba.karatsuba(a,b)
12193263210333790590595945731931108068998628253528425547401310676055479323014784354458161844612101832860844366209419311263526900
"""
if len(str(x)) == 1 or len(str(y)) == 1:
return x*y
else:
n = max(len(str(x)), len(str(y)))
m = n/2
a = x/10**m
b = x%10**m
c = y/10**m
d = y%10**m
ac = karatsuba(a,c) #step 1
bd = karatsuba(b,d) #step 2
ad_plus_bc = karatsuba(a+b, c+d) - ac - bd #step 3
prod = ac*10**(2*m) + bd + ad_plus_bc*10**m #step 4
return prod

Categories