Very simply, given a point A(x,y) and another point B(m,n), I need a function that can return in any iterable object a list[k,z] of all points in between.
Am only interested in integer points, so no need for floats.
I need the best possible pythonic way because this 'little' function is going to be heavily run and is the key pillar of a larger system.
#roippi, thanks pointing out the gotcha concerning the integers. From my code below, you can see I try to step across the x axis and get corresponding y, then do the same for y. My set of points will not have any non-discrete co-ordinate point, so for the moment I can afford to overlook that small flaw
import itertools
origin = {'x':0, 'y':0}
def slope(origin, target):
if target['x'] == origin['x']:
return 0
m = (target['y'] - origin['y']) / (target['x'] - origin['x'])
return m
def line_eqn(origin, target):
x = origin['x']
y = origin['y']
c = -(slope(origin, target)*x - y)
c = y - (slope(origin, target)*x)
#return 'y = ' + str(slope(target)) + 'x + ' + str(c)
m = slope(origin, target)
return {'m':m, 'c':c}
def get_y(x, slope, c):
# y = mx + c
y = (slope*x) + c
return y
def get_x(y, slope, c):
#x = (y-c)/m
if slope == 0:
c = 0 #vertical lines never intersect with y-axis
if slope == 0:
slope = 1 #Do NOT divide by zero
x = (y - c)/slope
return x
def get_points(origin, target):
coord_list = []
#Step along x-axis
for i in range(origin['x'], target['x']+1):
eqn = line_eqn(origin, target)
y = get_y(i, eqn['m'], eqn['c'])
coord_list.append([i, y])
#Step along y-axis
for i in range(origin['y'], target['y']+1):
eqn = line_eqn(origin, target)
x = get_x(i, eqn['m'], eqn['c'])
coord_list.append([x, i])
#return unique list
return list(k for k,_ in itertools.groupby(sorted(coord_list)))
origin = {'x':1, 'y':3}
target = {'x':1, 'y':6}
print get_points(origin, target)
def get_line(x1, y1, x2, y2):
points = []
issteep = abs(y2-y1) > abs(x2-x1)
if issteep:
x1, y1 = y1, x1
x2, y2 = y2, x2
rev = False
if x1 > x2:
x1, x2 = x2, x1
y1, y2 = y2, y1
rev = True
deltax = x2 - x1
deltay = abs(y2-y1)
error = int(deltax / 2)
y = y1
ystep = None
if y1 < y2:
ystep = 1
ystep = -1
for x in range(x1, x2 + 1):
if issteep:
points.append((y, x))
points.append((x, y))
error -= deltay
if error < 0:
y += ystep
error += deltax
# Reverse the list if the coordinates were reversed
if rev:
return points
Let's assume you know how to work out the equation of a line, so you have
: your gradient,
: your constant
you also have your 2 points: a and b, with the x-value of a lower than the x-value of b
for x in range(a[0], b[0]):
y = m*x + c
if isinstance(y, int) and (x,y) not in [a,b]:
print (x, y)
The Bresenham line segment, or variants thereof is related to the parametric equation
X = X0 + t.Dx
Y = Y0 + t.Dy,
where Dx=X1-X0 and Dy=Y1-Y0, and t is a parameter in [0, 1].
It turns out that this equation can be written for an integer lattice, as
X = X0 + (T.Dx) \ D
Y = Y0 + (T.Dy) \ D,
where \ denotes integer division, D=Max(|Dx|, |Dy|) and t is an integer in range [0, D].
As you can see, depending on which of Dx and Dy is has the largest absolute value and what signs it has, one of the equations can be simplified as X = X0 + T (let us assume for now Dx >= Dy >= 0).
To implement this, you have three options:
use floating-point numbers for the Y equation, Y = Y0 + T.dy, where dy = Dy/D, preferably rounding the result for better symmetry; as you increment T, update with Y+= dy;
use a fixed-point representation of the slope, choosing a power of 2 for scaling, let 2^B; set Y' = Y0 << B, Dy' = (Dy << B) \ D; and every time you perform Y'+= D', retrieve Y = Y' >> B.
use pure integer arithmetic.
In the case of integer arithmetic, you can obtain the rounding effect easily by computing Y0 + (T.Dy + D/2) \ D instead of Y0 + (T.Dy \ D). Indeed, as you divide by D, this is equivalent to Y0 + T.dy + 1/2.
Division is a slow operation. You can trade it for a comparison by means of a simple trick: Y increases by 1 every time T.Dy increases by D. You can maintain a "remainder" variable, equal to (T.Dy) modulo D (or T.Dy + D/2, for rounding), and decrease it by D every time it exceeds D.
Y= Y0
R= 0
for X in range(X0, X1 + 1):
# Pixel(X, Y)
R+= Dy
if R >= D:
R-= D
Y+= 1
For a well optimized version, you should consider separately the nine cases corresponding to the combination of signs of Dx and Dy (-, 0, +).
def getLine(x1,y1,x2,y2):
if x1==x2: ## Perfectly horizontal line, can be solved easily
return [(x1,i) for i in range(y1,y2,int(abs(y2-y1)/(y2-y1)))]
else: ## More of a problem, ratios can be used instead
if x1>x2: ## If the line goes "backwards", flip the positions, to go "forwards" down it.
slope=(y2-y1)/(x2-x1) ## Calculate the slope of the line
while x1+i < x2: ## Keep iterating until the end of the line is reached
line.append((x1+i,y1+slope*i)) ## Add the next point on the line
return line ## Finally, return the line!
Here is a C++ equivalent of user1048839's answer for anyone interested:
std::vector<std::tuple<int, int>> bresenhamsLineGeneration(int x1, int y1, int x2, int y2) {
std::vector<std::tuple<int, int>> points;
bool issteep = (abs(y2 - y1) > abs(x2 - x1));
if (issteep) {
std::swap(x1, y1);
std::swap(x2, y2);
bool rev = false;
if (x1 > x2) {
std::swap(x1, x2);
std::swap(y1, y2);
rev = true;
int deltax = x2 - x1;
int deltay = abs(y2 - y1);
int error = int(deltax / 2);
int y = y1;
int ystep;
if (y1 < y2) {
ystep = 1;
} else {
ystep = -1;
for (int x = x1; x < x2 + 1; ++x) {
if (issteep) {
std::tuple<int, int> pt = std::make_tuple(y, x);
} else {
std::tuple<int, int> pt = std::make_tuple(x, y);
error -= deltay;
if (error < 0) {
y += ystep;
error += deltax;
// Reverse the list if the coordinates were reversed
if (rev) {
std::reverse(points.begin(), points.end());
return points;
I looked into this as a project to learn c. The integer values of a straight line follow this pattern. Major number horizontal, one across one up repeated n times followed by minor number horizontal one across one up. The minor number is one more or less than the major number. The major number is effectively the gradient and the minor number corrects the rounding.
I made a Fourier Series/Transform Tkinter app, and so far everything works as I want it to, except that I am having issues with the circles misaligning.
Here is an image explaining my issue (the green and pink were added after the fact to better explain the issue):
I have narrowed down the problem to the start of the lines, as it seems that they end in the correct place, and the circles are in their correct places.
The distance between the correct positions and the position where the lines start seems to grow, but is actually proportional to the speed of the circle rotating, as the circle rotates by larger amounts, thus going faster.
Here is the code:
from tkinter import *
import time
import math
import random
root = Tk()
myCanvas = Canvas(root, width=1300, height=750)
global x,y, lines, xList, yList
NumOfCircles = 4
rList = [200]
for i in range(0, NumOfCircles):
num = 250/sum(rList)
for i in range(0, NumOfCircles):
rList[i] = rList[i]*num
lines = []
circles = []
centerXList = [300]
for i in range(0,NumOfCircles):
centerYList = [300]
for i in range(0,NumOfCircles):
xList = [0]*NumOfCircles
yList = [0]*NumOfCircles
waveLines = []
wavePoints = []
endCoord = []
for i in range(0, NumOfCircles):
lastX = 0
lastY = 0
count = 0
randlist = []
for i in range(0, NumOfCircles):
def createCircle(x, y, r, canvasName):
x0 = x - r
y0 = y - r
x1 = x + r
y1 = y + r
return canvasName.create_oval(x0, y0, x1, y1, width=r/50, outline="#094F9A")
def updateCircle(i):
newX = endCoord[i-1][0]
newY = endCoord[i-1][1]
centerXList[i] = newX
centerYList[i] = newY
x0 = newX - rList[i]
y0 = newY - rList[i]
x1 = newX + rList[i]
y1 = newY + rList[i]
myCanvas.coords(circles[i], x0, y0, x1, y1)
def circleWithLine(i):
global line, lines
circle = createCircle(centerXList[i], centerYList[i], rList[i], myCanvas)
line = myCanvas.create_line(centerXList[i], centerYList[i], centerXList[i], centerYList[i], width=2, fill="#1581B7")
def update(i, x, y):
endCoord[i][0] = x+(rList[i]*math.cos(xList[i]))
endCoord[i][1] = y+(rList[i]*math.sin(yList[i]))
myCanvas.coords(lines[i], x, y, endCoord[i][0], endCoord[i][1])
xList[i] += (math.pi/randlist[i])
yList[i] += (math.pi/randlist[i])
def lineBetweenTwoPoints(x, y, x2, y2):
line = myCanvas.create_line(x, y, x2, y2, fill="white")
return line
def lineForWave(y1, y2, y3, y4, con):
l = myCanvas.create_line(700+con, y1, 702+con, y2, 704+con, y3, 706+con, y4, smooth=1, fill="white")
for i in range(0,NumOfCircles):
myCanvas.create_line(700, 20, 700, 620, fill="black", width = 3)
myCanvas.create_line(700, 300, 1250, 300, fill="red")
myCanvas.create_line(0, 300, 600, 300, fill="red", width = 0.5)
myCanvas.create_line(300, 0, 300, 600, fill="red", width = 0.5)
while True:
for i in range(0, len(lines)):
update(i, centerXList[i], centerYList[i])
for i in range(1, len(lines)):
if count >= 8:
lineBetweenTwoPoints(lastX, lastY, endCoord[i][0], endCoord[i][1])
if count % 6 == 0 and con<550:
lineForWave(wavePoints[-7],wavePoints[-5],wavePoints[-3],wavePoints[-1], con)
con += 6
lastX = endCoord[i][0]
lastY = endCoord[i][1]
if count != 108:
count += 1
count = 8
I am aware that this is not the best way to achieve what I am trying to achieve, as using classes would be much better. I plan to do that in case nobody can find a solution, and hope that when it is re-written, this issue does not persist.
The main problem that you are facing is that you receive floating point numbers from your calculations but you can only use integers for pixels. In the following I will show you where you fail and the quickest way to solve the issue.
First your goal is to have connected lines and you calculate the points here:
def update(i, x, y):
endCoord[i][0] = x+(rList[i]*math.cos(xList[i]))
endCoord[i][1] = y+(rList[i]*math.sin(yList[i]))
myCanvas.coords(lines[i], x, y, endCoord[i][0], endCoord[i][1])
xList[i] += (math.pi/randlist[i])
yList[i] += (math.pi/randlist[i])
when you add the following code into this function you see that it fails there.
if i != 0:
print(i,endCoord[i-1][0], endCoord[i-1][1])
Because x and y should always match with the last point (end of the previous line) that will be endCoord[i-1][0] and endCoord[i-1][1].
to solve your problem I simply skipt the match for the sarting point of the follow up lines and took the coordinates of the previous line with the following alternated function:
def update(i, x, y):
endCoord[i][0] = x+(rList[i]*math.cos(xList[i]))
endCoord[i][1] = y+(rList[i]*math.sin(yList[i]))
if i == 0:
points = x, y, endCoord[i][0], endCoord[i][1]
points = endCoord[i-1][0], endCoord[i-1][1], endCoord[i][0], endCoord[i][1]
myCanvas.coords(lines[i], *points)
xList[i] += (math.pi/randlist[i])
yList[i] += (math.pi/randlist[i])
Additional proposals are:
don't use wildcard imports
import just what you really use in the code random isnt used in your example
the use of global in the global namespace is useless
create functions to avoid repetitive code
def listinpt_times_circles(inpt):
return [inpt]*CIRCLES
x_list = listinpt_times_circles(0)
y_list = listinpt_times_circles(0)
center_x_list = listinpt_times_circles(0)
center_y_list = listinpt_times_circles(0)
use .after(ms,func,*args) instead of a interrupting while loop and blocking call time.sleep
def animate():
global count,con,lastX,lastY
for i in range(0, len(lines)):
update(i, centerXList[i], centerYList[i])
for i in range(1, len(lines)):
if count >= 8:
lineBetweenTwoPoints(lastX, lastY, endCoord[i][0], endCoord[i][1])
if count % 6 == 0 and con<550:
lineForWave(wavePoints[-7],wavePoints[-5],wavePoints[-3],wavePoints[-1], con)
con += 6
lastX = endCoord[i][0]
lastY = endCoord[i][1]
if count != 108:
count += 1
count = 8
read the PEP 8 -- Style Guide for Python
use intuitive variable names to make your code easier to read for others and yourself in the future
list_of_radii = [200] #instead of rList
as said pixels will be expressed with integers not with floating point numbers
myCanvas.create_line(0, 300, 600, 300, fill="red", width = 1) #0.5 has no effect compare 0.1 to 1
using classes and a canvas for each animation will become handy if you want to show more cycles
dont use tkinters update method
As #Thingamabobs said, the main reason for the misalignment is that pixel coordinates work with integer values. I got excited about your project and decided to make an example using matplotlib, this way I do not have to work with integer values for the coordinates. The example was made to work with any function, I implemented samples with sine, square and sawtooth functions.
I also tried to follow some good practices for naming, type annotations and so on, I hope this helps you
from numbers import Complex
from typing import Callable, Iterable, List
import matplotlib.pyplot as plt
import numpy as np
def fourier_series_coeff_numpy(f: Callable, T: float, N: int) -> List[Complex]:
"""Get the coefficients of the Fourier series of a function.
f (Callable): function to get the Fourier series coefficients of.
T (float): period of the function.
N (int): number of coefficients to get.
List[Complex]: list of coefficients of the Fourier series.
f_sample = 2 * N
t, dt = np.linspace(0, T, f_sample + 2, endpoint=False, retstep=True)
y = np.fft.fft(f(t)) / t.size
return y
def evaluate_fourier_series(coeffs: List[Complex], ang: float, period: float) -> List[Complex]:
"""Evaluate a Fourier series at a given angle.
coeffs (List[Complex]): list of coefficients of the Fourier series.
ang (float): angle to evaluate the Fourier series at.
period (float): period of the Fourier series.
List[Complex]: list of complex numbers representing the Fourier series.
N = np.fft.fftfreq(len(coeffs), d=1/len(coeffs))
N = filter(lambda x: x >= 0, N)
y = 0
radius = []
for n, c in zip(N, coeffs):
r = 2 * c * np.exp(1j * n * ang / period)
y += r
return radius
def square_function_factory(period: float):
"""Builds a square function with given period.
period (float): period of the square function.
def f(t):
if isinstance(t, Iterable):
return [1.0 if x % period < period / 2 else -1.0 for x in t]
elif isinstance(t, float):
return 1.0 if t % period < period / 2 else -1.0
return f
def saw_tooth_function_factory(period: float):
"""Builds a saw-tooth function with given period.
period (float): period of the saw-tooth function.
def f(t):
if isinstance(t, Iterable):
return [1.0 - 2 * (x % period / period) for x in t]
elif isinstance(t, float):
return 1.0 - 2 * (t % period / period)
return f
def main():
f = square_function_factory(PERIOD)
# f = lambda t: np.sin(2 * np.pi * t / PERIOD)
# f = saw_tooth_function_factory(PERIOD)
coeffs = fourier_series_coeff_numpy(f, 1, N_COEFFS)
radius = evaluate_fourier_series(coeffs, 0, 1)
fig, axs = plt.subplots(nrows=1, ncols=2, sharey=True, figsize=(10, 5))
ang_cum = []
amp_cum = []
for ang in np.linspace(0, 2*np.pi * PERIOD * 3, 200):
radius = evaluate_fourier_series(coeffs, ang, 1)
x = np.cumsum([x.imag for x in radius])
y = np.cumsum([x.real for x in radius])
x = np.insert(x, 0, 0)
y = np.insert(y, 0, 0)
axs[0].plot(x, y)
axs[0].set_ylim(-GRAPH_RANGE, GRAPH_RANGE)
axs[0].set_xlim(-GRAPH_RANGE, GRAPH_RANGE)
axs[1].plot(ang_cum, amp_cum)
xmin=x[-1] / (2 * GRAPH_RANGE) + 0.5,
min_x, max_x = axs[1].get_xlim()
line_end_x = (ang - min_x) / (max_x - min_x)
if __name__ == '__main__':
I changed some things but i still have a similar problem. I am working on Mandelbrot zoom. I try to zoom in deeper at branches. I count the consecutive points in the set and return the branch if it reaches the defined length. Then I zoom into that area and repeat. But the change gets smaller and the returned branch is nearly the same as the last one.
X0,Y0,X1 = verticies of current area
branch = current branch
n = number of iterations
p = how many times i want to zoom
b = minimum length of branch i am looking for
c = current complex number
z = current value
k = pixel size
the function that returns the branch
def fractal(n, X0, Y0, X1): # searching for new branch
branch = []
k = (X1 - X0) / x_axis # pixel size
b = 5 # the number of black points int the set i am looking for
for y in range(y_axis):
for x in range(x_axis):
c = X0 + k * x + Y0 - k * y * 1.j # new c
z = c
for m in range(n):
if abs(z) <= 2:
z = z * z + c # new z
if abs(z) <= 2:
branch.append((x,y)) # the coordinates for those points
if b < len(branch) < 50:
raise BreakOutOfALoop # break if a branch was found
branch = []
except BreakOutOfALoop:
return branch, k
the function that calculates the verticies of the new area
def new_area(branch, X0, Y0, X1, k): # calculating new area
X0 = X0 + k * branch[0][0]
Y0 = Y0 - k * branch[0][1]
X1 = X1 + k * branch[-1][0] # new area verticies
return X0, Y0, X1
the loop that calls the functions
for i in range(p):
areaImage ='RGB', (x_axis,y_axis), "white")
area_pixels = areaImage.load() # image load
branch, k = fractal(n,X0, Y0, X1)
X0, Y0, X1 = new_area(branch, X0, Y0, X1, k)
file_name = "/" + str(i) + "_" + str(n) + "_" + str(x_axis) + "_" + str(y_axis) + ".png" # image save
areaImage ="{areaImage_path}{file_name}")
What am I overlooking?
Here are some pictures for different p.
p = 3
p = 10
I have a clicking applicaiton on a phone.
I want to sample last N points so it won't click the same "place" over and over again.
I guess it should be kind of this:
This I want to avoid.
I guess the circlue center should be the center of all points ?
How to determind the radius ?
I think calculation of the last N dots can be used to calculte the "new" N dots once a new click is done, to reduce performance.
Any suggestions ?
def clicking_loop_protection(self, x, y, methods):
:param x:
:param y:
:param method: 'xpath'/'point' it can be both
def centroid(points):
_len = len(points)
x_coords = [p[0] for p in points]
y_coords = [p[1] for p in points]
centroid_x = sum(x_coords) / _len
centroid_y = sum(y_coords) / _len
return centroid_x, centroid_y
def calculate_points_distance(x1, y1, x2, y2):
dist = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
return dist
point_res = True
if 'point' in methods:
# add current point
_point = (x, y)
self.points_history[self.points_history_index] = _point
except IndexError:
if len(self.points_history) == self.points_history_length:
centroid_point_x, centroid_point_y = centroid(self.points_history)
radius = self.displayWidth/10
for point in self.points_history:
# distance from centroid should be less than radius (to fail the test)
distance = calculate_points_distance(centroid_point_x, centroid_point_y, point[0], point[1])
if distance > radius:
point_res = True # pass test !
point_res = False
self.points_history_index += 1
self.points_history_index %= self.points_history_length
return xpath_res and point_res
I'm looking for a short smart way to find all integer points on a line segment. The 2 points are also integers, and the line can be at an angle of 0,45,90,135 etc. degrees.
Here is my long code(so far the 90 degree cases):
def getPoints(p1,p2)
if p1[0] == p2[0]:
if p1[1] < p2[1]:
return [(p1[0],x) for x in range(p1[1],p2[1])]
return [(p1[0],x) for x in range(p1[1],p2[1],-1)]
if p2[1] == p2[1]:
if p1[0] < p2[0]:
return [(x,p1[1]) for x in range(p1[0],p2[0])]
return [(x,p1[1]) for x in range(p1[0],p2[0],-1)]
EDIT: I haven't mentioned it clear enough, but the slope will always be an integer -1, 0 or 1, there are 8 cases that are need to be checked.
Reduce the slope to lowest terms (p/q), then step from one endpoint of the line segment to the other in increments of p vertically and q horizontally. The same code can work for vertical line segments if your reduce-to-lowest-terms code reduces 5/0 to 1/0.
Do a little bit of maths for each pair of points calculate m & c for mx+c and compare it to the formulae for the lines you are considering. (N.B. You Will get some divide by zeros to cope with.)
i could write code that works, but the amount of repeating code is
throwing me off, that's why i turned to you guys
This could stand a lot of improvement but maybe it gets you on the track.
(sorry, don't have time to make it better just now!)
def points(p1,p2):
slope = (p2[1]-p1[1])/float(p2[0]-p1[0])
[(x,x*slope) for x in range (p1[0], p2[0]) if int(x*slope) == x*slope)]
Extending the answer of #Jon Kiparsky.
def points_on_line(p1, p2):
fx, fy = p1
sx, sy = p2
if fx == sx and fy == sy:
return []
elif fx == sx:
return [(fx, y) for y in range(fy+1, sy)]
elif fy == sy:
return [(x, fy) for x in range(fx+1, sx)]
elif fx > sx and fy > sy:
p1, p2 = p2, p1
slope = (p2[1] - p1[1]) / float(p2[0] - p1[0])
return [(x, int(x*slope)) for x in range(p1[0], p2[0]) if int(x*slope) == x*slope and (x, int(x*slope)) != p1]
def getpoints(p1, p2):
# Sort both points first.
(x1, y1), (x2, y2) = sorted([p1, p2])
a = b = 0.0
# Not interesting case.
if x1 == x2:
yield p1
# First point is in (0, y).
if x1 == 0.0:
b = y1
a = (y2 - y1) / x2
elif x2 == 0.0:
# Second point is in (0, y).
b = y2
a = (y1 - y2) / x1
# Both points are valid.
b = (y2 - (y1 * x2) / x1) / (1 - (x2 / x1))
a = (y1 - b) / x1
for x in xrange(int(x1), int(x2) + 1):
y = a * float(x) + b
# Delta could be increased for lower precision.
if abs(y - round(y)) == 0:
yield (x, y)
I am trying to implement the Karatsuba multiplication algorithm in c++ but right now I am just trying to get it to work in python.
Here is my code:
def mult(x, y, b, m):
if max(x, y) < b:
return x * y
bm = pow(b, m)
x0 = x / bm
x1 = x % bm
y0 = y / bm
y1 = y % bm
z2 = mult(x1, y1, b, m)
z0 = mult(x0, y0, b, m)
z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0
return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0
What I don't get is: how should z2, z1, and z0 be created? Is using the mult function recursively correct? If so, I'm messing up somewhere because the recursion isn't stopping.
Can someone point out where the error is?
NB: the response below addresses directly the OP's question about
excessive recursion, but it does not attempt to provide a correct
Karatsuba algorithm. The other responses are far more informative in
this regard.
Try this version:
def mult(x, y, b, m):
bm = pow(b, m)
if min(x, y) <= bm:
return x * y
# NOTE the following 4 lines
x0 = x % bm
x1 = x / bm
y0 = y % bm
y1 = y / bm
z0 = mult(x0, y0, b, m)
z2 = mult(x1, y1, b, m)
z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0
retval = mult(mult(z2, bm, b, m) + z1, bm, b, m) + z0
assert retval == x * y, "%d * %d == %d != %d" % (x, y, x * y, retval)
return retval
The most serious problem with your version is that your calculations of x0 and x1, and of y0 and y1 are flipped. Also, the algorithm's derivation does not hold if x1 and y1 are 0, because in this case, a factorization step becomes invalid. Therefore, you must avoid this possibility by ensuring that both x and y are greater than b**m.
EDIT: fixed a typo in the code; added clarifications
To be clearer, commenting directly on your original version:
def mult(x, y, b, m):
# The termination condition will never be true when the recursive
# call is either
# mult(z2, bm ** 2, b, m)
# or mult(z1, bm, b, m)
# Since every recursive call leads to one of the above, you have an
# infinite recursion condition.
if max(x, y) < b:
return x * y
bm = pow(b, m)
# Even without the recursion problem, the next four lines are wrong
x0 = x / bm # RHS should be x % bm
x1 = x % bm # RHS should be x / bm
y0 = y / bm # RHS should be y % bm
y1 = y % bm # RHS should be y / bm
z2 = mult(x1, y1, b, m)
z0 = mult(x0, y0, b, m)
z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0
return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0
Usually big numbers are stored as arrays of integers. Each integer represents one digit. This approach allows to multiply any number by the power of base with simple left shift of the array.
Here is my list-based implementation (may contain bugs):
def normalize(l,b):
over = 0
for i,x in enumerate(l):
over,l[i] = divmod(x+over,b)
if over: l.append(over)
return l
def sum_lists(x,y,b):
l = min(len(x),len(y))
res = map(operator.add,x[:l],y[:l])
if len(x) > l: res.extend(x[l:])
else: res.extend(y[l:])
return normalize(res,b)
def sub_lists(x,y,b):
res = map(operator.sub,x[:len(y)],y)
return normalize(res,b)
def lshift(x,n):
if len(x) > 1 or len(x) == 1 and x[0] != 0:
return [0 for i in range(n)] + x
else: return x
def mult_lists(x,y,b):
if min(len(x),len(y)) == 0: return [0]
m = max(len(x),len(y))
if (m == 1): return normalize([x[0]*y[0]],b)
else: m >>= 1
x0,x1 = x[:m],x[m:]
y0,y1 = y[:m],y[m:]
z0 = mult_lists(x0,y0,b)
z1 = mult_lists(x1,y1,b)
z2 = mult_lists(sum_lists(x0,x1,b),sum_lists(y0,y1,b),b)
t1 = lshift(sub_lists(z2,sum_lists(z1,z0,b),b),m)
t2 = lshift(z1,m*2)
return sum_lists(sum_lists(z0,t1,b),t2,b)
sum_lists and sub_lists returns unnormalized result - single digit can be greater than the base value. normalize function solved this problem.
All functions expect to get list of digits in the reverse order. For example 12 in base 10 should be written as [2,1]. Lets take a square of 9987654321.
» a = [1,2,3,4,5,6,7,8,9]
» res = mult_lists(a,a,10)
» res.reverse()
» res
[9, 7, 5, 4, 6, 1, 0, 5, 7, 7, 8, 9, 9, 7, 1, 0, 4, 1]
The goal of the Karatsuba multiplication is to improve on the divide-and conquer multiplication algorithm by making 3 recursive calls instead of four. Therefore, the only lines in your script that should contain a recursive call to the multiplication are those assigning z0,z1 and z2. Anything else will give you a worse complexity. You can't use pow to compute bm when you haven't defined multiplication yet (and a fortiori exponentiation), either.
For that, the algorithm crucially uses the fact that it is using a positional notation system. If you have a representation x of a number in base b, then x*bm is simply obtained by shifting the digits of that representation m times to the left. That shifting operation is essentially "free" with any positional notation system. That also means that if you want to implement that, you have to reproduce this positional notation, and the "free" shift. Either you chose to compute in base b=2 and use python's bit operators (or the bit operators of a given decimal, hex, ... base if your test platform has them), or you decide to implement for educational purposes something that works for an arbitrary b, and you reproduce this positional arithmetic with something like strings, arrays, or lists.
You have a solution with lists already. I like to work with strings in python, since int(s, base) will give you the integer corresponding to the string s seen as a number representation in base base: it makes tests easy. I have posted an heavily commented string-based implementation as a gist here, including string-to-number and number-to-string primitives for good measure.
You can test it by providing padded strings with the base and their (equal) length as arguments to mult:
In [169]: mult("987654321","987654321",10,9)
Out[169]: '966551847789971041'
If you don't want to figure out the padding or count string lengths, a padding function can do it for you:
In [170]: padding("987654321","2")
Out[170]: ('987654321', '000000002', 9)
And of course it works with b>10:
In [171]: mult('987654321', '000000002', 16, 9)
Out[171]: '130eca8642'
(Check with wolfram alpha)
I believe that the idea behind the technique is that the zi terms are computed using the recursive algorithm, but the results are not unified together that way. Since the net result that you want is
z0 B^2m + z1 B^m + z2
Assuming that you choose a suitable value of B (say, 2) you can compute B^m without doing any multiplications. For example, when using B = 2, you can compute B^m using bit shifts rather than multiplications. This means that the last step can be done without doing any multiplications at all.
One more thing - I noticed that you've picked a fixed value of m for the whole algorithm. Typically, you would implement this algorithm by having m always be a value such that B^m is half the number of digits in x and y when they are written in base B. If you're using powers of two, this would be done by picking m = ceil((log x) / 2).
Hope this helps!
In Python 2.7: Save this file as
def karatsuba(x,y):
"""Karatsuba multiplication algorithm.
Return the product of two numbers in an efficient manner
#author Shashank
date: 23-09-2018
x : int
First Number
y : int
Second Number
prod : int
The product of two numbers
>>> import Karatsuba.karatsuba
>>> a = 1234567899876543211234567899876543211234567899876543211234567890
>>> b = 9876543211234567899876543211234567899876543211234567899876543210
>>> Karatsuba.karatsuba(a,b)
if len(str(x)) == 1 or len(str(y)) == 1:
return x*y
n = max(len(str(x)), len(str(y)))
m = n/2
a = x/10**m
b = x%10**m
c = y/10**m
d = y%10**m
ac = karatsuba(a,c) #step 1
bd = karatsuba(b,d) #step 2
ad_plus_bc = karatsuba(a+b, c+d) - ac - bd #step 3
prod = ac*10**(2*m) + bd + ad_plus_bc*10**m #step 4
return prod