I am currently stuck on Project Euler problem 60. The problem goes like this:
The primes 3, 7, 109, and 673, are quite remarkable. By taking any two primes and concatenating them in any order the result will always be prime. For example, taking 7 and 109, both 7109 and 1097 are prime. The sum of these four primes, 792, represents the lowest sum for a set of four primes with this property.
Find the lowest sum for a set of five primes for which any two primes concatenate to produce another prime.
In Python, I used and created the following functions.
def isprime(n):
"""Primality test using 6k+-1 optimization."""
if n <= 3:
return n > 1
elif n % 2 == 0 or n % 3 == 0:
return False
i = 5
while i * i <= n:
if n % i == 0 or n % (i + 2) == 0:
return False
i += 6
return True
def prime_pair_sets(n):
heap = [[3,7]]
x = 11
#Phase 1: Create a list of length n.
while getMaxLength(heap)<n:
if isprime(x):
m = [x]
for lst in heap:
tmplst = [x]
for p in lst:
if primePair(x,p):
tmplst.append(p)
if len(tmplst)>len(m):
m=tmplst
heap.append(m)
x+=2
s = sum(maxList(heap))
#Phase 2: Find the lowest sum.
for li in heap:
y=x
while s>(sum(li)+y) and len(li)<n:
b = True
for k in li:
if not primePair(k,y):
b = False
if b == True:
li.append(y)
y+=2
if len(li)>=n:
s = sum(li)
return s
def getMaxLength(h):
m = 0
for s in h:
if len(s) > m:
m = len(s)
return m
def primePair(x, y):
return isprime(int(str(x)+str(y))) and isprime(int(str(y)+str(x)))
def maxList(h):
result = []
for s in h:
if len(s)> len(result):
result = s
return result
I executed the primePairSets function using the first phase only. After about an hour of waiting I got the sum of 74,617 (33647 + 23003 + 16451 + 1493 + 23). It turns out that it is not the sum we're looking for. I tried the problem again using both phases. After about two hours, I end up with the same wrong result. Obviously, the answer is less than 74,617. Can somebody help me find a more efficient way to solving this problem. Please tell me how to solve it.
Solution found
Primes: 13, 5197, 5701, 6733, 8389
Sum: 26,033
Run time reduced to ~10 seconds vs. 1-2 hours reported by OP solution.
Approach
Backtracking algorithm based upon extending path containing list of primes
A prime can be added to the path if its pairwise prime with all primes already in path
Use Sieve of Eratosthenes to precompute primes up to max we expect to need
For each prime, precompute which other primes it is pairwise prime with
Code
import functools
import time
# Decorator for function timing
def timer(func):
"""Print the runtime of the decorated function"""
#functools.wraps(func)
def wrapper_timer(*args, **kwargs):
start_time = time.perf_counter() # 1
value = func(*args, **kwargs)
end_time = time.perf_counter() # 2
run_time = end_time - start_time # 3
print(f"Finished {func.__name__!r} in {run_time:.4f} secs")
return value
return wrapper_timer
#*************************************************************
# Prime number algorithms
#-------------------------------------------------------------
def _try_composite(a, d, n, s):
if pow(a, d, n) == 1:
return False
for i in range(s):
if pow(a, 2**i * d, n) == n-1:
return False
return True # n is definitely composite
def is_prime(n, _precision_for_huge_n=16, _known_primes = set([2, 3, 5, 7])):
if n in _known_primes:
return True
if n > 6 and not (n % 6) in (1, 5):
# Check of form 6*q +/- 1
return False
if any((n % p) == 0 for p in _known_primes) or n in (0, 1):
return False
d, s = n - 1, 0
while not d % 2:
d, s = d >> 1, s + 1
# Returns exact according to http://primes.utm.edu/prove/prove2_3.html
if n < 1373653:
return not any(_try_composite(a, d, n, s) for a in (2, 3))
if n < 25326001:
return not any(_try_composite(a, d, n, s) for a in (2, 3, 5))
if n < 118670087467:
if n == 3215031751:
return False
return not any(_try_composite(a, d, n, s) for a in (2, 3, 5, 7))
if n < 2152302898747:
return not any(_try_composite(a, d, n, s) for a in (2, 3, 5, 7, 11))
if n < 3474749660383:
return not any(_try_composite(a, d, n, s) for a in (2, 3, 5, 7, 11, 13))
if n < 341550071728321:
return not any(_try_composite(a, d, n, s) for a in (2, 3, 5, 7, 11, 13, 17))
# otherwise
return not any(_try_composite(a, d, n, s)
for a in _known_primes[:_precision_for_huge_n])
def primes235(limit):
' Prime generator using Sieve of Eratosthenes with factorization wheel of 2, 3, 5 '
# Source: https://rosettacode.org/wiki/Sieve_of_Eratosthenes#Python
yield 2; yield 3; yield 5
if limit < 7: return
modPrms = [7,11,13,17,19,23,29,31]
gaps = [4,2,4,2,4,6,2,6,4,2,4,2,4,6,2,6] # 2 loops for overflow
ndxs = [0,0,0,0,1,1,2,2,2,2,3,3,4,4,4,4,5,5,5,5,5,5,6,6,7,7,7,7,7,7]
lmtbf = (limit + 23) // 30 * 8 - 1 # integral number of wheels rounded up
lmtsqrt = (int(limit ** 0.5) - 7)
lmtsqrt = lmtsqrt // 30 * 8 + ndxs[lmtsqrt % 30] # round down on the wheel
buf = [True] * (lmtbf + 1)
for i in range(lmtsqrt + 1):
if buf[i]:
ci = i & 7; p = 30 * (i >> 3) + modPrms[ci]
s = p * p - 7; p8 = p << 3
for j in range(8):
c = s // 30 * 8 + ndxs[s % 30]
buf[c::p8] = [False] * ((lmtbf - c) // p8 + 1)
s += p * gaps[ci]; ci += 1
for i in range(lmtbf - 6 + (ndxs[(limit - 7) % 30])): # adjust for extras
if buf[i]: yield (30 * (i >> 3) + modPrms[i & 7])
def prime_pair(x, y):
' Checks if two primes are prime pair (i.e. concatenation of two in either order is also a prime)'
return is_prime(int(str(x)+str(y))) and is_prime(int(str(y)+str(x)))
def find_pairs(primes):
' Creates dictionary of what primes can go with others as a pair'
prime_pairs = {}
for i, p in enumerate(primes):
pairs = set()
for j in range(i+1, len(primes)):
if prime_pair(p, primes[j]):
pairs.add(primes[j])
prime_pairs[p] = pairs
return prime_pairs
#*************************************************************
# Main functionality
#-------------------------------------------------------------
#timer
def find_groups(max_prime = 9000, n = 5):
'''
Find group smallest sum of primes that are pairwise prime
max_prime - max prime to consider
n - the size of the group
'''
def fully_connected(p, path):
'''
checks if p is connected to all elements of the path
(i.e. group of primes)
'''
return all(p in prime_pairs.get(path_item, set()) for path_item in path)
def backtracking(prime_pairs, n, path = None):
if path is None:
path = []
if len(path) == n:
yield path[:]
else:
if not path:
for p, v in prime_pairs.items():
if v:
yield from backtracking(prime_pairs, n, path + [p])
else:
p = path[-1]
for t in sorted(prime_pairs[p]):
if t > p and fully_connected(t, path):
yield from backtracking(prime_pairs, n, path + [t])
primes = list(primes235(max_prime)) # Sieve for list of primes up to max_pair
set_primes = set(primes) # set of primes (for easy test if number is prime)
prime_pairs = find_pairs(primes) # Table of primes and set of other primes they can pair with
return next(backtracking(prime_pairs, n), None)
Test Runs
for n, max_prime in [(2, 1000), (3, 1000), (4, 1000), (5, 9000)]:
print(find_groups(max_prime = max_prime, n = n))
Run Times
n max_prime Primes Found Run Time (secs)
2 1000 (3, 7) 0.1280
3 1000 (3, 37, 67) 0.1240
4 1000 (3, 7, 109, 673) 0.1233
5 40,000 (13, 5197, 5701, 6733, 8389) 7.1142
Note: Timings above performed on an older Windows desktop computer,
namely:
~7-year-old HP-Pavilion Desktop i7 CPU 920 # 2.67 GHz
I wrote this code based on an sieve of Eratosthenes which only stores the numbers 1 mod 6 and -1 mod 6. It is very fast.
def find_lowest_sum (pmax,num):
n=int(str(pmax)+str(pmax))
sieve5m6 = [True] * (n//6+1)
sieve1m6 = [True] * (n//6+1)
for i in range(1,int((n**0.5+1)/6)+1):
if sieve5m6[i]:
sieve5m6[6*i*i::6*i-1]=[False]*(((n//6+1)-6*i*i-1)//(6*i-1)+1)
sieve1m6[6*i*i-2*i::6*i-1]=[False]*(((n//6+1)-6*i*i+2*i-1)//(6*i-1)+1)
if sieve1m6[i]:
sieve5m6[6*i*i::6*i+1]=[False]*(((n//6+1)-6*i*i-1)//(6*i+1)+1)
sieve1m6[6*i*i+2*i::6*i+1]=[False]*(((n//6+1)-6*i*i-2*i-1)//(6*i+1)+1)
def test_concatenate (p1,p2):
ck=0
p3=int(str(p1)+str(p2))
if sieve1m6[(p3-1)//6] and p3%6==1:
ck=1
elif sieve5m6[(p3+1)//6]and p3%6==5:
ck=1
if ck==1:
p3=int(str(p2)+str(p1))
if sieve1m6[(p3-1)//6] and p3%6==1:
ck=2
elif sieve5m6[(p3+1)//6]and p3%6==5:
ck=2
if ck==2:
return True
else:
return False
kmax=(pmax+1)//6
s1=5*pmax
P1=[]
P2=[]
p=3
kmin=0
while p<s1//5:
if p==3 or(p%6==1 and sieve1m6[(p-1)//6])or(p%6==5 and sieve5m6[(p+1)//6]):
P=[]
if p%6==1:
kmin=p//6
elif p%6==5:
kmin=(p+1)//6
if sieve1m6[kmin]:
if test_concatenate(p,6*kmin+1):
P.append(6*kmin+1)
for k in range(kmin+1,kmax):
if sieve5m6[k]:
if test_concatenate(p,6*k-1):
P.append(6*k-1)
if sieve1m6[k]:
if test_concatenate(p,6*k+1):
P.append(6*k+1)
i1=0
while i1<len(P) and P[i1]<s1:
P1=[p]
s=p
P1.append(P[i1])
s+=P[i1]
for i2 in range(i1+1,len(P)):
for i3 in range(1,len(P1)):
ck=test_concatenate(P[i2],P1[i3])
if ck==False:
break
if len(P1)-1==i3 and ck==True:
P1.append(P[i2])
s+=P[i2]
if len(P1)==num:
if s<s1:
P2=P1
s1=s
break
i1+=1
p+=2
return P2
P=find_lowest_sum(9001,5)
print(P)
s=0
for i in range(0,len(P)):
s+=P[i]
print(s)
Related
I am currently trying to solve the problem "Is my friend cheating?" on Codewars.
The Text are the Details to the Problem:
A friend of mine takes the sequence of all numbers from 1 to n (where n > 0).
Within that sequence, he chooses two numbers, a and b.
He says that the product of a and b should be equal to the sum of all numbers in the sequence, excluding a and b.
Given a number n, could you tell me the numbers he excluded from the sequence?
The function takes the parameter: n (n is always strictly greater than 0) and returns an array or a string (depending on the language) of the form:
[(a, b), ...] or [[a, b], ...] or {{a, b}, ...} or or [{a, b}, ...]
with all (a, b) which are the possible removed numbers in the sequence 1 to n.
[(a, b), ...] or [[a, b], ...] or {{a, b}, ...} or ... will be sorted in increasing order of the "a".
It happens that there are several possible (a, b). The function returns an empty array (or an empty string) if no possible numbers are found which will prove that my friend has not told the truth! (Go: in this case return nil).
Examples:
removNb(26) should return [(15, 21), (21, 15)]
or
removNb(26) should return { {15, 21}, {21, 15} }
or
removeNb(26) should return [[15, 21], [21, 15]]
or
removNb(26) should return [ {15, 21}, {21, 15} ]
or
removNb(26) should return "15 21, 21 15"
or
in C:
removNb(26) should return {{15, 21}{21, 15}} tested by way of strings.
Function removNb should return a pointer to an allocated array of Pair pointers, each
one also allocated.
My Code I tried to solve the Problem with:
def removNb(n):
liste = [i+1 for i in range(n)]
result = []
for k in range(n):
for t in range(k,n):
m,n = liste[k], liste[t]
if m * n == sum(liste)-(m+n):
result.append((m, n))
result.append((n, m))
return result
This solution is working, but appently not enough optimised.
Does anyone have an idea how to optimize the Code further.
Since you never mutate this list:
liste = [i+1 for i in range(n)]
it's always just [1, ..., n], so you can replace any expression of the form liste[x] with just x + 1. And instead of computing sum(liste) repeatedly, you can just do it once. Applying those changes leaves you with:
def removNb(n):
total = sum(range(n + 1))
result = []
for k in range(n):
for t in range(k, n):
m, n = k + 1, t + 1
if m * n == total - (m + n):
result.append((m, n))
result.append((n, m))
return result
We can further eliminate the + 1s by modifying our ranges:
def removNb(n):
total = sum(range(n + 1))
result = []
for k in range(1, n + 1):
for t in range(k, n + 1):
m, n = k, t
if m * n == total - (m + n):
result.append((m, n))
result.append((n, m))
return result
which also makes the m, n assignment superfluous (also shadowing n inside the loop was confusing):
def removNb(n):
total = sum(range(n + 1))
result = []
for k in range(1, n + 1):
for t in range(k, n + 1):
if k * t == total - (k + t):
result.append((k, t))
result.append((t, k))
return result
You're recomputing sum(liste) for every iteration of the inner loop. This takes your runtime from O(n^2) to O(n^3).
Here's a quick fix:
def removNb(n):
nums = list(range(1, n + 1))
result = []
total = sum(nums)
for k in range(n):
for t in range(k,n):
m, n = nums[k], nums[t]
if m * n == total - (m + n):
result.append((m, n))
result.append((n, m))
return result
print(removNb(26))
If that still isn't fast enough, you can also cut down the runtime from O(n^2) to O(n log n) by using a binary rather than linear search in the inner loop.
There is a problem we need to solve in my university where we need to print the 10 smallest prime fibonacci numbers in an ascending order.So far i have found this code but it takes about 2 min to print them and was wondering if there was a faster way to print them.
import math
def isSquare(n):
sr = (int)(math.sqrt(n))
return (sr * sr == n)
def printPrimeAndFib(n):
prime = [True] * (n + 1)
p = 2
while (p * p <= n):
if (prime[p] == True):
for i in range(p * 2, n + 1, p):
prime[i] = False
p = p + 1
list=[]
for i in range(2, n + 1):
if (prime[i] and (isSquare(5 * i * i + 4) > 0 or
isSquare(5 * i * i - 4) > 0)):
list.append(i)
print(list)
n = 500000000
printPrimeAndFib(n)
With a Fibonacci generator and a prime filter. Takes about 0.002 seconds.
from itertools import islice
from math import isqrt
def fibonacci():
a, b = 0, 1
while True:
yield a
a, b = b, a + b
def is_prime(n):
return n > 1 and all(map(n.__mod__, range(2, isqrt(n) + 1)))
fibonacci_primes = filter(is_prime, fibonacci())
print(list(islice(fibonacci_primes, 10)))
Output:
[2, 3, 5, 13, 89, 233, 1597, 28657, 514229, 433494437]
This approach generates Fibonacci numbers until it finds 10 that are prime; takes approx 0.001-0.002 seconds
from math import sqrt
def isprime(x):
#deal with prime special cases 1 and 2
if x==2:
return True
if x == 1 or x%2==0:
return False
#Fast-ish prime checking - only scan odds numbers up to sqrt(x)
for n in range(3,int(sqrt(x))+1,2):
if x%n==0:
return False
return True
fib_primes = []
current=1
previous=1
while (len(fib_primes)<10):
if isprime(current):
fib_primes.append(current)
next = current+previous
previous = current
current = next
print(fib_primes)
My assignment:
Given a number, say prod (for product), we search two Fibonacci
numbers F(n) and F(n+1) verifying
F(n) * F(n+1) = prod if F(n) * F(n+1) = prod
Your function productFib takes an integer (prod) and returns an array:
[F(n), F(n+1), True] else
F(m) being the smallest one such as F(m) * F(m+1) > prod
[F(m), F(m+1), False]
Examples:
productFib(714) # should return [21, 34, True],
# since F(8) = 21, F(9) = 34 and 714 = 21 * 34
productFib(800) # should return [34, 55, False],
# since F(8) = 21, F(9) = 34, F(10) = 55 and 21 * 34 < 800 < 34 * 55
My code:
def f(i):
if i == 0 :
return 0
if i == 1 :
return 1
return f(i-2) + f(i-1)
def productFib(prod):
i=1
final1 = 0
final2 = 0
while(f(i)*f(i+1) != prod and f(i)*f(i+1) < prod):
i += 1
final1 = f(i)
final2 = f(i+1)
if(final1*final2 == prod):
return [final1,final2,True]
else:
return [final1,final2,False]
I am new to programming; this runs very slow for large numbers. How can I reduce the time complexity?
def productFib(prod):
a, b = 0, 1
while prod > a * b:
a, b = b, a + b
return [a, b, prod == a * b]
First and foremost, your f function is horridly time-consuming: it computes f(n) for low n many times. Memoize the function: keep results in a list, and just refer to that list when you compute again.
memo = [1, 1]
def f(i):
global memo
if i >= len(memo):
# Compute all f(k) where k < i
f_of_i = f(i-2) + f(i-1)
memo.append(f_of_i)
return memo[i]
Note that this sequence still guarantees that you will fill in memo in numerical order: f(i-2) is called before f(i-1), and both are called before adding f(i) to the list.
Calling f(100000000000000) (10^14) with this returns instantaneously. I haven't tried it with higher numbers.
UPDATE
I ran it with increasing powers of 10. At 10^1650, it was still printing output at full speed, and I interrupted the run.
IMPROVEMENT
You can do even better (for many applications) by directly computing f(i) from the closed-form equation:
root5 = 5 ** 0.5
phi = (1 + root5) / 2
psi = (1 - root5) / 2
def f(i):
return int(round((phi ** i - psi ** i) / root5))
MORE IMPROVEMENT
Directly compute the proper value of i.
f(i) * f(i+1) is very close to phi**(2i+1) / 5.
def productFib(prod):
power = math.log(prod*5) / log_phi
i = int(round(power-1, 7)/2) + 1
low = f(i)
high = f(i+1)
# print i, low, high
answer = [low, high, low*high == prod]
return answer
print productFib(714)
print productFib(800)
print productFib(100000000000000000)
Output:
[21, 34, True]
[34, 55, False]
[267914296, 433494437, False]
You can increase the performance quite a bit by making use of a generator.
def fib():
# Generator that yields the last two fibonacci numbers on each iteration.
# Initialize
np = 0
n = 1
# Loop forever. The main function will break the loop.
while True:
# Calculate the next number
nn = np + n
# Yield the previous number and the new one.
yield n, nn
# Update the generator state.
np = n
n = nn
def product_fib(prod):
# Loop through the generator.
for f_0, f_1 in fib():
# If the product is larger or equal to the product return.
if f_0 * f_1 >= prod:
# The last element in the list is the resut of the equality check.
return [f_0, f_1, f_0 * f_1 == prod]
t0 = perf_counter()
res = productFib(8000000000)
t = perf_counter() - t0
print("Original:", t, res)
t0 = perf_counter()
res = product_fib(8000000000) # 8000000000
t = perf_counter() - t0
print("New:", t, res)
The output of this is
Original: 0.8113621789962053 [75025, 121393, False]
New: 1.3276992831379175e-05 [75025, 121393, False]
Edit
If you want a single line. It works, but don't use it, it's not really the most practical solution. The first one is faster anyways.
print((lambda prod: (lambda n: next(([a, b, (a * b) == prod] for a, b in ([n.append(n[0] + n[1]), n.pop(0), n][-1] for _ in iter(int, 1)) if (a * b) >= prod)))([0, 1]))(714))
Your fib function is recalculating from the bottom every time. You should save values you already know. My python is rusty, but I think this will do it:
dict = {}
def f(i):
if dict.has_key(i):
return dict[i]
if i == 0 :
return 0
if i == 1 :
return 1
sum = f(i-2) + f(i-1)
dict[i] = sum
return sum
def productFib(prod):
i=1
final1 = 0
final2 = 0
while(f(i)*f(i+1) != prod and f(i)*f(i+1) < prod):
i += 1
final1 = f(i)
final2 = f(i+1)
if(final1*final2 == prod):
return [final1,final2,True]
else:
return [final1,final2,False]
Here are a couple of different algorithms that are increasingly optimized (starting with the original algorithm).
Algorithm 3 uses an iterative fibonacci with caching along with guessing the square root (which should be between the two multiples of the fibonacci product).
Timing Results:
Original algorithm: 4.43634369118898
Algorithm 2 (recursive fib): 1.5160450420565503
Algorithm 2 (iter fib): 0.03543769357344395
Algorithm 2 (iter fib + caching): 0.013537414072276377
Algorithm 3 (iter fib + caching + guessing): 0.0017255337946799898
Setup:
import timeit
from random import randint
from math import sqrt
from bisect import bisect
cache = [0, 1]
Nth Fibonacci Number Algorithms:
def f(n):
"""Recursive solution: O(2^n)"""
if n == 0:
return 0
if n == 1:
return 1
return f(n-2) + f(n-1)
def ff(n):
"""Iterative solution: O(n)"""
a, b = 0, 1
for _ in range(0, n):
a, b = b, a + b
return a
def fff(n):
"""Iterative solution + caching"""
length = len(cache)
if n >= length:
for i in range(length, n+1):
cache.append(cache[i-1] + cache[i-2])
return cache[n]
Fibonacci Product Algorithms:
def pf1(prod):
"""Original algorithm."""
i=1
final1 = 0
final2 = 0
while(f(i)*f(i+1) != prod and f(i)*f(i+1) < prod):
i += 1
final1 = f(i)
final2 = f(i+1)
if(final1*final2 == prod):
return [final1,final2,True]
else:
return [final1,final2,False]
def pf2(prod, fib_func):
"""Algorithm 2.
Removes extra logic and duplicate function calls."""
i = 1
guess = 0
while guess < prod:
final1 = fib_func(i)
final2 = fib_func(i + 1)
guess = final1 * final2
i += 1
return [final1, final2, guess == prod]
def pf3(prod, fib_func):
"""Algorithm 3.
Implements square root as a guess."""
guess = sqrt(prod) # Numbers will be near square root
i = 2
while cache[-1] < guess:
fff(i)
i += 1
insertion_spot = bisect(cache, guess)
test1 = cache[insertion_spot-1]
test2 = cache[insertion_spot]
return [test1, test2, (test1 * test2) == prod]
Testing & Output:
prods = [randint(3, 99999) for _ in range(1000)] # 1000 random products
print("Original algorithm: ", timeit.timeit("for prod in prods: pf1(prod)", number=1, setup="from __main__ import pf1, prods"))
print("Algorithm 2 (recursive fib): ", timeit.timeit("for prod in prods: pf2(prod, f)", number=1, setup="from __main__ import pf2, prods, f"))
print("Algorithm 2 (iter fib): ", timeit.timeit("for prod in prods: pf2(prod, ff)", number=1, setup="from __main__ import pf2, prods, ff"))
print("Algorithm 2 (iter fib + caching): ", timeit.timeit("for prod in prods: pf2(prod, fff)", number=1, setup="from __main__ import pf2, prods, fff"))
print("Algorithm 3 (iter fib + caching + guessing):", timeit.timeit("for prod in prods: pf3(prod, fff)", number=1, setup="from __main__ import pf3, prods, fff, cache; cache = [0, 1]"))
Helpful Time Complexities of Python "List" object: https://wiki.python.org/moin/TimeComplexity
You may use the following code for your ProductFib function
def productFib2(prod):
i=0
final1 = f(i)
final2 = f(i+1)
while(final1 * final2 < prod):
i += 1
final1 = final2
final2 = f(i+1)
if(final1*final2 == prod):
return [final1,final2,True]
else:
return [final1,final2,False]
Two ways: slow and fast way to do it
Slow:
def productFib(prod):
i=0
final1 = f(i)
final2 = f(i+1)
while(final1 * final2 < prod):
i += 1
final1 = final2 #this is the same as final1 = f(i) again but, the value is already in memory no need to recalc it
final2 = f(i+1)
if(final1*final2 == prod):
return [final1,final2,True]
else:
return [final1,final2,False] #if non equivilant product
def f(i): #fib(0) == 0 fib(1) == 1
if i == 0 :
return 0
if i == 1 :
return 1
return f(i-2) + f(i-1)
output of the product in the while loop:
0
1
2
6
15
Fast:
def productFib(prod):
a, b = 0, 1
while prod > a * b:
1. List item
a, b = b, a + b #a is equal to b and b is equal to the sum of a and b
return [a, b, prod == a * b] #return a and b as well as the conditional if a times b is equal to the product
You don't need to check if something is apart of the fib series here because we can simply iterate over the series within the while loop without having to call another function recursively
the fib function is defined as: F(n) = F(n-1) + F(n-2) with F(0) = 0 and F(1) = 1. therefore, you can generate this with the algorithm in the while loop above
following that a = 0, b = 1 next a = 1, b = 1 next a = 1, b = 2 next a = 2, b = 3
next a = 3, b = 5 ...
I've been trying to implement Dixon's factorization method in python, and I'm a bit confused. I know that you need to give some bound B and some number N and search for numbers between sqrtN and N whose squares are B-smooth, meaning all their factors are in the set of primes less than or equal to B. My question is, given N of a certain size, what determines B so that the algorithm will produce non-trivial factors of N? Here is a wikipedia article about the algorithm, and if it helps, here is my code for my implementation:
def factor(N, B):
def isBsmooth(n, b):
factors = []
for i in b:
while n % i == 0:
n = int(n / i)
if not i in factors:
factors.append(i)
if n == 1 and factors == b:
return True
return False
factor1 = 1
while factor1 == 1 or factor1 == N:
Bsmooth = []
BsmoothMod = []
for i in range(int(N ** 0.5), N):
if len(Bsmooth) < 2 and isBsmooth(i ** 2 % N, B):
Bsmooth.append(i)
BsmoothMod.append(i ** 2 % N)
gcd1 = (Bsmooth[0] * Bsmooth[1]) % N
gcd2 = int((BsmoothMod[0] * BsmoothMod[1]) ** 0.5)
factor1 = gcd(gcd1 - gcd2, N)
factor2 = int(N / factor1)
return (factor1, factor2)
Maybe someone could help clean my code up a bit, too? It seems very inefficient.
This article discusses the optimal size for B: https://web.archive.org/web/20160205002504/https://vmonaco.com/dixons-algorithm-and-the-quadratic-sieve/. Briefly, the optimal value is thought to be exp((logN loglogN)^(1/2)).
[ I wrote this for a different purpose, but you might find it interesting. ]
Given x2 ≡ y2 (mod n) with x ≠ ± y, about half the time gcd(x−y, n) is a factor of n. This congruence of squares, observed by Maurice Kraitchik in the 1920s, is the basis for several factoring methods. One of those methods, due to John Dixon, is important in theory because its sub-exponential run time can be proven, though it is too slow to be useful in practice.
Dixon's method begins by choosing a bound b ≈ e√(log n log log n) and identifying the factor base of all primes less than b that are quadratic residues of n (their jacobi symbol is 1).
function factorBase(n, b)
fb := [2]
for p in tail(primes(b))
if jacobi(n, p) == 1
append p to fb
return fb
Then repeatedly choose an integer r on the range 1 < r < n, calculate its square modulo n, and if the square is smooth over the factor base add it to a list of relations, stopping when there are more relations than factors in the factor base, plus a small reserve for those cases that fail. The idea is to identify a set of relations, using linear algebra, where the factor base primes combine to form a square. Then take the square root of the product of all the factor base primes in the relations, take the product of the related r, and calculate the gcd to identify the factor.
struct rel(x, ys)
function dixon(n, fb, count)
r, rels := floor(sqrt(n)), []
while count > 0
fs := smooth((r * r) % n, fb)
if fs is not null
append rel(r, fs) to rels
count := count - 1
r := r + 1
return rels
A number n is smooth if all its factors are in the factor base, which is determined by trial division; the smooth function returns a list of factors, which is null if n doesn't completely factor over the factor base.
function smooth(n, fb)
fs := []
for f in fb
while n % f == 0
append f to fs
n := n / f
if n == 1 return fs
return []
A factor is determined by submitting the accumulated relations to the linear algebra of the congruence of square solver.
For example, consider the factorization of 143. Choose r = 17, so r2 ≡ 3 (mod 143). Then choose r = 19, so r2 ≡ 75 ≡ 3 · 52. Those two relations can be combined as (17 · 19)2 ≡ 32 · 52 ≡ 152 (mod 143), and the two factors are gcd(17·19 − 15, 143) = 11 and gcd(17·19 + 15, 143) = 13. This sometimes fails; for instance, the relation 212 ≡ 22 (mod 143) can be combined with the relation on 19, but the two factors produced, 1 and 143, are trivial.
Thanks for very interesting question!
In pure Python I implemented from scratch Dixon Factorization Algorithm in 3 different flavors:
Using simplest sieve. I'm creating u64 array with all numbers in range [N; N * 2), which signify z^2 value. This array hold result of multiplication of prime numbers. Then through sieving process I iterate all factor base prime numbers and do array[k] *= p in those k positions that are divisible by p. Finally when sieved array is ready I check both that a) array index k is a perfect square, b) and array[k] == k - N. Second b) condition means that all multiplied p primes give final number, this is only true if number is divisible only by factor-base primes, i.e. it is B-smooth. This is simplest and most slowest out of my 3 solutions.
Second solution uses SymPy library to factorize every z^2. I iterate all possible z and do sympy.factorint(z * z), this gives factorization of z^2. If this factorization contains only small primes, i.e. from factor base, then I collect such z and z^2 for later processing. This version of algorithm is also slow, but much faster than first one.
Third solution uses a kind of sieving used in Quadratic Sieve. This sieving process is fastest of all three algorithms. Basically what it does, it finds all roots of equation x^2 = N (mod p) for all primes in factor base, as I have just few primes root finding is done through simple loop through all variants, for bigger primes one can use Shanks Tonelli algorithm of finding root, which is really fast. Only around 50% of primes give a root solution at all, hence only half of primes are actually used in Quadratic Sieve. Roots of such equation can be used to generate lots of solutions at once, because root + k * p is also a valid solution for all k. Sieving is done through array[offset(root) :: p] += Log2(p). Here instead of multiplication of first algorithm I used adding a logarithm of prime. First it is a bit faster to add a number than to multiply. Secondly, what is more important is that it supports any size of number, e.g. even 256-bit. While multiplying is possible only till 64-bit number, because Numpy has no 128 or 256 bit integers support. After logartithms are added, I check which logarithms are equal to logarithm of original z^2 number, this numbers are final sieved numbers.
After all three algorithms above have sieved all z^2 then I do Linear Algebra stage through Gaussian Elemination algorithm. This stage is meant to find such combination of B-smooth z^2 numbers which after multiplication of their prime factors give final number with all EVEN prime powers.
Lets call a Relation a triple z, z^2, prime factors of z^2. Basically all relations are given to Gaussian Elemination stage, where even combinations are found.
Even powers of prime numbers give us equality a^2 = b^2 (mod N), from where we can get a factor by doing factor = GCD(a + b, N), here GCD is Greatest Common Divisor found through Euclidean Algorithm. This GCD sometimes gives trivial factors 1 and N, in this case other even combinations should be checked.
To be 100% sure to get even combinations I do Sieving stage till I find a bit more than amount of prime numbers amount of relations, actually around 105% of amount of prime numbers. This extra 5% of relations ensure us that we certainly will get dependent linear equations in Gaussian stage. All these dependent equation form even combinations.
Actually we need a bit more dependent equations, not just 1 more than amount of primes, but around 5%-10% more, only because some (50-60% of them as I can see experimentally) dependencies give only trivial factor 1 or N. Hence extra equations are needed.
Put a look at console output at the end of my post. This console output shows all the impressions from my program. There I run in parallel (multi-threaded) both 2nd (Sieve_B) and 3rd (Sieve_C) algorithms. 1st one (Sieve_A) is not run by my program because it is so slow that you'll wait forever for it to finish.
At the very end of source file you can tweak variable bits = 64 to some other size, like bits = 96. This is amount of bits in composite number N. This N is created as a product of just two random prime numbers of equal size. Such a composite consisting of two equal in size primes is usually called RSA Number.
Also find B = 1 << 10, this tells degree of B-smoothness, basically factor base consists of all possible primes < B. You may increase this B limit, this will give more frequent answers of sieved z^2 hence whole factoring becomes much faster. The only limitation of huge size of B is Linear Algebra stage (Gaussian Elemination), because with bigger factor base you have to solve more linear equations of bigger size. And my Gauss is done not in very optimal way, for example instead of keeping bits as np.uint8 you may keep bits as dense np.uint64, this will increase Linear Algebra speed by 8x times more.
You may also find variable M = 1 << 23, which tells how large is sieving array size, in other words it is block size that is processed at once. Bigger block is a bit faster, but not much. Bigger values of M will not give much difference because it only tells what size of tasks sieving process is split into, it doesn't influence any computation power. More than that bigger M will occupy more memory, so you can't increases it infinitely, only till you have enough memory.
Besides all mentioned above algorithms I also used Fermat Primality Test, also Sieve of Eratosthenes (for generating prime factor base).
Plus also implemented my own algorithm of filtering square numbers. For this I take some composite modulus that looks close to Primorial, like mod = 2 * 2 * 2 * 3 * 3 * 5 * 7 * 11 * 13. And inside boolean array I mark all numbers modulus mod that are squares. Later when any number K should be checked if it is square or not I get flag_array[K % mod] and if it is True then number is "Possibly" squares, while if it is False then number is "Definitely" not square. Thus this filter gives false positives sometimes but never false negatives. This filter checking stage filters out 95% of non-squares, remaining 5% of possibly squares can be double-checked through math.isqrt().
Please, click below on Try it online! link, to test run my program on online server of ReplIt. This will give you best impression, especially if you have no Python or no personal laptop. My code below can be just run straight away after only PIP-installing python -m pip numpy sympy.
Try it online!
import threading
def GenPrimes_SieveOfEratosthenes(end):
import numpy as np
composites = np.zeros((end,), dtype = np.uint8)
for p in range(2, len(composites)):
if composites[p]:
continue
if p * p >= end:
break
composites[p * p :: p] = 1
primes = []
for p in range(2, len(composites)):
if not composites[p]:
primes.append(p)
return np.array(primes, dtype = np.uint32)
def Print(*pargs, __state = (threading.RLock(),), **nargs):
with __state[0]:
print(*pargs, flush = True, **nargs)
def IsSquare(n, *, state = []):
if len(state) == 0:
import numpy as np
Print('Pre-computing squares filter...')
squares_filter = 2 * 2 * 2 * 3 * 3 * 5 * 7 * 11 * 13
squares = np.zeros((squares_filter,), dtype = np.uint8)
squares[(np.arange(0, squares_filter, dtype = np.uint64) ** 2) % squares_filter] = 1
state.extend([squares_filter, squares])
if not state[1][n % state[0]]:
return False, None
import math
root = math.isqrt(n)
return root ** 2 == n, root
def FactorRef(x):
import sympy
return dict(sorted(sympy.factorint(x).items()))
def CheckZ(z, N, primes):
z2 = pow(z, 2, N)
factors = FactorRef(z2)
assert all(p <= primes[-1] for p in factors), (primes[-1], factors, N, z, z2)
return z
def SieveSimple(N, primes):
import time, math, numpy as np
Print('Simple Sieve of B-smooth z^2...')
sieve_block = 1 << 21
rep0_time = 0
for iiblock, iblock in enumerate(range(N, N * 2, sieve_block)):
if time.time() - rep0_time >= 30:
Print(f'Block {iiblock:>3} (2^{math.log2(max(iblock - N, 1)):>5.2f})')
rep0_time = time.time()
iblock_end = iblock + sieve_block
sieve_arr = np.ones((sieve_block,), dtype = np.uint64)
iblock_modN = iblock % N
for p in primes:
mp = 1
while True:
if mp * p >= sieve_block:
break
mp *= p
off = (mp - iblock_modN % mp) % mp
sieve_arr[off :: mp] *= p
for i in range(1 if iblock == N else 0, sieve_block):
num = iblock + i
z2 = num - N
if sieve_arr[i] < z2:
continue
assert sieve_arr[i] == z2, (sieve_arr[i], round(math.log2(sieve_arr[i]), 3), z2)
is_square, z = IsSquare(num)
if not is_square:
continue
#Print('z', z, 'z^2', z2)
yield CheckZ(z, N, primes)
def SieveFactor(N, primes):
import math
Print('Factor Sieve of B-smooth z^2...')
for iz, z in enumerate(range(math.isqrt(N - 1) + 1, math.isqrt(N * 2 - 1) + 1)):
z2 = z ** 2 - N
assert 0 <= z2 and z2 < N, (z, z2)
factors = FactorRef(z2)
if any(p > primes[-1] for p in factors):
continue
#Print('iz', iz, 'z', z, 'z^2', z2, 'z^2 factors', factors)
yield CheckZ(z, N, primes)
def BinarySearch(begin, end, Test):
while begin + 1 < end:
mid = (begin + end - 1) >> 1
if Test(mid):
end = mid + 1
else:
begin = mid + 1
assert begin + 1 == end and Test(begin), (begin, end, Test(begin))
return begin
def ModSqrt(n, p):
n %= p
def Ret(x):
if pow(x, 2, p) != n:
return []
nx = (p - x) % p
if x == nx:
return [x]
elif x <= nx:
return [x, nx]
else:
return [nx, x]
#if p % 4 == 3 and sympy.isprime(p):
# return Ret(pow(n, (p + 1) // 4, p))
for i in range(p):
if pow(i, 2, p) == n:
return Ret(i)
return []
def SieveQuadratic(N, primes):
import math, numpy as np
# https://en.wikipedia.org/wiki/Quadratic_sieve
# https://www.rieselprime.de/ziki/Multiple_polynomial_quadratic_sieve
M = 1 << 23
def Log2I(x):
return int(round(math.log2(max(1, x)) * (1 << 24)))
def Log2IF(li):
return li / (1 << 24)
Print('Quadratic Sieve of B-smooth z^2...')
plogs = {}
for p in primes:
plogs[int(p)] = Log2I(int(p))
qprimes = []
B = int(primes[-1]) + 1
for p in primes:
p = int(p)
res = []
mp = 1
while True:
if mp * p >= B:
break
mp *= p
roots = ModSqrt(N, mp)
if len(roots) == 0:
if mp == p:
break
continue
res.append((mp, tuple(roots)))
if len(res) > 0:
qprimes.append(res)
qprimes_lin = np.array([pinfo[0][0] for pinfo in qprimes], dtype = np.uint32)
yield qprimes_lin
Print('QSieve num primes', len(qprimes), f'({len(qprimes) * 100 / len(primes):.1f}%)')
x_begin0 = math.isqrt(N - 1) + 1
assert N <= x_begin0 ** 2
for iblock in range(1 << 30):
if (x_begin0 + (iblock + 1) * M) ** 2 - N >= N:
break
x_begin = x_begin0 + iblock * M
if iblock != 0:
Print('\n', end = '')
Print(f'Block {iblock} (2^{math.log2(max(1, x_begin ** 2 - N)):>6.2f})...')
a = np.zeros((M,), np.uint32)
for pinfo in qprimes:
p = pinfo[0][0]
plog = np.uint32(plogs[p])
for imp, (mp, roots) in enumerate(pinfo):
off_done = set()
for root in roots:
for off in range(mp):
if ((x_begin + off) ** 2 - N) % mp == 0 and off not in off_done:
break
else:
continue
a[off :: mp] += plog
off_done.add(off)
logs = np.log2(np.array((np.arange(M).astype(np.float64) + x_begin) ** 2 - N, dtype = np.float64))
logs2if = Log2IF(a.astype(np.float64))
logs_diff = np.abs(logs - logs2if)
for ix in range(M):
if logs_diff[ix] > 0.3:
continue
z = x_begin + ix
z2 = z * z - N
factors = FactorRef(z2)
assert all(p <= primes[-1] for p, c in factors.items())
#Print('iz', ix, 'z', z, 'z^2', z2, f'(2^{math.log2(max(1, z2)):>6.2f})', ', z^2 factors', factors)
yield CheckZ(z, N, primes)
def LinAlg(N, zs, primes):
import numpy as np
Print('Linear algebra...')
Print('Factoring...')
m = np.zeros((len(zs), len(primes) + len(zs)), dtype = np.uint8)
def SwapRows(i, j):
t = np.copy(m[i])
m[i][...] = m[j][...]
m[j][...] = t[...]
def MatToStr(m):
s = '\n'
for i in range(len(m)):
for j in range(len(m[i])):
s += str(m[i, j])
s += '\n'
return s[1:-1]
for iz, z in enumerate(zs):
z2 = z * z - N
fs = FactorRef(z2)
for p, c in fs.items():
i = np.searchsorted(primes, p, 'right') - 1
assert i >= 0 and i < len(primes) and primes[i] == p, (i, primes[i])
m[iz, i] = (int(m[iz, i]) + c) % 2
m[iz, len(primes) + iz] = 1
Print('Gaussian elemination...')
#Print(MatToStr(m)); Print()
one_col, one_rows = 0, 0
while True:
while True:
for i in range(one_rows, len(m)):
if m[i, one_col]:
break
else:
one_col += 1
if one_col >= len(primes):
break
continue
break
if one_col >= len(primes):
break
assert m[i, one_col]
assert np.all(m[i, :one_col] == 0)
for j in range(len(m)):
if i == j:
continue
if not m[j, one_col]:
continue
m[j][...] ^= m[i][...]
SwapRows(one_rows, i)
one_rows += 1
one_col += 1
assert np.all(m[one_rows:, :len(primes)] == 0)
zeros = m[one_rows:, len(primes):]
Print(f'Even combinations ({len(m) - one_rows}):')
Print(MatToStr(zeros))
return zeros
def ProcessResults(N, zs, la_zeros):
import math
Print('Computing final results...')
factors = []
for i in range(len(la_zeros)):
zero = la_zeros[i]
assert len(zero) == len(zs)
cz = []
for j in range(len(zero)):
if not zero[j]:
continue
z = zs[j]
z2 = z * z - N
cz.append((z, z2, FactorRef(z2)))
a = 1
for z, z2, fs in cz:
a = (a * z) % N
cnts = {}
for z, z2, fs in cz:
for p, c in fs.items():
cnts[p] = cnts.get(p, 0) + c
cnts = dict(sorted(cnts.items()))
b = 1
for p, c in cnts.items():
assert c % 2 == 0, (p, c, cnts)
b = (b * pow(p, c // 2, N)) % N
factor = math.gcd(a + b, N)
Print('a', str(a).rjust(len(str(N))), ' b', str(b).rjust(len(str(N))), ' factor', factor if factor != N else 'N')
if factor != 1 and factor != N:
factors.append(factor)
return factors
def SieveCollectResults(N, its):
import time, threading, queue, traceback, math
K = len(its)
qs = [queue.Queue() for i in range(K)]
last_dot, finish = False, False
def Get(it, ty, need, compul):
nonlocal last_dot, finish
try:
cnt = 0
for iz, z in enumerate(it):
if finish:
break
if iz < 4:
z2 = z * z - N
Print(('\n' if last_dot else '') + 'Sieve_' + ('C', 'B', 'A')[K - 1 - ty], ' iz', iz,
'z', z, 'z^2', z2, f'(2^{math.log2(max(1, z2)):>6.2f})', ', z^2 factors', FactorRef(z2))
last_dot = False
else:
Print(('.', 'b', 'a')[K - 1 - ty], end = '')
last_dot = True
qs[ty].put(z)
cnt += 1
if cnt >= need:
break
except:
Print(traceback.format_exc())
thr = []
for ty, (it, need, compul) in enumerate(its):
thr.append(threading.Thread(target = Get, args = (it, ty, need, compul), daemon = True))
thr[-1].start()
for ithr, t in enumerate(thr):
if its[ithr][2]:
t.join()
finish = True
if last_dot:
Print()
zs = [[] for i in range(K)]
for iq, q in enumerate(qs):
while not qs[iq].empty():
zs[iq].append(qs[iq].get())
return zs
def DixonFactor(N):
import time, math, numpy as np, sys
B = 1 << 10
primes = GenPrimes_SieveOfEratosthenes(B)
Print('Num primes', len(primes), 'last prime', primes[-1])
IsSquare(0)
it = SieveQuadratic(N, primes)
qprimes = next(it)
zs = SieveCollectResults(N, [
#(SieveSimple(N, primes), 3, False),
(SieveFactor(N, primes), 3, False),
(it, round(len(qprimes) * 1.06 + 0.5), True),
])[-1]
la_zeros = LinAlg(N, zs, qprimes)
fs = ProcessResults(N, zs, la_zeros)
if len(fs) > 0:
Print('Factored, factors', sorted(set(fs)))
else:
Print('Failed to factor! Try running program again...')
def IsPrime_Fermat(n, *, ntrials = 32):
import random
if n <= 16:
return n in (2, 3, 5, 7, 11, 13)
for i in range(ntrials):
if pow(random.randint(2, n - 2), n - 1, n) != 1:
return False
return True
def GenRandom(bits):
import random
return random.randrange(1 << (bits - 1), 1 << bits)
def RandPrime(bits):
while True:
n = GenRandom(bits) | 1
if IsPrime_Fermat(n):
return n
def Main():
import math
bits = 64
N = RandPrime(bits // 2) * RandPrime((bits + 1) // 2)
Print('N to factor', N, f'(2^{math.log2(N):>5.1f})')
DixonFactor(N)
if __name__ == '__main__':
Main()
Console output:
N to factor 10086068308526249063 (2^ 63.1)
Num primes 172 last prime 1021
Pre-computing squares filter...
Quadratic Sieve of B-smooth z^2...
Factor Sieve of B-smooth z^2...
QSieve num primes 78 (45.3%)
Block 0 (2^ 32.14)...
Sieve_C iz 0 z 3175858067 z^2 6153202727426 (2^ 42.48) , z^2 factors {2: 1, 29: 2, 67: 1, 191: 1, 487: 1, 587: 1}
Sieve_C iz 1 z 3175859246 z^2 13641877439453 (2^ 43.63) , z^2 factors {31: 1, 61: 1, 167: 1, 179: 1, 373: 1, 647: 1}
Sieve_C iz 2 z 3175863276 z^2 39239319203113 (2^ 45.16) , z^2 factors {31: 1, 109: 1, 163: 1, 277: 1, 311: 1, 827: 1}
Sieve_C iz 3 z 3175867115 z^2 63623612174162 (2^ 45.85) , z^2 factors {2: 1, 29: 1, 41: 1, 47: 1, 61: 1, 127: 1, 197: 1, 373: 1}
.........................................................................
Sieve_B iz 0 z 3175858067 z^2 6153202727426 (2^ 42.48) , z^2 factors {2: 1, 29: 2, 67: 1, 191: 1, 487: 1, 587: 1}
......
Linear algebra...
Factoring...
Gaussian elemination...
Even combinations (7):
01000000000000000000000000000000000000000000000000001100000000000000000000000000000
11010100000010000100100000010011100000000001001001001001011001000000110001010000000
11001011000101111100011111001011010011000111101000001001011000001111100101001110000
11010010010000110110101100110101000100001100010011100011101000100010011011001001000
00010110111010000010000010000111010001010010111001000011011011101110110001001100100
00000010111000110010100110001111010101001000011010110011101000110001101101100100010
10010001111111101100011110111110110100000110111011010001010001100000010100000100001
Computing final results...
a 9990591196683978238 b 9990591196683978238 factor 1
a 936902490212600845 b 3051457985176300292 factor 3960321451
a 1072293684177681642 b 8576178744296269655 factor 2546780213
a 1578121372922149955 b 1578121372922149955 factor 1
a 2036768191033218175 b 8049300117493030888 factor N
a 1489997751586754228 b 2231890938565281666 factor 3960321451
a 9673227070299809069 b 3412883990935144956 factor 3960321451
Factored, factors [2546780213, 3960321451]
Been trying to implement Rabin-Miller Strong Pseudoprime Test today.
Have used Wolfram Mathworld as reference, lines 3-5 sums up my code pretty much.
However, when I run the program, it says (sometimes) that primes (even low such as 5, 7, 11) are not primes. I've looked over the code for a very long while and cannot figure out what is wrong.
For help I've looked at this site aswell as many other sites but most use another definition (probably the same, but since I'm new to this kind of math, I can't see the same obvious connection).
My Code:
import random
def RabinMiller(n, k):
# obviously not prime
if n < 2 or n % 2 == 0:
return False
# special case
if n == 2:
return True
s = 0
r = n - 1
# factor n - 1 as 2^(r)*s
while r % 2 == 0:
s = s + 1
r = r // 2 # floor
# k = accuracy
for i in range(k):
a = random.randrange(1, n)
# a^(s) mod n = 1?
if pow(a, s, n) == 1:
return True
# a^(2^(j) * s) mod n = -1 mod n?
for j in range(r):
if pow(a, 2**j*s, n) == -1 % n:
return True
return False
print(RabinMiller(7, 5))
How does this differ from the definition given at Mathworld?
1. Comments on your code
A number of the points I'll make below were noted in other answers, but it seems useful to have them all together.
In the section
s = 0
r = n - 1
# factor n - 1 as 2^(r)*s
while r % 2 == 0:
s = s + 1
r = r // 2 # floor
you've got the roles of r and s swapped: you've actually factored n − 1 as 2sr. If you want to stick to the MathWorld notation, then you'll have to swap r and s in this section of the code:
# factor n - 1 as 2^(r)*s, where s is odd.
r, s = 0, n - 1
while s % 2 == 0:
r += 1
s //= 2
In the line
for i in range(k):
the variable i is unused: it's conventional to name such variables _.
You pick a random base between 1 and n − 1 inclusive:
a = random.randrange(1, n)
This is what it says in the MathWorld article, but that article is written from the mathematician's point of view. In fact it is useless to pick the base 1, since 1s = 1 (mod n) and you'll waste a trial. Similarly, it's useless to pick the base n − 1, since s is odd and so (n − 1)s = −1 (mod n). Mathematicians don't have to worry about wasted trials, but programmers do, so write instead:
a = random.randrange(2, n - 1)
(n needs to be at least 4 for this optimization to work, but we can easily arrange that by returning True at the top of the function when n = 3, just as you do for n = 2.)
As noted in other replies, you've misunderstood the MathWorld article. When it says that "n passes the test" it means that "n passes the test for the base a". The distinguishing fact about primes is that they pass the test for all bases. So when you find that as = 1 (mod n), what you should do is to go round the loop and pick the next base to test against.
# a^(s) = 1 (mod n)?
x = pow(a, s, n)
if x == 1:
continue
There's an opportunity for optimization here. The value x that we've just computed is a20 s (mod n). So we could test it immediately and save ourselves one loop iteration:
# a^(s) = ±1 (mod n)?
x = pow(a, s, n)
if x == 1 or x == n - 1:
continue
In the section where you calculate a2j s (mod n) each of these numbers is the square of the previous number (modulo n). It's wasteful to calculate each from scratch when you could just square the previous value. So you should write this loop as:
# a^(2^(j) * s) = -1 (mod n)?
for _ in range(r - 1):
x = pow(x, 2, n)
if x == n - 1:
break
else:
return False
It's a good idea to test for divisibility by small primes before trying Miller–Rabin. For example, in Rabin's 1977 paper he says:
In implementing the algorithm we incorporate some laborsaving steps. First we test for divisibility by any prime p < N, where, say N = 1000.
2. Revised code
Putting all this together:
from random import randrange
small_primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31] # etc.
def probably_prime(n, k):
"""Return True if n passes k rounds of the Miller-Rabin primality
test (and is probably prime). Return False if n is proved to be
composite.
"""
if n < 2: return False
for p in small_primes:
if n < p * p: return True
if n % p == 0: return False
r, s = 0, n - 1
while s % 2 == 0:
r += 1
s //= 2
for _ in range(k):
a = randrange(2, n - 1)
x = pow(a, s, n)
if x == 1 or x == n - 1:
continue
for _ in range(r - 1):
x = pow(x, 2, n)
if x == n - 1:
break
else:
return False
return True
In addition to what Omri Barel has said, there is also a problem with your for loop. You will return true if you find one a that passes the test. However, all a have to pass the test for n to be a probable prime.
I'm wondering about this piece of code:
# factor n - 1 as 2^(r)*s
while r % 2 == 0:
s = s + 1
r = r // 2 # floor
Let's take n = 7. So n - 1 = 6. We can express n - 1 as 2^1 * 3. In this case r = 1 and s = 3.
But the code above finds something else. It starts with r = 6, so r % 2 == 0. Initially, s = 0 so after one iteration we have s = 1 and r = 3. But now r % 2 != 0 and the loop terminates.
We end up with s = 1 and r = 3 which is clearly incorrect: 2^r * s = 8.
You should not update s in the loop. Instead, you should count how many times you can divide by 2 (this will be r) and the result after the divisions will be s. In the example of n = 7, n - 1 = 6, we can divide it once (so r = 1) and after the division we end up with 3 (so s = 3).
Here's my version:
# miller-rabin pseudoprimality checker
from random import randrange
def isStrongPseudoprime(n, a):
d, s = n-1, 0
while d % 2 == 0:
d, s = d/2, s+1
t = pow(a, d, n)
if t == 1:
return True
while s > 0:
if t == n - 1:
return True
t, s = pow(t, 2, n), s - 1
return False
def isPrime(n, k):
if n % 2 == 0:
return n == 2
for i in range(1, k):
a = randrange(2, n)
if not isStrongPseudoprime(n, a):
return False
return True
If you want to know more about programming with prime numbers, I modestly recommend this essay on my blog.
You should also have a look at Wikipedia, where known "random" sequences gives guaranteed answers up to a given prime.
if n < 1,373,653, it is enough to test a = 2 and 3;
if n < 9,080,191, it is enough to test a = 31 and 73;
if n < 4,759,123,141, it is enough to test a = 2, 7, and 61;
if n < 2,152,302,898,747, it is enough to test a = 2, 3, 5, 7, and 11;
if n < 3,474,749,660,383, it is enough to test a = 2, 3, 5, 7, 11, and 13;
if n < 341,550,071,728,321, it is enough to test a = 2, 3, 5, 7, 11, 13, and 17;