Related
I want to perform addition and multiplication in F_{2^8}
I currently have this code which seems to work for add but doesn't work for multiply; the issue seems to be that when I modulo by 100011011 (which represents x^8 + x^4 + x^3 + x + 1), it doesn't seem to do it. Another idea would be to use numpy.polynomial but it isn't as intuitive.
def toBinary(self, n):
return ''.join(str(1 & int(n) >> i) for i in range(8)[::-1])
def add(self, x, y):
"""
"10111001" + "10010100" = "00101101"
"""
if len(x)<8:
self.add('0'+x,y)
elif len(y)<8:
self.add(x,'0'+y)
try:
a = int(x,2); b = int(y,2)
z = int(x)+int(y)
s = ''
for i in str(z):
if int(i)%2 == 0:
s+='0'
else:
s+='1'
except:
return '00000000'
return s
def multiply(self, x, y):
"""
"10111001" * "10010100" = "10110010"
"""
if len(x)<8:
self.multiply('0'+x,y)
elif len(y)<8:
self.multiply(x,'0'+y)
result = '00000000'
result = '00000000'
while y!= '00000000' :
print(f'x:{x},y:{y},result:{result}')
if int(y[-1]) == 1 :
result = self.add(result ,x)
y = self.add(y, '00000001')
x = self.add(self.toBinary(int(x,2)<<1),'100011011')
y = self.toBinary(int(y,2)>>1) #b = self.multiply(b,inverse('00000010'))
return result
Python example for add (same as subtract), multiply, divide, and inverse. Assumes the input parameters are 8 bit values, and there is no check for divide by 0.
def add(x, y): # add is xor
return x^y
def sub(x, y): # sub is xor
return x^y
def mpy(x, y): # mpy two 8 bit values
p = 0b100011011 # mpy modulo x^8+x^4+x^3+x+1
m = 0 # m will be product
for i in range(8):
m = m << 1
if m & 0b100000000:
m = m ^ p
if y & 0b010000000:
m = m ^ x
y = y << 1
return m
def div(x, y): # divide using inverse
return mpy(x, inv(y)) # (no check for y = 0)
def inv(x): # x^254 = 1/x
p=mpy(x,x) # p = x^2
x=mpy(p,p) # x = x^4
p=mpy(p,x) # p = x^(2+4)
x=mpy(x,x) # x = x^8
p=mpy(p,x) # p = x^(2+4+8)
x=mpy(x,x) # x = x^16
p=mpy(p,x) # p = x^(2+4+8+16)
x=mpy(x,x) # x = x^32
p=mpy(p,x) # p = x^(2+4+8+16+32)
x=mpy(x,x) # x = x^64
p=mpy(p,x) # p = x^(2+4+8+16+32+64)
x=mpy(x,x) # x = x^128
p=mpy(p,x) # p = x^(2+4+8+16+32+64+128)
return p
print hex(add(0b01010101, 0b10101010)) # returns 0xff
print hex(mpy(0b01010101, 0b10101010)) # returns 0x59
print hex(div(0b01011001, 0b10101010)) # returns 0x55
For GF(2^n), both add and subtract are XOR. This means multiplies are carryless and divides are borrowless. The X86 has a carryless multiply for XMM registers, PCLMULQDQ. Divide by a constant can be done with carryless multiply by 2^64 / constant and using the upper 64 bits of the product. The inverse constant is generated using a loop for borrowless divide.
The reason for this is GF(2^n) elements are polynomials with 1 bit coefficients, (the coefficients are elements of GF(2)).
For GF(2^8), it would be simpler to generate exponentiate and log tables. Example C code:
#define POLY (0x11b)
/* all non-zero elements are powers of 3 for POLY == 0x11b */
typedef unsigned char BYTE;
/* ... */
static BYTE exp2[512];
static BYTE log2[256];
/* ... */
static void Tbli()
{
int i;
int b;
b = 0x01; /* init exp2 table */
for(i = 0; i < 512; i++){
exp2[i] = (BYTE)b;
b = (b << 1) ^ b; /* powers of 3 */
if(b & 0x100)
b ^= POLY;
}
log2[0] = 0xff; /* init log2 table */
for(i = 0; i < 255; i++)
log2[exp2[i]] = (BYTE)i;
}
/* ... */
static BYTE GFMpy(BYTE m0, BYTE m1) /* multiply */
{
if(0 == m0 || 0 == m1)
return(0);
return(exp2[log2[m0] + log2[m1]]);
}
/* ... */
static BYTE GFDiv(BYTE m0, BYTE m1) /* divide */
{
if(0 == m0)
return(0);
return(exp2[log2[m0] + 255 - log2[m1]]);
}
I created a Python package galois that extends NumPy arrays over finite fields. Working with GF(2^8) is quite easy, see my below example.
In [1]: import galois
In [2]: GF = galois.GF(2**8, irreducible_poly="x^8 + x^4 + x^3 + x + 1")
In [3]: print(GF.properties)
GF(2^8):
characteristic: 2
degree: 8
order: 256
irreducible_poly: x^8 + x^4 + x^3 + x + 1
is_primitive_poly: False
primitive_element: x + 1
# Your original values from your example
In [4]: a = GF(0b10111001); a
Out[4]: GF(185, order=2^8)
In [5]: b = GF(0b10010100); b
Out[5]: GF(148, order=2^8)
In [6]: c = a * b; c
Out[6]: GF(178, order=2^8)
# You can display the result as a polynomial over GF(2)
In [7]: GF.display("poly");
# This matches 0b10110010
In [8]: c
Out[8]: GF(x^7 + x^5 + x^4 + x, order=2^8)
You can work with arrays too.
In [12]: a = GF([1, 2, 3, 4]); a
Out[12]: GF([1, 2, 3, 4], order=2^8)
In [13]: b = GF([100, 110, 120, 130]); b
Out[13]: GF([100, 110, 120, 130], order=2^8)
In [14]: a * b
Out[14]: GF([100, 220, 136, 62], order=2^8)
It's open source, so you can review all the code. Here's a snippet of multiplication in GF(2^m). All of the inputs are integers. Here's how to perform the "polynomial multiplication" using integers with characteristic 2.
def _multiply_calculate(a, b, CHARACTERISTIC, DEGREE, IRREDUCIBLE_POLY):
"""
a in GF(2^m), can be represented as a degree m-1 polynomial a(x) in GF(2)[x]
b in GF(2^m), can be represented as a degree m-1 polynomial b(x) in GF(2)[x]
p(x) in GF(2)[x] with degree m is the irreducible polynomial of GF(2^m)
a * b = c
= (a(x) * b(x)) % p(x) in GF(2)
= c(x)
= c
"""
ORDER = CHARACTERISTIC**DEGREE
# Re-order operands such that a > b so the while loop has less loops
if b > a:
a, b = b, a
c = 0
while b > 0:
if b & 0b1:
c ^= a # Add a(x) to c(x)
b >>= 1 # Divide b(x) by x
a <<= 1 # Multiply a(x) by x
if a >= ORDER:
a ^= IRREDUCIBLE_POLY # Compute a(x) % p(x)
return c
The same example runs as follows.
In [72]: _multiply_calculate(0b10111001, 0b10010100, 2, 8, 0b100011011)
Out[72]: 178
In [73]: bin(_multiply_calculate(0b10111001, 0b10010100, 2, 8, 0b100011011))
Out[73]: '0b10110010'
I am studying algorithms in Python and solving a question that is:
Let x(k) be a recursively defined string with base case x(1) = "123"
and x(k) is "1" + x(k-1) + "2" + x(k-1) + "3". Given three positive
integers k,s, and t, find the substring x(k)[s:t].
For example, if k = 2, s = 1 and t = 5,x(2) = 112321233 and x(2)[1:5]
= 1232.
I have solved it using a simple recursive function:
def generate_string(k):
if k == 1:
return "123"
part = generate_string(k -1)
return ("1" + part + "2" + part + "3")
print(generate_string(k)[s,t])
Although my first approach gives correct answer, the problem is that it takes too long to build string x when k is greater than 20. The program need to be finished within 16 seconds while k is below 50. I have tried to use memoization but it does not help as I am not allowed to cache each test case. I thus think that I must avoid using recursive function to speed up the program. Is there any approaches I should consider?
We can see that the string represented by x(k) grows exponentially in length with increasing k:
len(x(1)) == 3
len(x(k)) == len(x(k-1)) * 2 + 3
So:
len(x(k)) == 3 * (2**k - 1)
For k equal to 100, this amounts to a length of more than 1030. That's more characters than there are atoms in a human body!
Since the parameters s and t will take (in comparison) a tiny, tiny slice of that, you should not need to produce the whole string. You can still use recursion though, but keep passing an s and t range to each call. Then when you see that this slice will actually be outside of the string you would generate, then you can just exit without recursing deeper, saving a lot of time and (string) space.
Here is how you could do it:
def getslice(k, s, t):
def recur(xsize, s, t):
if xsize == 0 or s >= xsize or t <= 0:
return ""
smaller = (xsize - 3) // 2
return ( ("1" if s <= 0 else "")
+ recur(smaller, s-1, t-1)
+ ("2" if s <= smaller+1 < t else "")
+ recur(smaller, s-smaller-2, t-smaller-2)
+ ("3" if t >= xsize else "") )
return recur(3 * (2**k - 1), s, t)
This doesn't use any caching of x(k) results... In my tests this was fast enough.
Based on #FMc's answer, here's some python3 code that calculates x(k, s, t):
from functools import lru_cache
from typing import *
def f_len(k) -> int:
return 3 * ((2 ** k) - 1)
#lru_cache(None)
def f(k) -> str:
if k == 1:
return "123"
return "1" + f(k - 1) + "2" + f(k - 1) + "3"
def substring_(k, s, t, output) -> None:
# Empty substring.
if s >= t or k == 0:
return
# (An optimization):
# If all the characters need to be included, just calculate the string and cache it.
if s == 0 and t == f_len(k):
output.append(f(k))
return
if s == 0:
output.append("1")
sub_len = f_len(k - 1)
substring_(k - 1, max(0, s - 1), min(sub_len, t - 1), output)
if s <= 1 + sub_len < t:
output.append("2")
substring_(k - 1, max(0, s - sub_len - 2), min(sub_len, t - sub_len - 2), output)
if s <= 2 * (1 + sub_len) < t:
output.append("3")
def substring(k, s, t) -> str:
output: List[str] = []
substring_(k, s, t, output)
return "".join(output)
def test(k, s, t) -> bool:
actual = substring(k, s, t)
expected = f(k)[s:t]
return actual == expected
assert test(1, 0, 3)
assert test(2, 2, 6)
assert test(2, 1, 5)
assert test(2, 0, f_len(2))
assert test(3, 0, f_len(3))
assert test(8, 44, 89)
assert test(10, 1001, 2022)
assert test(14, 12345, 45678)
assert test(17, 12345, 112345)
# print(substring(30, 10000, 10100))
print("Tests passed")
This is an interesting problem. I'm not sure whether I'll have time to write the code, but here's an outline of how you can solve it. Note: see the better answer from trincot.
As discussed in the comments, you cannot generate the actual string: you will quickly run out of memory as k grows. But you can easily compute the length of that string.
First some notation:
f(k) : The generated string.
n(k) : The length of f(k).
nk1 : n(k-1), which is used several times in table below.
For discussion purposes, we can divide the string into the following regions. The start/end values use standard Python slice numbering:
Region | Start | End | Len | Subtring | Ex: k = 2
-------------------------------------------------------------------
A | 0 | 1 | 1 | 1 | 0:1 1
B | 1 | 1 + nk1 | nk1 | f(k-1) | 1:4 123
C | 1 + nk1 | 2 + nk1 | 1 | 2 | 4:5 2
D | 2 + nk1 | 2 + nk1 + nk1 | nk1 | f(k-1) | 5:8 123
E | 2 + nk1 + nk1 | 3 + nk1 + nk1 | 1 | 3 | 8:9 3
Given k, s, and t we need to figure out which region of the string is relevant. Take a small example:
k=2, s=6, and t=8.
The substring defined by 6:8 does not require the full f(k). We only need
region D, so we can turn our attention to f(k-1).
To make the shift from k=2 to k=1, we need to adjust s and t: specifically,
we need to subtract the total length of regions A + B + C. For k=2, that
length is 5 (1 + nk1 + 1).
Now we are dealing with: k=1, s=1, and t=3.
Repeat as needed.
Whenever k gets small enough, we stop this nonsense and actually generate the string so we can grab the needed substring directly.
It's possible that some values of s and t could cross region boundaries. In that case, divide the problem into two subparts (one for each region needed). But the general idea is the same.
Here's a commented iterative version in JavaScript that's very easy to convert to Python.
In addition to being what you asked for, that is non-recursive, it allows us to solve things like f(10000, 10000, 10050), which seem to exceed Python default recursion depth.
// Generates the full string
function g(k){
if (k == 1)
return "123";
prev = g(k - 1);
return "1" + prev + "2" + prev + "3";
}
function size(k){
return 3 * ((1 << k) - 1);
}
// Given a depth and index,
// we'd like (1) a string to
// output, (2) the possible next
// part of the same depth to
// push to the stack, and (3)
// possibly the current section
// mapped deeper to also push to
// the stack. (2) and (3) can be
// in a single list.
function getParams(depth, i){
const psize = size(depth - 1);
if (i == 0){
return ["1", [[depth, 1 + psize], [depth - 1, 0]]];
} else if (i < 1 + psize){
return ["", [[depth, 1 + psize], [depth - 1, i - 1]]];
} else if (i == 1 + psize){
return ["2", [[depth, 2 + 2 * psize], [depth - 1, 0]]];
} else if (i < 2 + 2 * psize){
return ["", [[depth, 2 + 2 * psize], [depth - 1, i - 2 - psize]]];
} else {
return ["3", []];
}
}
function f(k, s, t){
let len = t - s;
let str = "";
let stack = [[k, s]];
while (str.length < len){
const [depth, i] = stack.pop();
if (depth == 1){
const toTake = Math.min(3 - i, len - str.length);
str = str + "123".substr(i, toTake);
} else {
const [s, rest] = getParams(depth, i);
str = str + s;
stack.push(...rest);
}
}
return str;
}
function test(k, s, t){
const l = g(k).substring(s, t);
const r = f(k, s, t);
console.log(g(k).length);
//console.log(g(k))
console.log(l);
console.log(r);
console.log(l == r);
}
test(1, 0, 3);
test(2, 2, 6);
test(2, 1, 5);
test(4, 44, 45);
test(5, 30, 40);
test(7, 100, 150);
Is there an integer square root somewhere in python, or in standard libraries? I want it to be exact (i.e. return an integer), and raise an exception if the input isn't a perfect square.
I tried using this code:
def isqrt(n):
i = int(math.sqrt(n) + 0.5)
if i**2 == n:
return i
raise ValueError('input was not a perfect square')
But it's ugly and I don't really trust it for large integers. I could iterate through the squares and give up if I've exceeded the value, but I assume it would be kinda slow to do something like that. Also, surely this is already implemented somewhere?
See also: Check if a number is a perfect square.
Note: There is now math.isqrt in stdlib, available since Python 3.8.
Newton's method works perfectly well on integers:
def isqrt(n):
x = n
y = (x + 1) // 2
while y < x:
x = y
y = (x + n // x) // 2
return x
This returns the largest integer x for which x * x does not exceed n. If you want to check if the result is exactly the square root, simply perform the multiplication to check if n is a perfect square.
I discuss this algorithm, and three other algorithms for calculating square roots, at my blog.
Update: Python 3.8 has a math.isqrt function in the standard library!
I benchmarked every (correct) function here on both small (0…222) and large (250001) inputs. The clear winners in both cases are gmpy2.isqrt suggested by mathmandan in first place, followed by Python 3.8’s math.isqrt in second, followed by the ActiveState recipe linked by NPE in third. The ActiveState recipe has a bunch of divisions that can be replaced by shifts, which makes it a bit faster (but still behind the native functions):
def isqrt(n):
if n > 0:
x = 1 << (n.bit_length() + 1 >> 1)
while True:
y = (x + n // x) >> 1
if y >= x:
return x
x = y
elif n == 0:
return 0
else:
raise ValueError("square root not defined for negative numbers")
Benchmark results:
gmpy2.isqrt() (mathmandan): 0.08 µs small, 0.07 ms large
int(gmpy2.isqrt())*: 0.3 µs small, 0.07 ms large
Python 3.8 math.isqrt: 0.13 µs small, 0.9 ms large
ActiveState (optimized as above): 0.6 µs small, 17.0 ms large
ActiveState (NPE): 1.0 µs small, 17.3 ms large
castlebravo long-hand: 4 µs small, 80 ms large
mathmandan improved: 2.7 µs small, 120 ms large
martineau (with this correction): 2.3 µs small, 140 ms large
nibot: 8 µs small, 1000 ms large
mathmandan: 1.8 µs small, 2200 ms large
castlebravo Newton’s method: 1.5 µs small, 19000 ms large
user448810: 1.4 µs small, 20000 ms large
(* Since gmpy2.isqrt returns a gmpy2.mpz object, which behaves mostly but not exactly like an int, you may need to convert it back to an int for some uses.)
Sorry for the very late response; I just stumbled onto this page. In case anyone visits this page in the future, the python module gmpy2 is designed to work with very large inputs, and includes among other things an integer square root function.
Example:
>>> import gmpy2
>>> gmpy2.isqrt((10**100+1)**2)
mpz(10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001L)
>>> gmpy2.isqrt((10**100+1)**2 - 1)
mpz(10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000L)
Granted, everything will have the "mpz" tag, but mpz's are compatible with int's:
>>> gmpy2.mpz(3)*4
mpz(12)
>>> int(gmpy2.mpz(12))
12
See my other answer for a discussion of this method's performance relative to some other answers to this question.
Download: https://code.google.com/p/gmpy/
Here's a very straightforward implementation:
def i_sqrt(n):
i = n.bit_length() >> 1 # i = floor( (1 + floor(log_2(n))) / 2 )
m = 1 << i # m = 2^i
#
# Fact: (2^(i + 1))^2 > n, so m has at least as many bits
# as the floor of the square root of n.
#
# Proof: (2^(i+1))^2 = 2^(2i + 2) >= 2^(floor(log_2(n)) + 2)
# >= 2^(ceil(log_2(n) + 1) >= 2^(log_2(n) + 1) > 2^(log_2(n)) = n. QED.
#
while m*m > n:
m >>= 1
i -= 1
for k in xrange(i-1, -1, -1):
x = m | (1 << k)
if x*x <= n:
m = x
return m
This is just a binary search. Initialize the value m to be the largest power of 2 that does not exceed the square root, then check whether each smaller bit can be set while keeping the result no larger than the square root. (Check the bits one at a time, in descending order.)
For reasonably large values of n (say, around 10**6000, or around 20000 bits), this seems to be:
Faster than the Newton's method implementation described by user448810.
Much, much slower than the gmpy2 built-in method in my other answer.
Comparable to, but somewhat slower than, the Longhand Square Root described by nibot.
All of these approaches succeed on inputs of this size, but on my machine, this function takes around 1.5 seconds, while #Nibot's takes about 0.9 seconds, #user448810's takes around 19 seconds, and the gmpy2 built-in method takes less than a millisecond(!). Example:
>>> import random
>>> import timeit
>>> import gmpy2
>>> r = random.getrandbits
>>> t = timeit.timeit
>>> t('i_sqrt(r(20000))', 'from __main__ import *', number = 5)/5. # This function
1.5102493192883117
>>> t('exact_sqrt(r(20000))', 'from __main__ import *', number = 5)/5. # Nibot
0.8952787937686366
>>> t('isqrt(r(20000))', 'from __main__ import *', number = 5)/5. # user448810
19.326695976676184
>>> t('gmpy2.isqrt(r(20000))', 'from __main__ import *', number = 5)/5. # gmpy2
0.0003599147067689046
>>> all(i_sqrt(n)==isqrt(n)==exact_sqrt(n)[0]==int(gmpy2.isqrt(n)) for n in (r(1500) for i in xrange(1500)))
True
This function can be generalized easily, though it's not quite as nice because I don't have quite as precise of an initial guess for m:
def i_root(num, root, report_exactness = True):
i = num.bit_length() / root
m = 1 << i
while m ** root < num:
m <<= 1
i += 1
while m ** root > num:
m >>= 1
i -= 1
for k in xrange(i-1, -1, -1):
x = m | (1 << k)
if x ** root <= num:
m = x
if report_exactness:
return m, m ** root == num
return m
However, note that gmpy2 also has an i_root method.
In fact this method could be adapted and applied to any (nonnegative, increasing) function f to determine an "integer inverse of f". However, to choose an efficient initial value of m you'd still want to know something about f.
Edit: Thanks to #Greggo for pointing out that the i_sqrt function can be rewritten to avoid using any multiplications. This yields an impressive performance boost!
def improved_i_sqrt(n):
assert n >= 0
if n == 0:
return 0
i = n.bit_length() >> 1 # i = floor( (1 + floor(log_2(n))) / 2 )
m = 1 << i # m = 2^i
#
# Fact: (2^(i + 1))^2 > n, so m has at least as many bits
# as the floor of the square root of n.
#
# Proof: (2^(i+1))^2 = 2^(2i + 2) >= 2^(floor(log_2(n)) + 2)
# >= 2^(ceil(log_2(n) + 1) >= 2^(log_2(n) + 1) > 2^(log_2(n)) = n. QED.
#
while (m << i) > n: # (m<<i) = m*(2^i) = m*m
m >>= 1
i -= 1
d = n - (m << i) # d = n-m^2
for k in xrange(i-1, -1, -1):
j = 1 << k
new_diff = d - (((m<<1) | j) << k) # n-(m+2^k)^2 = n-m^2-2*m*2^k-2^(2k)
if new_diff >= 0:
d = new_diff
m |= j
return m
Note that by construction, the kth bit of m << 1 is not set, so bitwise-or may be used to implement the addition of (m<<1) + (1<<k). Ultimately I have (2*m*(2**k) + 2**(2*k)) written as (((m<<1) | (1<<k)) << k), so it's three shifts and one bitwise-or (followed by a subtraction to get new_diff). Maybe there is still a more efficient way to get this? Regardless, it's far better than multiplying m*m! Compare with above:
>>> t('improved_i_sqrt(r(20000))', 'from __main__ import *', number = 5)/5.
0.10908999762373242
>>> all(improved_i_sqrt(n) == i_sqrt(n) for n in xrange(10**6))
True
Long-hand square root algorithm
It turns out that there is an algorithm for computing square roots that you can compute by hand, something like long-division. Each iteration of the algorithm produces exactly one digit of the resulting square root while consuming two digits of the number whose square root you seek. While the "long hand" version of the algorithm is specified in decimal, it works in any base, with binary being simplest to implement and perhaps the fastest to execute (depending on the underlying bignum representation).
Because this algorithm operates on numbers digit-by-digit, it produces exact results for arbitrarily large perfect squares, and for non-perfect-squares, can produce as many digits of precision (to the right of the decimal place) as desired.
There are two nice writeups on the "Dr. Math" site that explain the algorithm:
Square Roots in Binary
Longhand Square Roots
And here's an implementation in Python:
def exact_sqrt(x):
"""Calculate the square root of an arbitrarily large integer.
The result of exact_sqrt(x) is a tuple (a, r) such that a**2 + r = x, where
a is the largest integer such that a**2 <= x, and r is the "remainder". If
x is a perfect square, then r will be zero.
The algorithm used is the "long-hand square root" algorithm, as described at
http://mathforum.org/library/drmath/view/52656.html
Tobin Fricke 2014-04-23
Max Planck Institute for Gravitational Physics
Hannover, Germany
"""
N = 0 # Problem so far
a = 0 # Solution so far
# We'll process the number two bits at a time, starting at the MSB
L = x.bit_length()
L += (L % 2) # Round up to the next even number
for i in xrange(L, -1, -1):
# Get the next group of two bits
n = (x >> (2*i)) & 0b11
# Check whether we can reduce the remainder
if ((N - a*a) << 2) + n >= (a<<2) + 1:
b = 1
else:
b = 0
a = (a << 1) | b # Concatenate the next bit of the solution
N = (N << 2) | n # Concatenate the next bit of the problem
return (a, N-a*a)
You could easily modify this function to conduct additional iterations to calculate the fractional part of the square root. I was most interested in computing roots of large perfect squares.
I'm not sure how this compares to the "integer Newton's method" algorithm. I suspect that Newton's method is faster, since it can in principle generate multiple bits of the solution in one iteration, while the "long hand" algorithm generates exactly one bit of the solution per iteration.
Source repo: https://gist.github.com/tobin/11233492
One option would be to use the decimal module, and do it in sufficiently-precise floats:
import decimal
def isqrt(n):
nd = decimal.Decimal(n)
with decimal.localcontext() as ctx:
ctx.prec = n.bit_length()
i = int(nd.sqrt())
if i**2 != n:
raise ValueError('input was not a perfect square')
return i
which I think should work:
>>> isqrt(1)
1
>>> isqrt(7**14) == 7**7
True
>>> isqrt(11**1000) == 11**500
True
>>> isqrt(11**1000+1)
Traceback (most recent call last):
File "<ipython-input-121-e80953fb4d8e>", line 1, in <module>
isqrt(11**1000+1)
File "<ipython-input-100-dd91f704e2bd>", line 10, in isqrt
raise ValueError('input was not a perfect square')
ValueError: input was not a perfect square
Python's default math library has an integer square root function:
math.isqrt(n)
Return the integer square root of the nonnegative integer n. This is the floor of the exact square root of n, or equivalently the greatest integer a such that a² ≤ n.
Seems like you could check like this:
if int(math.sqrt(n))**2 == n:
print n, 'is a perfect square'
Update:
As you pointed out the above fails for large values of n. For those the following looks promising, which is an adaptation of the example C code, by Martin Guy # UKC, June 1985, for the relatively simple looking binary numeral digit-by-digit calculation method mentioned in the Wikipedia article Methods of computing square roots:
from math import ceil, log
def isqrt(n):
res = 0
bit = 4**int(ceil(log(n, 4))) if n else 0 # smallest power of 4 >= the argument
while bit:
if n >= res + bit:
n -= res + bit
res = (res >> 1) + bit
else:
res >>= 1
bit >>= 2
return res
if __name__ == '__main__':
from math import sqrt # for comparison purposes
for i in range(17)+[2**53, (10**100+1)**2]:
is_perfect_sq = isqrt(i)**2 == i
print '{:21,d}: math.sqrt={:12,.7G}, isqrt={:10,d} {}'.format(
i, sqrt(i), isqrt(i), '(perfect square)' if is_perfect_sq else '')
Output:
0: math.sqrt= 0, isqrt= 0 (perfect square)
1: math.sqrt= 1, isqrt= 1 (perfect square)
2: math.sqrt= 1.414214, isqrt= 1
3: math.sqrt= 1.732051, isqrt= 1
4: math.sqrt= 2, isqrt= 2 (perfect square)
5: math.sqrt= 2.236068, isqrt= 2
6: math.sqrt= 2.44949, isqrt= 2
7: math.sqrt= 2.645751, isqrt= 2
8: math.sqrt= 2.828427, isqrt= 2
9: math.sqrt= 3, isqrt= 3 (perfect square)
10: math.sqrt= 3.162278, isqrt= 3
11: math.sqrt= 3.316625, isqrt= 3
12: math.sqrt= 3.464102, isqrt= 3
13: math.sqrt= 3.605551, isqrt= 3
14: math.sqrt= 3.741657, isqrt= 3
15: math.sqrt= 3.872983, isqrt= 3
16: math.sqrt= 4, isqrt= 4 (perfect square)
9,007,199,254,740,992: math.sqrt=9.490627E+07, isqrt=94,906,265
100,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,020,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,001: math.sqrt= 1E+100, isqrt=10,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,001 (perfect square)
The script below extracts integer square roots. It uses no divisions, only bitshifts, so it is quite fast. It uses Newton's method on the inverse square root, a technique made famous by Quake III Arena as mentioned in the Wikipedia article, Fast inverse square root.
The strategy of the algorithm to compute s = sqrt(Y) is as follows.
Reduce the argument Y to y in the range [1/4, 1), i.e., y = Y/B, with 1/4 <= y < 1, where B is an even power of 2, so B = 2**(2*k) for some integer k. We want to find X, where x = X/B, and x = 1 / sqrt(y).
Determine a first approximation to X using a quadratic minimax polynomial.
Refine X using Newton's method.
Calculate s = X*Y/(2**(3*k)).
We don't actually create fractions or perform any divisions. All the arithmetic is done with integers, and we use bit shifting to divide by various powers of B.
Range reduction lets us find a good initial approximation to feed to Newton's method. Here's a version of the 2nd degree minimax polynomial approximation to the inverse square root in the interval [1/4, 1):
(Sorry, I've reversed the meaning of x & y here, to conform to the usual conventions). The maximum error of this approximation is around 0.0355 ~= 1/28. Here's a graph showing the error:
Using this poly, our initial x starts with at least 4 or 5 bits of precision. Each round of Newton's method doubles the precision, so it doesn't take many rounds to get thousands of bits, if we want them.
""" Integer square root
Uses no divisions, only shifts
"Quake" style algorithm,
i.e., Newton's method for 1 / sqrt(y)
Uses a quadratic minimax polynomial for the first approximation
Written by PM 2Ring 2022.01.23
"""
def int_sqrt(y):
if y < 0:
raise ValueError("int_sqrt arg must be >= 0, not %s" % y)
if y < 2:
return y
# print("\n*", y, "*")
# Range reduction.
# Find k such that 1/4 <= y/b < 1, where b = 2 ** (k*2)
j = y.bit_length()
# Round k*2 up to the next even number
k2 = j + (j & 1)
# k and some useful multiples
k = k2 >> 1
k3 = k2 + k
k6 = k3 << 1
kd = k6 + 1
# b cubed
b3 = 1 << k6
# Minimax approximation: x/b ~= 1 / sqrt(y/b)
x = (((463 * y * y) >> k2) - (896 * y) + (698 << k2)) >> 8
# print(" ", x, h)
# Newton's method for 1 / sqrt(y/b)
epsilon = 1 << k
for i in range(1, 99):
dx = x * (b3 - y * x * x) >> kd
x += dx
# print(f" {i}: {x} {dx}")
if abs(dx) <= epsilon:
break
# s == sqrt(y)
s = x * y >> k3
# Adjust if too low
ss = s + 1
return ss if ss * ss <= y else s
def test(lo, hi, step=1):
for y in range(lo, hi, step):
s = int_sqrt(y)
ss = s + 1
s2, ss2 = s * s, ss * ss
assert s2 <= y < ss2, (y, s2, ss2)
print("ok")
test(0, 100000, 1)
This code is certainly slower than math.isqrt and decimal.Decimal.sqrt. Its purpose is simply to illustrate the algorithm. It would be interesting to see how fast it would be if it were implemented in C...
Here's a live version, running on the SageMathCell server. Set hi <= 0 to calculate and display the results for a single value set in lo. You can put expressions in the input boxes, eg set hi to 0 and lo to 2 * 10**100 to get sqrt(2) * 10**50.
Inspired by all answers, decided to implement in pure C++ several best methods from these answers. As everybody knows C++ is always faster than Python.
To glue C++ and Python I used Cython. It allows to make out of C++ a Python module and then call C++ functions directly from Python functions.
Also as complementary I provided not only Python-adopted code, but pure C++ with tests too.
Here are timings from pure C++ tests:
Test 'GMP', bits 64, time 0.000001 sec
Test 'AndersKaseorg', bits 64, time 0.000003 sec
Test 'Babylonian', bits 64, time 0.000006 sec
Test 'ChordTangent', bits 64, time 0.000018 sec
Test 'GMP', bits 50000, time 0.000118 sec
Test 'AndersKaseorg', bits 50000, time 0.002777 sec
Test 'Babylonian', bits 50000, time 0.003062 sec
Test 'ChordTangent', bits 50000, time 0.009120 sec
and same C++ functions but as adopted Python module have timings:
Bits 50000
math.isqrt: 2.819 ms
gmpy2.isqrt: 0.166 ms
ISqrt_GMP: 0.252 ms
ISqrt_AndersKaseorg: 3.338 ms
ISqrt_Babylonian: 3.756 ms
ISqrt_ChordTangent: 10.564 ms
My Cython-C++ is nice in a sence as a framework for those people who want to write and test his own C++ method from Python directly.
As you noticed in above timings as example I used following methods:
math.isqrt, implementation from standard library.
gmpy2.isqrt, GMPY2 library's implementation.
ISqrt_GMP - same as GMPY2, but using my Cython module, there I use C++ GMP library (<gmpxx.h>) directly.
ISqrt_AndersKaseorg, code taken from answer of #AndersKaseorg.
ISqrt_Babylonian, method taken from Wikipedia article, so-called Babylonian method. My own implementation as I understand it.
ISqrt_ChordTangent, it is my own method that I called Chord-Tangent, because it uses chord and tangent line to iteratively shorten interval of search. This method is described in moderate details in my other article. This method is nice because it searches not only square root, but also K-th root for any K. I drew a small picture showing details of this algorithm.
Regarding compiling C++/Cython code, I used GMP library. You need to install it first, under Linux it is easy through sudo apt install libgmp-dev.
Under Windows easiest is to install really great program VCPKG, this is software Package Manager, similar to APT in Linux. VCPKG compiles all packages from sources using Visual Studio (don't forget to install Community version of Visual Studio). After installing VCPKG you can install GMP by vcpkg install gmp. Also you may install MPIR, this is alternative fork of GMP, you can install it through vcpkg install mpir.
After GMP is installed under Windows please edit my Python code and replace path to include directory and library file. VCPKG at the end of installation should show you path to ZIP file with GMP library, there are .lib and .h files.
You may notice in Python code that I also designed special handy cython_compile() function that I use to compile any C++ code into Python module. This function is really good as it allows for you to easily plug-in any C++ code into Python, this can be reused many times.
If you have any questions or suggestions, or something doesn't work on your PC, please write in comments.
Below first I show code in Python, afterwards in C++. See Try it online! link above C++ code to run code online on GodBolt servers. Both code snippets I fully runnable from scratch as they are, nothing needs to be edited in them.
def cython_compile(srcs):
import json, hashlib, os, glob, importlib, sys, shutil, tempfile
srch = hashlib.sha256(json.dumps(srcs, sort_keys = True, ensure_ascii = True).encode('utf-8')).hexdigest().upper()[:12]
pdir = 'cyimp'
if len(glob.glob(f'{pdir}/cy{srch}*')) == 0:
class ChDir:
def __init__(self, newd):
self.newd = newd
def __enter__(self):
self.curd = os.getcwd()
os.chdir(self.newd)
return self
def __exit__(self, ext, exv, tb):
os.chdir(self.curd)
os.makedirs(pdir, exist_ok = True)
with tempfile.TemporaryDirectory(dir = pdir) as td, ChDir(str(td)) as chd:
os.makedirs(pdir, exist_ok = True)
for k, v in srcs.items():
with open(f'cys{srch}_{k}', 'wb') as f:
f.write(v.replace('{srch}', srch).encode('utf-8'))
import numpy as np
from setuptools import setup, Extension
from Cython.Build import cythonize
sys.argv += ['build_ext', '--inplace']
setup(
ext_modules = cythonize(
Extension(
f'{pdir}.cy{srch}', [f'cys{srch}_{k}' for k in filter(lambda e: e[e.rfind('.') + 1:] in ['pyx', 'c', 'cpp'], srcs.keys())],
depends = [f'cys{srch}_{k}' for k in filter(lambda e: e[e.rfind('.') + 1:] not in ['pyx', 'c', 'cpp'], srcs.keys())],
extra_compile_args = ['/O2', '/std:c++latest',
'/ID:/dev/_3party/vcpkg_bin/gmp/include/',
],
),
compiler_directives = {'language_level': 3, 'embedsignature': True},
annotate = True,
),
include_dirs = [np.get_include()],
)
del sys.argv[-2:]
for f in glob.glob(f'{pdir}/cy{srch}*'):
shutil.copy(f, f'./../')
print('Cython module:', f'cy{srch}')
return importlib.import_module(f'{pdir}.cy{srch}')
def cython_import():
srcs = {
'lib.h': """
#include <cstring>
#include <cstdint>
#include <stdexcept>
#include <tuple>
#include <iostream>
#include <string>
#include <type_traits>
#include <sstream>
#include <gmpxx.h>
#pragma comment(lib, "D:/dev/_3party/vcpkg_bin/gmp/lib/gmp.lib")
#define ASSERT_MSG(cond, msg) { if (!(cond)) throw std::runtime_error("Assertion (" #cond ") failed at line " + std::to_string(__LINE__) + "! Msg '" + std::string(msg) + "'."); }
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define LN { std::cout << "LN " << __LINE__ << std::endl; }
using u32 = uint32_t;
using u64 = uint64_t;
template <typename T>
size_t BitLen(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return mpz_sizeinbase(n.get_mpz_t(), 2);
else {
size_t cnt = 0;
while (n >= (1ULL << 32)) {
cnt += 32;
n >>= 32;
}
while (n >= (1 << 8)) {
cnt += 8;
n >>= 8;
}
while (n) {
++cnt;
n >>= 1;
}
return cnt;
}
}
template <typename T>
T ISqrt_Babylonian(T const & y) {
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
if (y <= 1)
return y;
T x = T(1) << (BitLen(y) / 2), a = 0, b = 0, limit = 3;
while (true) {
size_t constexpr loops = 3;
for (size_t i = 0; i < loops; ++i) {
if (i + 1 >= loops)
a = x;
b = y;
b /= x;
x += b;
x >>= 1;
}
if (b < a)
std::swap(a, b);
if (b - a > limit)
continue;
++b;
for (size_t i = 0; a <= b; ++a, ++i)
if (a * a > y) {
if (i == 0)
break;
else
return a - 1;
}
ASSERT(false);
}
}
template <typename T>
T ISqrt_AndersKaseorg(T const & n) {
// https://stackoverflow.com/a/53983683/941531
if (n > 0) {
T y = 0, x = T(1) << ((BitLen(n) + 1) >> 1);
while (true) {
y = (x + n / x) >> 1;
if (y >= x)
return x;
x = y;
}
} else if (n == 0)
return 0;
else
ASSERT_MSG(false, "square root not defined for negative numbers");
}
template <typename T>
T ISqrt_GMP(T const & y) {
// https://gmplib.org/manual/Integer-Roots
mpz_class r, n;
bool constexpr is_mpz = std::is_same_v<std::decay_t<T>, mpz_class>;
if constexpr(is_mpz)
n = y;
else {
static_assert(sizeof(T) <= 8);
n = u32(y >> 32);
n <<= 32;
n |= u32(y);
}
mpz_sqrt(r.get_mpz_t(), n.get_mpz_t());
if constexpr(is_mpz)
return r;
else
return (u64(mpz_get_ui(mpz_class(r >> 32).get_mpz_t())) << 32) | u64(mpz_get_ui(mpz_class(r & u32(-1)).get_mpz_t()));
}
template <typename T>
T KthRoot_ChordTangent(T const & n, size_t k = 2) {
// https://i.stack.imgur.com/et9O0.jpg
if (n <= 1)
return n;
auto KthPow = [&](auto const & x){
T y = x * x;
for (size_t i = 2; i < k; ++i)
y *= x;
return y;
};
auto KthPowDer = [&](auto const & x){
T y = x * u32(k);
for (size_t i = 1; i + 1 < k; ++i)
y *= x;
return y;
};
size_t root_bit_len = (BitLen(n) + k - 1) / k;
T hi = T(1) << root_bit_len,
x_begin = hi >> 1, x_end = hi,
y_begin = KthPow(x_begin), y_end = KthPow(x_end),
x_mid = 0, y_mid = 0, x_n = 0, y_n = 0, tangent_x = 0, chord_x = 0;
for (size_t icycle = 0; icycle < (1 << 30); ++icycle) {
if (x_end <= x_begin + 2)
break;
if constexpr(0) { // Do Binary Search step if needed
x_mid = (x_begin + x_end) >> 1;
y_mid = KthPow(x_mid);
if (y_mid > n) {
x_end = x_mid; y_end = y_mid;
} else {
x_begin = x_mid; y_begin = y_mid;
}
}
// (y_end - y_begin) / (x_end - x_begin) = (n - y_begin) / (x_n - x_begin) ->
x_n = x_begin + (n - y_begin) * (x_end - x_begin) / (y_end - y_begin);
y_n = KthPow(x_n);
tangent_x = x_n + (n - y_n) / KthPowDer(x_n) + 1;
chord_x = x_n + (n - y_n) * (x_end - x_n) / (y_end - y_n);
//ASSERT(chord_x <= tangent_x);
x_begin = chord_x; x_end = tangent_x;
y_begin = KthPow(x_begin); y_end = KthPow(x_end);
//ASSERT(y_begin <= n);
//ASSERT(y_end > n);
}
for (size_t i = 0; x_begin <= x_end; ++x_begin, ++i)
if (x_begin * x_begin > n) {
if (i == 0)
break;
else
return x_begin - 1;
}
ASSERT(false);
return 0;
}
mpz_class FromLimbs(uint64_t * limbs, uint64_t * cnt) {
mpz_class r;
mpz_import(r.get_mpz_t(), *cnt, -1, 8, -1, 0, limbs);
return r;
}
void ToLimbs(mpz_class const & n, uint64_t * limbs, uint64_t * cnt) {
uint64_t cnt_before = *cnt;
size_t cnt_res = 0;
mpz_export(limbs, &cnt_res, -1, 8, -1, 0, n.get_mpz_t());
ASSERT(cnt_res <= cnt_before);
std::memset(limbs + cnt_res, 0, (cnt_before - cnt_res) * 8);
*cnt = cnt_res;
}
void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_GMP<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_AndersKaseorg<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_Babylonian<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(KthRoot_ChordTangent<mpz_class>(FromLimbs(limbs, cnt), 2), limbs, cnt);
}
""",
'main.pyx': r"""
# distutils: language = c++
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
import numpy as np
cimport numpy as np
cimport cython
from libc.stdint cimport *
cdef extern from "cys{srch}_lib.h" nogil:
void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt);
#cython.boundscheck(False)
#cython.wraparound(False)
def ISqrt(method, n):
mask64 = (1 << 64) - 1
def ToLimbs():
return np.copy(np.frombuffer(n.to_bytes((n.bit_length() + 63) // 64 * 8, 'little'), dtype = np.uint64))
words = (n.bit_length() + 63) // 64
t = n
r = np.zeros((words,), dtype = np.uint64)
for i in range(words):
r[i] = np.uint64(t & mask64)
t >>= 64
return r
def FromLimbs(x):
return int.from_bytes(x.tobytes(), 'little')
n = 0
for i in range(x.shape[0]):
n |= int(x[i]) << (i * 64)
return n
n = ToLimbs()
cdef uint64_t[:] cn = n
cdef uint64_t ccnt = len(n)
cdef uint64_t cmethod = {'GMP': 0, 'AndersKaseorg': 1, 'Babylonian': 2, 'ChordTangent': 3}[method]
with nogil:
(ISqrt_GMP_Py if cmethod == 0 else ISqrt_AndersKaseorg_Py if cmethod == 1 else ISqrt_Babylonian_Py if cmethod == 2 else ISqrt_ChordTangent_Py)(
<uint64_t *>&cn[0], <uint64_t *>&ccnt
)
return FromLimbs(n[:ccnt])
""",
}
return cython_compile(srcs)
def main():
import math, gmpy2, timeit, random
mod = cython_import()
fs = [
('math.isqrt', math.isqrt),
('gmpy2.isqrt', gmpy2.isqrt),
('ISqrt_GMP', lambda n: mod.ISqrt('GMP', n)),
('ISqrt_AndersKaseorg', lambda n: mod.ISqrt('AndersKaseorg', n)),
('ISqrt_Babylonian', lambda n: mod.ISqrt('Babylonian', n)),
('ISqrt_ChordTangent', lambda n: mod.ISqrt('ChordTangent', n)),
]
times = [0] * len(fs)
ntests = 1 << 6
bits = 50000
for i in range(ntests):
n = random.randrange(1 << (bits - 1), 1 << bits)
ref = None
for j, (fn, f) in enumerate(fs):
timeit_cnt = 3
tim = timeit.timeit(lambda: f(n), number = timeit_cnt) / timeit_cnt
times[j] += tim
x = f(n)
if j == 0:
ref = x
else:
assert x == ref, (fn, ref, x)
print('Bits', bits)
print('\n'.join([f'{fs[i][0]:>19}: {round(times[i] / ntests * 1000, 3):>7} ms' for i in range(len(fs))]))
if __name__ == '__main__':
main()
and C++:
Try it online!
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <tuple>
#include <iostream>
#include <string>
#include <type_traits>
#include <sstream>
#include <gmpxx.h>
#define ASSERT_MSG(cond, msg) { if (!(cond)) throw std::runtime_error("Assertion (" #cond ") failed at line " + std::to_string(__LINE__) + "! Msg '" + std::string(msg) + "'."); }
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define LN { std::cout << "LN " << __LINE__ << std::endl; }
using u32 = uint32_t;
using u64 = uint64_t;
template <typename T>
size_t BitLen(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return mpz_sizeinbase(n.get_mpz_t(), 2);
else {
size_t cnt = 0;
while (n >= (1ULL << 32)) {
cnt += 32;
n >>= 32;
}
while (n >= (1 << 8)) {
cnt += 8;
n >>= 8;
}
while (n) {
++cnt;
n >>= 1;
}
return cnt;
}
}
template <typename T>
T ISqrt_Babylonian(T const & y) {
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
if (y <= 1)
return y;
T x = T(1) << (BitLen(y) / 2), a = 0, b = 0, limit = 3;
while (true) {
size_t constexpr loops = 3;
for (size_t i = 0; i < loops; ++i) {
if (i + 1 >= loops)
a = x;
b = y;
b /= x;
x += b;
x >>= 1;
}
if (b < a)
std::swap(a, b);
if (b - a > limit)
continue;
++b;
for (size_t i = 0; a <= b; ++a, ++i)
if (a * a > y) {
if (i == 0)
break;
else
return a - 1;
}
ASSERT(false);
}
}
template <typename T>
T ISqrt_AndersKaseorg(T const & n) {
// https://stackoverflow.com/a/53983683/941531
if (n > 0) {
T y = 0, x = T(1) << ((BitLen(n) + 1) >> 1);
while (true) {
y = (x + n / x) >> 1;
if (y >= x)
return x;
x = y;
}
} else if (n == 0)
return 0;
else
ASSERT_MSG(false, "square root not defined for negative numbers");
}
template <typename T>
T ISqrt_GMP(T const & y) {
// https://gmplib.org/manual/Integer-Roots
mpz_class r, n;
bool constexpr is_mpz = std::is_same_v<std::decay_t<T>, mpz_class>;
if constexpr(is_mpz)
n = y;
else {
static_assert(sizeof(T) <= 8);
n = u32(y >> 32);
n <<= 32;
n |= u32(y);
}
mpz_sqrt(r.get_mpz_t(), n.get_mpz_t());
if constexpr(is_mpz)
return r;
else
return (u64(mpz_get_ui(mpz_class(r >> 32).get_mpz_t())) << 32) | u64(mpz_get_ui(mpz_class(r & u32(-1)).get_mpz_t()));
}
template <typename T>
std::string IntToStr(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return n.get_str();
else {
std::ostringstream ss;
ss << n;
return ss.str();
}
}
template <typename T>
T KthRoot_ChordTangent(T const & n, size_t k = 2) {
// https://i.stack.imgur.com/et9O0.jpg
if (n <= 1)
return n;
auto KthPow = [&](auto const & x){
T y = x * x;
for (size_t i = 2; i < k; ++i)
y *= x;
return y;
};
auto KthPowDer = [&](auto const & x){
T y = x * u32(k);
for (size_t i = 1; i + 1 < k; ++i)
y *= x;
return y;
};
size_t root_bit_len = (BitLen(n) + k - 1) / k;
T hi = T(1) << root_bit_len,
x_begin = hi >> 1, x_end = hi,
y_begin = KthPow(x_begin), y_end = KthPow(x_end),
x_mid = 0, y_mid = 0, x_n = 0, y_n = 0, tangent_x = 0, chord_x = 0;
for (size_t icycle = 0; icycle < (1 << 30); ++icycle) {
//std::cout << "x_begin, x_end = " << IntToStr(x_begin) << ", " << IntToStr(x_end) << ", n " << IntToStr(n) << std::endl;
if (x_end <= x_begin + 2)
break;
if constexpr(0) { // Do Binary Search step if needed
x_mid = (x_begin + x_end) >> 1;
y_mid = KthPow(x_mid);
if (y_mid > n) {
x_end = x_mid; y_end = y_mid;
} else {
x_begin = x_mid; y_begin = y_mid;
}
}
// (y_end - y_begin) / (x_end - x_begin) = (n - y_begin) / (x_n - x_begin) ->
x_n = x_begin + (n - y_begin) * (x_end - x_begin) / (y_end - y_begin);
y_n = KthPow(x_n);
tangent_x = x_n + (n - y_n) / KthPowDer(x_n) + 1;
chord_x = x_n + (n - y_n) * (x_end - x_n) / (y_end - y_n);
//ASSERT(chord_x <= tangent_x);
x_begin = chord_x; x_end = tangent_x;
y_begin = KthPow(x_begin); y_end = KthPow(x_end);
//ASSERT(y_begin <= n);
//ASSERT(y_end > n);
}
for (size_t i = 0; x_begin <= x_end; ++x_begin, ++i)
if (x_begin * x_begin > n) {
if (i == 0)
break;
else
return x_begin - 1;
}
ASSERT(false);
return 0;
}
mpz_class FromLimbs(uint64_t * limbs, uint64_t * cnt) {
mpz_class r;
mpz_import(r.get_mpz_t(), *cnt, -1, 8, -1, 0, limbs);
return r;
}
void ToLimbs(mpz_class const & n, uint64_t * limbs, uint64_t * cnt) {
uint64_t cnt_before = *cnt;
size_t cnt_res = 0;
mpz_export(limbs, &cnt_res, -1, 8, -1, 0, n.get_mpz_t());
ASSERT(cnt_res <= cnt_before);
std::memset(limbs + cnt_res, 0, (cnt_before - cnt_res) * 8);
*cnt = cnt_res;
}
void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(KthRoot_ChordTangent<mpz_class>(FromLimbs(limbs, cnt), 2), limbs, cnt);
}
void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_GMP<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_AndersKaseorg<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_Babylonian<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
// Testing
#include <chrono>
#include <random>
#include <vector>
#include <iomanip>
inline double Time() {
static auto const gtb = std::chrono::high_resolution_clock::now();
return std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::high_resolution_clock::now() - gtb)
.count();
}
template <typename T, typename F>
std::vector<T> Test0(std::string const & test_name, size_t bits, size_t ntests, F && f) {
std::mt19937_64 rng{123};
std::vector<T> nums;
for (size_t i = 0; i < ntests; ++i) {
T n = 0;
for (size_t j = 0; j < bits; j += 32) {
size_t const cbits = std::min<size_t>(32, bits - j);
n <<= cbits;
n ^= u32(rng()) >> (32 - cbits);
}
nums.push_back(n);
}
auto tim = Time();
for (auto & n: nums)
n = f(n);
tim = Time() - tim;
std::cout << "Test " << std::setw(15) << ("'" + test_name + "'")
<< ", bits " << std::setw(6) << bits << ", time "
<< std::fixed << std::setprecision(6) << std::setw(9) << tim / ntests << " sec" << std::endl;
return nums;
}
void Test() {
auto f = [](auto ty, size_t bits, size_t ntests){
using T = std::decay_t<decltype(ty)>;
auto tim = Time();
auto a = Test0<T>("GMP", bits, ntests, [](auto const & x){ return ISqrt_GMP<T>(x); });
auto b = Test0<T>("AndersKaseorg", bits, ntests, [](auto const & x){ return ISqrt_AndersKaseorg<T>(x); });
ASSERT(b == a);
auto c = Test0<T>("Babylonian", bits, ntests, [](auto const & x){ return ISqrt_Babylonian<T>(x); });
ASSERT(c == a);
auto d = Test0<T>("ChordTangent", bits, ntests, [](auto const & x){ return KthRoot_ChordTangent<T>(x); });
ASSERT(d == a);
std::cout << "Bits " << bits << " nums " << ntests << " time " << std::fixed << std::setprecision(1) << (Time() - tim) << " sec" << std::endl;
};
for (auto p: std::vector<std::pair<int, int>>{{15, 1 << 19}, {30, 1 << 19}})
f(u64(), p.first, p.second);
for (auto p: std::vector<std::pair<int, int>>{{64, 1 << 15}, {8192, 1 << 10}, {50000, 1 << 5}})
f(mpz_class(), p.first, p.second);
}
int main() {
try {
Test();
return 0;
} catch (std::exception const & ex) {
std::cout << "Exception: " << ex.what() << std::endl;
return -1;
}
}
Your function fails for large inputs:
In [26]: isqrt((10**100+1)**2)
ValueError: input was not a perfect square
There is a recipe on the ActiveState site which should hopefully be more reliable since it uses integer maths only. It is based on an earlier StackOverflow question: Writing your own square root function
Floats cannot be precisely represented on computers. You can test for a desired proximity setting epsilon to a small value within the accuracy of python's floats.
def isqrt(n):
epsilon = .00000000001
i = int(n**.5 + 0.5)
if abs(i**2 - n) < epsilon:
return i
raise ValueError('input was not a perfect square')
Try this condition (no additional computation):
def isqrt(n):
i = math.sqrt(n)
if i != int(i):
raise ValueError('input was not a perfect square')
return i
If you need it to return an int (not a float with a trailing zero) then either assign a 2nd variable or compute int(i) twice.
I have compared the different methods given here with a loop:
for i in range (1000000): # 700 msec
r=int(123456781234567**0.5+0.5)
if r**2==123456781234567:rr=r
else:rr=-1
finding that this one is fastest and need no math-import. Very long might fail, but look at this
15241576832799734552675677489**0.5 = 123456781234567.0
Is there an integer square root somewhere in python, or in standard libraries? I want it to be exact (i.e. return an integer), and raise an exception if the input isn't a perfect square.
I tried using this code:
def isqrt(n):
i = int(math.sqrt(n) + 0.5)
if i**2 == n:
return i
raise ValueError('input was not a perfect square')
But it's ugly and I don't really trust it for large integers. I could iterate through the squares and give up if I've exceeded the value, but I assume it would be kinda slow to do something like that. Also, surely this is already implemented somewhere?
See also: Check if a number is a perfect square.
Note: There is now math.isqrt in stdlib, available since Python 3.8.
Newton's method works perfectly well on integers:
def isqrt(n):
x = n
y = (x + 1) // 2
while y < x:
x = y
y = (x + n // x) // 2
return x
This returns the largest integer x for which x * x does not exceed n. If you want to check if the result is exactly the square root, simply perform the multiplication to check if n is a perfect square.
I discuss this algorithm, and three other algorithms for calculating square roots, at my blog.
Update: Python 3.8 has a math.isqrt function in the standard library!
I benchmarked every (correct) function here on both small (0…222) and large (250001) inputs. The clear winners in both cases are gmpy2.isqrt suggested by mathmandan in first place, followed by Python 3.8’s math.isqrt in second, followed by the ActiveState recipe linked by NPE in third. The ActiveState recipe has a bunch of divisions that can be replaced by shifts, which makes it a bit faster (but still behind the native functions):
def isqrt(n):
if n > 0:
x = 1 << (n.bit_length() + 1 >> 1)
while True:
y = (x + n // x) >> 1
if y >= x:
return x
x = y
elif n == 0:
return 0
else:
raise ValueError("square root not defined for negative numbers")
Benchmark results:
gmpy2.isqrt() (mathmandan): 0.08 µs small, 0.07 ms large
int(gmpy2.isqrt())*: 0.3 µs small, 0.07 ms large
Python 3.8 math.isqrt: 0.13 µs small, 0.9 ms large
ActiveState (optimized as above): 0.6 µs small, 17.0 ms large
ActiveState (NPE): 1.0 µs small, 17.3 ms large
castlebravo long-hand: 4 µs small, 80 ms large
mathmandan improved: 2.7 µs small, 120 ms large
martineau (with this correction): 2.3 µs small, 140 ms large
nibot: 8 µs small, 1000 ms large
mathmandan: 1.8 µs small, 2200 ms large
castlebravo Newton’s method: 1.5 µs small, 19000 ms large
user448810: 1.4 µs small, 20000 ms large
(* Since gmpy2.isqrt returns a gmpy2.mpz object, which behaves mostly but not exactly like an int, you may need to convert it back to an int for some uses.)
Sorry for the very late response; I just stumbled onto this page. In case anyone visits this page in the future, the python module gmpy2 is designed to work with very large inputs, and includes among other things an integer square root function.
Example:
>>> import gmpy2
>>> gmpy2.isqrt((10**100+1)**2)
mpz(10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001L)
>>> gmpy2.isqrt((10**100+1)**2 - 1)
mpz(10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000L)
Granted, everything will have the "mpz" tag, but mpz's are compatible with int's:
>>> gmpy2.mpz(3)*4
mpz(12)
>>> int(gmpy2.mpz(12))
12
See my other answer for a discussion of this method's performance relative to some other answers to this question.
Download: https://code.google.com/p/gmpy/
Here's a very straightforward implementation:
def i_sqrt(n):
i = n.bit_length() >> 1 # i = floor( (1 + floor(log_2(n))) / 2 )
m = 1 << i # m = 2^i
#
# Fact: (2^(i + 1))^2 > n, so m has at least as many bits
# as the floor of the square root of n.
#
# Proof: (2^(i+1))^2 = 2^(2i + 2) >= 2^(floor(log_2(n)) + 2)
# >= 2^(ceil(log_2(n) + 1) >= 2^(log_2(n) + 1) > 2^(log_2(n)) = n. QED.
#
while m*m > n:
m >>= 1
i -= 1
for k in xrange(i-1, -1, -1):
x = m | (1 << k)
if x*x <= n:
m = x
return m
This is just a binary search. Initialize the value m to be the largest power of 2 that does not exceed the square root, then check whether each smaller bit can be set while keeping the result no larger than the square root. (Check the bits one at a time, in descending order.)
For reasonably large values of n (say, around 10**6000, or around 20000 bits), this seems to be:
Faster than the Newton's method implementation described by user448810.
Much, much slower than the gmpy2 built-in method in my other answer.
Comparable to, but somewhat slower than, the Longhand Square Root described by nibot.
All of these approaches succeed on inputs of this size, but on my machine, this function takes around 1.5 seconds, while #Nibot's takes about 0.9 seconds, #user448810's takes around 19 seconds, and the gmpy2 built-in method takes less than a millisecond(!). Example:
>>> import random
>>> import timeit
>>> import gmpy2
>>> r = random.getrandbits
>>> t = timeit.timeit
>>> t('i_sqrt(r(20000))', 'from __main__ import *', number = 5)/5. # This function
1.5102493192883117
>>> t('exact_sqrt(r(20000))', 'from __main__ import *', number = 5)/5. # Nibot
0.8952787937686366
>>> t('isqrt(r(20000))', 'from __main__ import *', number = 5)/5. # user448810
19.326695976676184
>>> t('gmpy2.isqrt(r(20000))', 'from __main__ import *', number = 5)/5. # gmpy2
0.0003599147067689046
>>> all(i_sqrt(n)==isqrt(n)==exact_sqrt(n)[0]==int(gmpy2.isqrt(n)) for n in (r(1500) for i in xrange(1500)))
True
This function can be generalized easily, though it's not quite as nice because I don't have quite as precise of an initial guess for m:
def i_root(num, root, report_exactness = True):
i = num.bit_length() / root
m = 1 << i
while m ** root < num:
m <<= 1
i += 1
while m ** root > num:
m >>= 1
i -= 1
for k in xrange(i-1, -1, -1):
x = m | (1 << k)
if x ** root <= num:
m = x
if report_exactness:
return m, m ** root == num
return m
However, note that gmpy2 also has an i_root method.
In fact this method could be adapted and applied to any (nonnegative, increasing) function f to determine an "integer inverse of f". However, to choose an efficient initial value of m you'd still want to know something about f.
Edit: Thanks to #Greggo for pointing out that the i_sqrt function can be rewritten to avoid using any multiplications. This yields an impressive performance boost!
def improved_i_sqrt(n):
assert n >= 0
if n == 0:
return 0
i = n.bit_length() >> 1 # i = floor( (1 + floor(log_2(n))) / 2 )
m = 1 << i # m = 2^i
#
# Fact: (2^(i + 1))^2 > n, so m has at least as many bits
# as the floor of the square root of n.
#
# Proof: (2^(i+1))^2 = 2^(2i + 2) >= 2^(floor(log_2(n)) + 2)
# >= 2^(ceil(log_2(n) + 1) >= 2^(log_2(n) + 1) > 2^(log_2(n)) = n. QED.
#
while (m << i) > n: # (m<<i) = m*(2^i) = m*m
m >>= 1
i -= 1
d = n - (m << i) # d = n-m^2
for k in xrange(i-1, -1, -1):
j = 1 << k
new_diff = d - (((m<<1) | j) << k) # n-(m+2^k)^2 = n-m^2-2*m*2^k-2^(2k)
if new_diff >= 0:
d = new_diff
m |= j
return m
Note that by construction, the kth bit of m << 1 is not set, so bitwise-or may be used to implement the addition of (m<<1) + (1<<k). Ultimately I have (2*m*(2**k) + 2**(2*k)) written as (((m<<1) | (1<<k)) << k), so it's three shifts and one bitwise-or (followed by a subtraction to get new_diff). Maybe there is still a more efficient way to get this? Regardless, it's far better than multiplying m*m! Compare with above:
>>> t('improved_i_sqrt(r(20000))', 'from __main__ import *', number = 5)/5.
0.10908999762373242
>>> all(improved_i_sqrt(n) == i_sqrt(n) for n in xrange(10**6))
True
Long-hand square root algorithm
It turns out that there is an algorithm for computing square roots that you can compute by hand, something like long-division. Each iteration of the algorithm produces exactly one digit of the resulting square root while consuming two digits of the number whose square root you seek. While the "long hand" version of the algorithm is specified in decimal, it works in any base, with binary being simplest to implement and perhaps the fastest to execute (depending on the underlying bignum representation).
Because this algorithm operates on numbers digit-by-digit, it produces exact results for arbitrarily large perfect squares, and for non-perfect-squares, can produce as many digits of precision (to the right of the decimal place) as desired.
There are two nice writeups on the "Dr. Math" site that explain the algorithm:
Square Roots in Binary
Longhand Square Roots
And here's an implementation in Python:
def exact_sqrt(x):
"""Calculate the square root of an arbitrarily large integer.
The result of exact_sqrt(x) is a tuple (a, r) such that a**2 + r = x, where
a is the largest integer such that a**2 <= x, and r is the "remainder". If
x is a perfect square, then r will be zero.
The algorithm used is the "long-hand square root" algorithm, as described at
http://mathforum.org/library/drmath/view/52656.html
Tobin Fricke 2014-04-23
Max Planck Institute for Gravitational Physics
Hannover, Germany
"""
N = 0 # Problem so far
a = 0 # Solution so far
# We'll process the number two bits at a time, starting at the MSB
L = x.bit_length()
L += (L % 2) # Round up to the next even number
for i in xrange(L, -1, -1):
# Get the next group of two bits
n = (x >> (2*i)) & 0b11
# Check whether we can reduce the remainder
if ((N - a*a) << 2) + n >= (a<<2) + 1:
b = 1
else:
b = 0
a = (a << 1) | b # Concatenate the next bit of the solution
N = (N << 2) | n # Concatenate the next bit of the problem
return (a, N-a*a)
You could easily modify this function to conduct additional iterations to calculate the fractional part of the square root. I was most interested in computing roots of large perfect squares.
I'm not sure how this compares to the "integer Newton's method" algorithm. I suspect that Newton's method is faster, since it can in principle generate multiple bits of the solution in one iteration, while the "long hand" algorithm generates exactly one bit of the solution per iteration.
Source repo: https://gist.github.com/tobin/11233492
One option would be to use the decimal module, and do it in sufficiently-precise floats:
import decimal
def isqrt(n):
nd = decimal.Decimal(n)
with decimal.localcontext() as ctx:
ctx.prec = n.bit_length()
i = int(nd.sqrt())
if i**2 != n:
raise ValueError('input was not a perfect square')
return i
which I think should work:
>>> isqrt(1)
1
>>> isqrt(7**14) == 7**7
True
>>> isqrt(11**1000) == 11**500
True
>>> isqrt(11**1000+1)
Traceback (most recent call last):
File "<ipython-input-121-e80953fb4d8e>", line 1, in <module>
isqrt(11**1000+1)
File "<ipython-input-100-dd91f704e2bd>", line 10, in isqrt
raise ValueError('input was not a perfect square')
ValueError: input was not a perfect square
Python's default math library has an integer square root function:
math.isqrt(n)
Return the integer square root of the nonnegative integer n. This is the floor of the exact square root of n, or equivalently the greatest integer a such that a² ≤ n.
Seems like you could check like this:
if int(math.sqrt(n))**2 == n:
print n, 'is a perfect square'
Update:
As you pointed out the above fails for large values of n. For those the following looks promising, which is an adaptation of the example C code, by Martin Guy # UKC, June 1985, for the relatively simple looking binary numeral digit-by-digit calculation method mentioned in the Wikipedia article Methods of computing square roots:
from math import ceil, log
def isqrt(n):
res = 0
bit = 4**int(ceil(log(n, 4))) if n else 0 # smallest power of 4 >= the argument
while bit:
if n >= res + bit:
n -= res + bit
res = (res >> 1) + bit
else:
res >>= 1
bit >>= 2
return res
if __name__ == '__main__':
from math import sqrt # for comparison purposes
for i in range(17)+[2**53, (10**100+1)**2]:
is_perfect_sq = isqrt(i)**2 == i
print '{:21,d}: math.sqrt={:12,.7G}, isqrt={:10,d} {}'.format(
i, sqrt(i), isqrt(i), '(perfect square)' if is_perfect_sq else '')
Output:
0: math.sqrt= 0, isqrt= 0 (perfect square)
1: math.sqrt= 1, isqrt= 1 (perfect square)
2: math.sqrt= 1.414214, isqrt= 1
3: math.sqrt= 1.732051, isqrt= 1
4: math.sqrt= 2, isqrt= 2 (perfect square)
5: math.sqrt= 2.236068, isqrt= 2
6: math.sqrt= 2.44949, isqrt= 2
7: math.sqrt= 2.645751, isqrt= 2
8: math.sqrt= 2.828427, isqrt= 2
9: math.sqrt= 3, isqrt= 3 (perfect square)
10: math.sqrt= 3.162278, isqrt= 3
11: math.sqrt= 3.316625, isqrt= 3
12: math.sqrt= 3.464102, isqrt= 3
13: math.sqrt= 3.605551, isqrt= 3
14: math.sqrt= 3.741657, isqrt= 3
15: math.sqrt= 3.872983, isqrt= 3
16: math.sqrt= 4, isqrt= 4 (perfect square)
9,007,199,254,740,992: math.sqrt=9.490627E+07, isqrt=94,906,265
100,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,020,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,001: math.sqrt= 1E+100, isqrt=10,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,001 (perfect square)
The script below extracts integer square roots. It uses no divisions, only bitshifts, so it is quite fast. It uses Newton's method on the inverse square root, a technique made famous by Quake III Arena as mentioned in the Wikipedia article, Fast inverse square root.
The strategy of the algorithm to compute s = sqrt(Y) is as follows.
Reduce the argument Y to y in the range [1/4, 1), i.e., y = Y/B, with 1/4 <= y < 1, where B is an even power of 2, so B = 2**(2*k) for some integer k. We want to find X, where x = X/B, and x = 1 / sqrt(y).
Determine a first approximation to X using a quadratic minimax polynomial.
Refine X using Newton's method.
Calculate s = X*Y/(2**(3*k)).
We don't actually create fractions or perform any divisions. All the arithmetic is done with integers, and we use bit shifting to divide by various powers of B.
Range reduction lets us find a good initial approximation to feed to Newton's method. Here's a version of the 2nd degree minimax polynomial approximation to the inverse square root in the interval [1/4, 1):
(Sorry, I've reversed the meaning of x & y here, to conform to the usual conventions). The maximum error of this approximation is around 0.0355 ~= 1/28. Here's a graph showing the error:
Using this poly, our initial x starts with at least 4 or 5 bits of precision. Each round of Newton's method doubles the precision, so it doesn't take many rounds to get thousands of bits, if we want them.
""" Integer square root
Uses no divisions, only shifts
"Quake" style algorithm,
i.e., Newton's method for 1 / sqrt(y)
Uses a quadratic minimax polynomial for the first approximation
Written by PM 2Ring 2022.01.23
"""
def int_sqrt(y):
if y < 0:
raise ValueError("int_sqrt arg must be >= 0, not %s" % y)
if y < 2:
return y
# print("\n*", y, "*")
# Range reduction.
# Find k such that 1/4 <= y/b < 1, where b = 2 ** (k*2)
j = y.bit_length()
# Round k*2 up to the next even number
k2 = j + (j & 1)
# k and some useful multiples
k = k2 >> 1
k3 = k2 + k
k6 = k3 << 1
kd = k6 + 1
# b cubed
b3 = 1 << k6
# Minimax approximation: x/b ~= 1 / sqrt(y/b)
x = (((463 * y * y) >> k2) - (896 * y) + (698 << k2)) >> 8
# print(" ", x, h)
# Newton's method for 1 / sqrt(y/b)
epsilon = 1 << k
for i in range(1, 99):
dx = x * (b3 - y * x * x) >> kd
x += dx
# print(f" {i}: {x} {dx}")
if abs(dx) <= epsilon:
break
# s == sqrt(y)
s = x * y >> k3
# Adjust if too low
ss = s + 1
return ss if ss * ss <= y else s
def test(lo, hi, step=1):
for y in range(lo, hi, step):
s = int_sqrt(y)
ss = s + 1
s2, ss2 = s * s, ss * ss
assert s2 <= y < ss2, (y, s2, ss2)
print("ok")
test(0, 100000, 1)
This code is certainly slower than math.isqrt and decimal.Decimal.sqrt. Its purpose is simply to illustrate the algorithm. It would be interesting to see how fast it would be if it were implemented in C...
Here's a live version, running on the SageMathCell server. Set hi <= 0 to calculate and display the results for a single value set in lo. You can put expressions in the input boxes, eg set hi to 0 and lo to 2 * 10**100 to get sqrt(2) * 10**50.
Inspired by all answers, decided to implement in pure C++ several best methods from these answers. As everybody knows C++ is always faster than Python.
To glue C++ and Python I used Cython. It allows to make out of C++ a Python module and then call C++ functions directly from Python functions.
Also as complementary I provided not only Python-adopted code, but pure C++ with tests too.
Here are timings from pure C++ tests:
Test 'GMP', bits 64, time 0.000001 sec
Test 'AndersKaseorg', bits 64, time 0.000003 sec
Test 'Babylonian', bits 64, time 0.000006 sec
Test 'ChordTangent', bits 64, time 0.000018 sec
Test 'GMP', bits 50000, time 0.000118 sec
Test 'AndersKaseorg', bits 50000, time 0.002777 sec
Test 'Babylonian', bits 50000, time 0.003062 sec
Test 'ChordTangent', bits 50000, time 0.009120 sec
and same C++ functions but as adopted Python module have timings:
Bits 50000
math.isqrt: 2.819 ms
gmpy2.isqrt: 0.166 ms
ISqrt_GMP: 0.252 ms
ISqrt_AndersKaseorg: 3.338 ms
ISqrt_Babylonian: 3.756 ms
ISqrt_ChordTangent: 10.564 ms
My Cython-C++ is nice in a sence as a framework for those people who want to write and test his own C++ method from Python directly.
As you noticed in above timings as example I used following methods:
math.isqrt, implementation from standard library.
gmpy2.isqrt, GMPY2 library's implementation.
ISqrt_GMP - same as GMPY2, but using my Cython module, there I use C++ GMP library (<gmpxx.h>) directly.
ISqrt_AndersKaseorg, code taken from answer of #AndersKaseorg.
ISqrt_Babylonian, method taken from Wikipedia article, so-called Babylonian method. My own implementation as I understand it.
ISqrt_ChordTangent, it is my own method that I called Chord-Tangent, because it uses chord and tangent line to iteratively shorten interval of search. This method is described in moderate details in my other article. This method is nice because it searches not only square root, but also K-th root for any K. I drew a small picture showing details of this algorithm.
Regarding compiling C++/Cython code, I used GMP library. You need to install it first, under Linux it is easy through sudo apt install libgmp-dev.
Under Windows easiest is to install really great program VCPKG, this is software Package Manager, similar to APT in Linux. VCPKG compiles all packages from sources using Visual Studio (don't forget to install Community version of Visual Studio). After installing VCPKG you can install GMP by vcpkg install gmp. Also you may install MPIR, this is alternative fork of GMP, you can install it through vcpkg install mpir.
After GMP is installed under Windows please edit my Python code and replace path to include directory and library file. VCPKG at the end of installation should show you path to ZIP file with GMP library, there are .lib and .h files.
You may notice in Python code that I also designed special handy cython_compile() function that I use to compile any C++ code into Python module. This function is really good as it allows for you to easily plug-in any C++ code into Python, this can be reused many times.
If you have any questions or suggestions, or something doesn't work on your PC, please write in comments.
Below first I show code in Python, afterwards in C++. See Try it online! link above C++ code to run code online on GodBolt servers. Both code snippets I fully runnable from scratch as they are, nothing needs to be edited in them.
def cython_compile(srcs):
import json, hashlib, os, glob, importlib, sys, shutil, tempfile
srch = hashlib.sha256(json.dumps(srcs, sort_keys = True, ensure_ascii = True).encode('utf-8')).hexdigest().upper()[:12]
pdir = 'cyimp'
if len(glob.glob(f'{pdir}/cy{srch}*')) == 0:
class ChDir:
def __init__(self, newd):
self.newd = newd
def __enter__(self):
self.curd = os.getcwd()
os.chdir(self.newd)
return self
def __exit__(self, ext, exv, tb):
os.chdir(self.curd)
os.makedirs(pdir, exist_ok = True)
with tempfile.TemporaryDirectory(dir = pdir) as td, ChDir(str(td)) as chd:
os.makedirs(pdir, exist_ok = True)
for k, v in srcs.items():
with open(f'cys{srch}_{k}', 'wb') as f:
f.write(v.replace('{srch}', srch).encode('utf-8'))
import numpy as np
from setuptools import setup, Extension
from Cython.Build import cythonize
sys.argv += ['build_ext', '--inplace']
setup(
ext_modules = cythonize(
Extension(
f'{pdir}.cy{srch}', [f'cys{srch}_{k}' for k in filter(lambda e: e[e.rfind('.') + 1:] in ['pyx', 'c', 'cpp'], srcs.keys())],
depends = [f'cys{srch}_{k}' for k in filter(lambda e: e[e.rfind('.') + 1:] not in ['pyx', 'c', 'cpp'], srcs.keys())],
extra_compile_args = ['/O2', '/std:c++latest',
'/ID:/dev/_3party/vcpkg_bin/gmp/include/',
],
),
compiler_directives = {'language_level': 3, 'embedsignature': True},
annotate = True,
),
include_dirs = [np.get_include()],
)
del sys.argv[-2:]
for f in glob.glob(f'{pdir}/cy{srch}*'):
shutil.copy(f, f'./../')
print('Cython module:', f'cy{srch}')
return importlib.import_module(f'{pdir}.cy{srch}')
def cython_import():
srcs = {
'lib.h': """
#include <cstring>
#include <cstdint>
#include <stdexcept>
#include <tuple>
#include <iostream>
#include <string>
#include <type_traits>
#include <sstream>
#include <gmpxx.h>
#pragma comment(lib, "D:/dev/_3party/vcpkg_bin/gmp/lib/gmp.lib")
#define ASSERT_MSG(cond, msg) { if (!(cond)) throw std::runtime_error("Assertion (" #cond ") failed at line " + std::to_string(__LINE__) + "! Msg '" + std::string(msg) + "'."); }
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define LN { std::cout << "LN " << __LINE__ << std::endl; }
using u32 = uint32_t;
using u64 = uint64_t;
template <typename T>
size_t BitLen(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return mpz_sizeinbase(n.get_mpz_t(), 2);
else {
size_t cnt = 0;
while (n >= (1ULL << 32)) {
cnt += 32;
n >>= 32;
}
while (n >= (1 << 8)) {
cnt += 8;
n >>= 8;
}
while (n) {
++cnt;
n >>= 1;
}
return cnt;
}
}
template <typename T>
T ISqrt_Babylonian(T const & y) {
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
if (y <= 1)
return y;
T x = T(1) << (BitLen(y) / 2), a = 0, b = 0, limit = 3;
while (true) {
size_t constexpr loops = 3;
for (size_t i = 0; i < loops; ++i) {
if (i + 1 >= loops)
a = x;
b = y;
b /= x;
x += b;
x >>= 1;
}
if (b < a)
std::swap(a, b);
if (b - a > limit)
continue;
++b;
for (size_t i = 0; a <= b; ++a, ++i)
if (a * a > y) {
if (i == 0)
break;
else
return a - 1;
}
ASSERT(false);
}
}
template <typename T>
T ISqrt_AndersKaseorg(T const & n) {
// https://stackoverflow.com/a/53983683/941531
if (n > 0) {
T y = 0, x = T(1) << ((BitLen(n) + 1) >> 1);
while (true) {
y = (x + n / x) >> 1;
if (y >= x)
return x;
x = y;
}
} else if (n == 0)
return 0;
else
ASSERT_MSG(false, "square root not defined for negative numbers");
}
template <typename T>
T ISqrt_GMP(T const & y) {
// https://gmplib.org/manual/Integer-Roots
mpz_class r, n;
bool constexpr is_mpz = std::is_same_v<std::decay_t<T>, mpz_class>;
if constexpr(is_mpz)
n = y;
else {
static_assert(sizeof(T) <= 8);
n = u32(y >> 32);
n <<= 32;
n |= u32(y);
}
mpz_sqrt(r.get_mpz_t(), n.get_mpz_t());
if constexpr(is_mpz)
return r;
else
return (u64(mpz_get_ui(mpz_class(r >> 32).get_mpz_t())) << 32) | u64(mpz_get_ui(mpz_class(r & u32(-1)).get_mpz_t()));
}
template <typename T>
T KthRoot_ChordTangent(T const & n, size_t k = 2) {
// https://i.stack.imgur.com/et9O0.jpg
if (n <= 1)
return n;
auto KthPow = [&](auto const & x){
T y = x * x;
for (size_t i = 2; i < k; ++i)
y *= x;
return y;
};
auto KthPowDer = [&](auto const & x){
T y = x * u32(k);
for (size_t i = 1; i + 1 < k; ++i)
y *= x;
return y;
};
size_t root_bit_len = (BitLen(n) + k - 1) / k;
T hi = T(1) << root_bit_len,
x_begin = hi >> 1, x_end = hi,
y_begin = KthPow(x_begin), y_end = KthPow(x_end),
x_mid = 0, y_mid = 0, x_n = 0, y_n = 0, tangent_x = 0, chord_x = 0;
for (size_t icycle = 0; icycle < (1 << 30); ++icycle) {
if (x_end <= x_begin + 2)
break;
if constexpr(0) { // Do Binary Search step if needed
x_mid = (x_begin + x_end) >> 1;
y_mid = KthPow(x_mid);
if (y_mid > n) {
x_end = x_mid; y_end = y_mid;
} else {
x_begin = x_mid; y_begin = y_mid;
}
}
// (y_end - y_begin) / (x_end - x_begin) = (n - y_begin) / (x_n - x_begin) ->
x_n = x_begin + (n - y_begin) * (x_end - x_begin) / (y_end - y_begin);
y_n = KthPow(x_n);
tangent_x = x_n + (n - y_n) / KthPowDer(x_n) + 1;
chord_x = x_n + (n - y_n) * (x_end - x_n) / (y_end - y_n);
//ASSERT(chord_x <= tangent_x);
x_begin = chord_x; x_end = tangent_x;
y_begin = KthPow(x_begin); y_end = KthPow(x_end);
//ASSERT(y_begin <= n);
//ASSERT(y_end > n);
}
for (size_t i = 0; x_begin <= x_end; ++x_begin, ++i)
if (x_begin * x_begin > n) {
if (i == 0)
break;
else
return x_begin - 1;
}
ASSERT(false);
return 0;
}
mpz_class FromLimbs(uint64_t * limbs, uint64_t * cnt) {
mpz_class r;
mpz_import(r.get_mpz_t(), *cnt, -1, 8, -1, 0, limbs);
return r;
}
void ToLimbs(mpz_class const & n, uint64_t * limbs, uint64_t * cnt) {
uint64_t cnt_before = *cnt;
size_t cnt_res = 0;
mpz_export(limbs, &cnt_res, -1, 8, -1, 0, n.get_mpz_t());
ASSERT(cnt_res <= cnt_before);
std::memset(limbs + cnt_res, 0, (cnt_before - cnt_res) * 8);
*cnt = cnt_res;
}
void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_GMP<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_AndersKaseorg<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_Babylonian<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(KthRoot_ChordTangent<mpz_class>(FromLimbs(limbs, cnt), 2), limbs, cnt);
}
""",
'main.pyx': r"""
# distutils: language = c++
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
import numpy as np
cimport numpy as np
cimport cython
from libc.stdint cimport *
cdef extern from "cys{srch}_lib.h" nogil:
void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt);
#cython.boundscheck(False)
#cython.wraparound(False)
def ISqrt(method, n):
mask64 = (1 << 64) - 1
def ToLimbs():
return np.copy(np.frombuffer(n.to_bytes((n.bit_length() + 63) // 64 * 8, 'little'), dtype = np.uint64))
words = (n.bit_length() + 63) // 64
t = n
r = np.zeros((words,), dtype = np.uint64)
for i in range(words):
r[i] = np.uint64(t & mask64)
t >>= 64
return r
def FromLimbs(x):
return int.from_bytes(x.tobytes(), 'little')
n = 0
for i in range(x.shape[0]):
n |= int(x[i]) << (i * 64)
return n
n = ToLimbs()
cdef uint64_t[:] cn = n
cdef uint64_t ccnt = len(n)
cdef uint64_t cmethod = {'GMP': 0, 'AndersKaseorg': 1, 'Babylonian': 2, 'ChordTangent': 3}[method]
with nogil:
(ISqrt_GMP_Py if cmethod == 0 else ISqrt_AndersKaseorg_Py if cmethod == 1 else ISqrt_Babylonian_Py if cmethod == 2 else ISqrt_ChordTangent_Py)(
<uint64_t *>&cn[0], <uint64_t *>&ccnt
)
return FromLimbs(n[:ccnt])
""",
}
return cython_compile(srcs)
def main():
import math, gmpy2, timeit, random
mod = cython_import()
fs = [
('math.isqrt', math.isqrt),
('gmpy2.isqrt', gmpy2.isqrt),
('ISqrt_GMP', lambda n: mod.ISqrt('GMP', n)),
('ISqrt_AndersKaseorg', lambda n: mod.ISqrt('AndersKaseorg', n)),
('ISqrt_Babylonian', lambda n: mod.ISqrt('Babylonian', n)),
('ISqrt_ChordTangent', lambda n: mod.ISqrt('ChordTangent', n)),
]
times = [0] * len(fs)
ntests = 1 << 6
bits = 50000
for i in range(ntests):
n = random.randrange(1 << (bits - 1), 1 << bits)
ref = None
for j, (fn, f) in enumerate(fs):
timeit_cnt = 3
tim = timeit.timeit(lambda: f(n), number = timeit_cnt) / timeit_cnt
times[j] += tim
x = f(n)
if j == 0:
ref = x
else:
assert x == ref, (fn, ref, x)
print('Bits', bits)
print('\n'.join([f'{fs[i][0]:>19}: {round(times[i] / ntests * 1000, 3):>7} ms' for i in range(len(fs))]))
if __name__ == '__main__':
main()
and C++:
Try it online!
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <tuple>
#include <iostream>
#include <string>
#include <type_traits>
#include <sstream>
#include <gmpxx.h>
#define ASSERT_MSG(cond, msg) { if (!(cond)) throw std::runtime_error("Assertion (" #cond ") failed at line " + std::to_string(__LINE__) + "! Msg '" + std::string(msg) + "'."); }
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define LN { std::cout << "LN " << __LINE__ << std::endl; }
using u32 = uint32_t;
using u64 = uint64_t;
template <typename T>
size_t BitLen(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return mpz_sizeinbase(n.get_mpz_t(), 2);
else {
size_t cnt = 0;
while (n >= (1ULL << 32)) {
cnt += 32;
n >>= 32;
}
while (n >= (1 << 8)) {
cnt += 8;
n >>= 8;
}
while (n) {
++cnt;
n >>= 1;
}
return cnt;
}
}
template <typename T>
T ISqrt_Babylonian(T const & y) {
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
if (y <= 1)
return y;
T x = T(1) << (BitLen(y) / 2), a = 0, b = 0, limit = 3;
while (true) {
size_t constexpr loops = 3;
for (size_t i = 0; i < loops; ++i) {
if (i + 1 >= loops)
a = x;
b = y;
b /= x;
x += b;
x >>= 1;
}
if (b < a)
std::swap(a, b);
if (b - a > limit)
continue;
++b;
for (size_t i = 0; a <= b; ++a, ++i)
if (a * a > y) {
if (i == 0)
break;
else
return a - 1;
}
ASSERT(false);
}
}
template <typename T>
T ISqrt_AndersKaseorg(T const & n) {
// https://stackoverflow.com/a/53983683/941531
if (n > 0) {
T y = 0, x = T(1) << ((BitLen(n) + 1) >> 1);
while (true) {
y = (x + n / x) >> 1;
if (y >= x)
return x;
x = y;
}
} else if (n == 0)
return 0;
else
ASSERT_MSG(false, "square root not defined for negative numbers");
}
template <typename T>
T ISqrt_GMP(T const & y) {
// https://gmplib.org/manual/Integer-Roots
mpz_class r, n;
bool constexpr is_mpz = std::is_same_v<std::decay_t<T>, mpz_class>;
if constexpr(is_mpz)
n = y;
else {
static_assert(sizeof(T) <= 8);
n = u32(y >> 32);
n <<= 32;
n |= u32(y);
}
mpz_sqrt(r.get_mpz_t(), n.get_mpz_t());
if constexpr(is_mpz)
return r;
else
return (u64(mpz_get_ui(mpz_class(r >> 32).get_mpz_t())) << 32) | u64(mpz_get_ui(mpz_class(r & u32(-1)).get_mpz_t()));
}
template <typename T>
std::string IntToStr(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return n.get_str();
else {
std::ostringstream ss;
ss << n;
return ss.str();
}
}
template <typename T>
T KthRoot_ChordTangent(T const & n, size_t k = 2) {
// https://i.stack.imgur.com/et9O0.jpg
if (n <= 1)
return n;
auto KthPow = [&](auto const & x){
T y = x * x;
for (size_t i = 2; i < k; ++i)
y *= x;
return y;
};
auto KthPowDer = [&](auto const & x){
T y = x * u32(k);
for (size_t i = 1; i + 1 < k; ++i)
y *= x;
return y;
};
size_t root_bit_len = (BitLen(n) + k - 1) / k;
T hi = T(1) << root_bit_len,
x_begin = hi >> 1, x_end = hi,
y_begin = KthPow(x_begin), y_end = KthPow(x_end),
x_mid = 0, y_mid = 0, x_n = 0, y_n = 0, tangent_x = 0, chord_x = 0;
for (size_t icycle = 0; icycle < (1 << 30); ++icycle) {
//std::cout << "x_begin, x_end = " << IntToStr(x_begin) << ", " << IntToStr(x_end) << ", n " << IntToStr(n) << std::endl;
if (x_end <= x_begin + 2)
break;
if constexpr(0) { // Do Binary Search step if needed
x_mid = (x_begin + x_end) >> 1;
y_mid = KthPow(x_mid);
if (y_mid > n) {
x_end = x_mid; y_end = y_mid;
} else {
x_begin = x_mid; y_begin = y_mid;
}
}
// (y_end - y_begin) / (x_end - x_begin) = (n - y_begin) / (x_n - x_begin) ->
x_n = x_begin + (n - y_begin) * (x_end - x_begin) / (y_end - y_begin);
y_n = KthPow(x_n);
tangent_x = x_n + (n - y_n) / KthPowDer(x_n) + 1;
chord_x = x_n + (n - y_n) * (x_end - x_n) / (y_end - y_n);
//ASSERT(chord_x <= tangent_x);
x_begin = chord_x; x_end = tangent_x;
y_begin = KthPow(x_begin); y_end = KthPow(x_end);
//ASSERT(y_begin <= n);
//ASSERT(y_end > n);
}
for (size_t i = 0; x_begin <= x_end; ++x_begin, ++i)
if (x_begin * x_begin > n) {
if (i == 0)
break;
else
return x_begin - 1;
}
ASSERT(false);
return 0;
}
mpz_class FromLimbs(uint64_t * limbs, uint64_t * cnt) {
mpz_class r;
mpz_import(r.get_mpz_t(), *cnt, -1, 8, -1, 0, limbs);
return r;
}
void ToLimbs(mpz_class const & n, uint64_t * limbs, uint64_t * cnt) {
uint64_t cnt_before = *cnt;
size_t cnt_res = 0;
mpz_export(limbs, &cnt_res, -1, 8, -1, 0, n.get_mpz_t());
ASSERT(cnt_res <= cnt_before);
std::memset(limbs + cnt_res, 0, (cnt_before - cnt_res) * 8);
*cnt = cnt_res;
}
void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(KthRoot_ChordTangent<mpz_class>(FromLimbs(limbs, cnt), 2), limbs, cnt);
}
void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_GMP<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_AndersKaseorg<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_Babylonian<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
// Testing
#include <chrono>
#include <random>
#include <vector>
#include <iomanip>
inline double Time() {
static auto const gtb = std::chrono::high_resolution_clock::now();
return std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::high_resolution_clock::now() - gtb)
.count();
}
template <typename T, typename F>
std::vector<T> Test0(std::string const & test_name, size_t bits, size_t ntests, F && f) {
std::mt19937_64 rng{123};
std::vector<T> nums;
for (size_t i = 0; i < ntests; ++i) {
T n = 0;
for (size_t j = 0; j < bits; j += 32) {
size_t const cbits = std::min<size_t>(32, bits - j);
n <<= cbits;
n ^= u32(rng()) >> (32 - cbits);
}
nums.push_back(n);
}
auto tim = Time();
for (auto & n: nums)
n = f(n);
tim = Time() - tim;
std::cout << "Test " << std::setw(15) << ("'" + test_name + "'")
<< ", bits " << std::setw(6) << bits << ", time "
<< std::fixed << std::setprecision(6) << std::setw(9) << tim / ntests << " sec" << std::endl;
return nums;
}
void Test() {
auto f = [](auto ty, size_t bits, size_t ntests){
using T = std::decay_t<decltype(ty)>;
auto tim = Time();
auto a = Test0<T>("GMP", bits, ntests, [](auto const & x){ return ISqrt_GMP<T>(x); });
auto b = Test0<T>("AndersKaseorg", bits, ntests, [](auto const & x){ return ISqrt_AndersKaseorg<T>(x); });
ASSERT(b == a);
auto c = Test0<T>("Babylonian", bits, ntests, [](auto const & x){ return ISqrt_Babylonian<T>(x); });
ASSERT(c == a);
auto d = Test0<T>("ChordTangent", bits, ntests, [](auto const & x){ return KthRoot_ChordTangent<T>(x); });
ASSERT(d == a);
std::cout << "Bits " << bits << " nums " << ntests << " time " << std::fixed << std::setprecision(1) << (Time() - tim) << " sec" << std::endl;
};
for (auto p: std::vector<std::pair<int, int>>{{15, 1 << 19}, {30, 1 << 19}})
f(u64(), p.first, p.second);
for (auto p: std::vector<std::pair<int, int>>{{64, 1 << 15}, {8192, 1 << 10}, {50000, 1 << 5}})
f(mpz_class(), p.first, p.second);
}
int main() {
try {
Test();
return 0;
} catch (std::exception const & ex) {
std::cout << "Exception: " << ex.what() << std::endl;
return -1;
}
}
Your function fails for large inputs:
In [26]: isqrt((10**100+1)**2)
ValueError: input was not a perfect square
There is a recipe on the ActiveState site which should hopefully be more reliable since it uses integer maths only. It is based on an earlier StackOverflow question: Writing your own square root function
Floats cannot be precisely represented on computers. You can test for a desired proximity setting epsilon to a small value within the accuracy of python's floats.
def isqrt(n):
epsilon = .00000000001
i = int(n**.5 + 0.5)
if abs(i**2 - n) < epsilon:
return i
raise ValueError('input was not a perfect square')
Try this condition (no additional computation):
def isqrt(n):
i = math.sqrt(n)
if i != int(i):
raise ValueError('input was not a perfect square')
return i
If you need it to return an int (not a float with a trailing zero) then either assign a 2nd variable or compute int(i) twice.
I have compared the different methods given here with a loop:
for i in range (1000000): # 700 msec
r=int(123456781234567**0.5+0.5)
if r**2==123456781234567:rr=r
else:rr=-1
finding that this one is fastest and need no math-import. Very long might fail, but look at this
15241576832799734552675677489**0.5 = 123456781234567.0
I am given a bunch of expressions in prefix notation in an ANSI text file. I would like to produce another ANSI text file containing the step-by-step evaluation of these expressions. For example:
- + ^ x 2 ^ y 2 1
should be turned into
t1 = x^2
t2 = y^2
t3 = t1 + t2
t4 = t3 - 1
t4 is the result
I also have to identify common subexpressions. For example given
expression_1: z = ^ x 2
expression_2: - + z ^ y 2 1
expression_3: - z y
I have to generate an output saying that x appears in expressions 1, 2 and 3 (through z).
I have to identify dependecies: expression_1 depends only on x, expression_2 depends on x and y, etc.
The original problem is more difficult than the examples above and I have no control over the input format, it is in prefix notation in a much more complicated way than the above ones.
I already have a working implementation in C++ however it is a lot of pain doing such things in C++.
What programming language is best suited for these type problems?
Could you recommend a tutorial / website / book where I could start?
What keywords should I look for?
UPDATE: Based on the answers, the above examples are somewhat unfortunate, I have unary, binary and n-ary operators in the input. (If you are wondering, exp is an unary operator, sum over a range is an n-ary operator.)
To give you an idea how this would look like in Python, here is some example code:
operators = "+-*/^"
def parse(it, count=1):
token = next(it)
if token in operators:
op1, count = parse(it, count)
op2, count = parse(it, count)
tmp = "t%s" % count
print tmp, "=", op1, token, op2
return tmp, count + 1
return token, count
s = "- + ^ x 2 ^ y 2 1"
a = s.split()
res, dummy = parse(iter(a))
print res, "is the result"
The output is the same as your example output.
This example aside, I think any of the high-level languages you listed are almost equally suited for the task.
The sympy python package does symbolic algebra, including common subexpression elimination and generating evaluation steps for a set of expressions.
See: http://docs.sympy.org/dev/modules/rewriting.html (Look at the cse method at the bottom of the page).
The Python example is elegantly short, but I suspect that you won't actually get enough control over your expressions that way. You're much better off actually building an expression tree, even though it takes more work, and then querying the tree. Here's an example in Scala (suitable for cutting and pasting into the REPL):
object OpParser {
private def estr(oe: Option[Expr]) = oe.map(_.toString).getOrElse("_")
case class Expr(text: String, left: Option[Expr] = None, right: Option[Expr] = None) {
import Expr._
def varsUsed: Set[String] = text match {
case Variable(v) => Set(v)
case Op(o) =>
left.map(_.varsUsed).getOrElse(Set()) ++ right.map(_.varsUsed).getOrElse(Set())
case _ => Set()
}
def printTemp(n: Int = 0, depth: Int = 0): (String,Int) = text match {
case Op(o) =>
val (tl,nl) = left.map(_.printTemp(n,depth+1)).getOrElse(("_",n))
val (tr,nr) = right.map(_.printTemp(nl,depth+1)).getOrElse(("_",n))
val t = "t"+(nr+1)
println(t + " = " + tl + " " + text + " " + tr)
if (depth==0) println(t + " is the result")
(t, nr+1)
case _ => (text, n)
}
override def toString: String = {
if (left.isDefined || right.isDefined) {
"(" + estr(left) + " " + text + " " + estr(right) + ")"
}
else text
}
}
object Expr {
val Digit = "([0-9]+)"r
val Variable = "([a-z])"r
val Op = """([+\-*/^])"""r
def parse(s: String) = {
val bits = s.split(" ")
val parsed = (
if (bits.length > 2 && Variable.unapplySeq(bits(0)).isDefined && bits(1)=="=") {
parseParts(bits,2)
}
else parseParts(bits)
)
parsed.flatMap(p => if (p._2<bits.length) None else Some(p._1))
}
def parseParts(as: Array[String], from: Int = 0): Option[(Expr,Int)] = {
if (from >= as.length) None
else as(from) match {
case Digit(n) => Some(Expr(n), from+1)
case Variable(v) => Some(Expr(v), from+1)
case Op(o) =>
parseParts(as, from+1).flatMap(lhs =>
parseParts(as, lhs._2).map(rhs => (Expr(o,Some(lhs._1),Some(rhs._1)), rhs._2))
)
case _ => None
}
}
}
}
This may be a little much to digest all at once, but then again, this does rather a lot.
Firstly, it's completely bulletproof (note the heavy use of Option where a result might fail). If you throw garbage at it, it will just return None. (With a bit more work, you could make it complain about the problem in an informative way--basically the case Op(o) which then does parseParts nested twice could instead store the results and print out an informative error message if the op didn't get two arguments. Likewise, parse could complain about trailing values instead of just throwing back None.)
Secondly, when you're done with it, you have a complete expression tree. Note that printTemp prints out the temporary variables you wanted, and varsUsed lists the variables used in a particular expression, which you can use to expand to a full list once you parse multiple lines. (You might need to fiddle with the regexp a little if your variables can be more than just a to z.) Note also that the expression tree prints itself out in normal infix notation. Let's look at some examples:
scala> OpParser.Expr.parse("4")
res0: Option[OpParser.Expr] = Some(4)
scala> OpParser.Expr.parse("+ + + + + 1 2 3 4 5 6")
res1: Option[OpParser.Expr] = Some((((((1 + 2) + 3) + 4) + 5) + 6))
scala> OpParser.Expr.parse("- + ^ x 2 ^ y 2 1")
res2: Option[OpParser.Expr] = Some((((x ^ 2) + (y ^ 2)) - 1))
scala> OpParser.Expr.parse("+ + 4 4 4 4") // Too many 4s!
res3: Option[OpParser.Expr] = None
scala> OpParser.Expr.parse("Q#$S!M$#!*)000") // Garbage!
res4: Option[OpParser.Expr] = None
scala> OpParser.Expr.parse("z =") // Assigned nothing?!
res5: Option[OpParser.Expr] = None
scala> res2.foreach(_.printTemp())
t1 = x ^ 2
t2 = y ^ 2
t3 = t1 + t2
t4 = t3 - 1
t4 is the result
scala> res2.map(_.varsUsed)
res10: Option[Set[String]] = Some(Set(x, y))
Now, you could do this in Python also without too much additional work, and in a number of the other languages besides. I prefer to use Scala, but you may prefer otherwise. Regardless, I do recommend creating the full expression tree if you want to retain maximum flexibility for handling tricky cases.
Prefix notation is really simple to do with plain recursive parsers. For instance:
object Parser {
val Subexprs = collection.mutable.Map[String, String]()
val Dependencies = collection.mutable.Map[String, Set[String]]().withDefaultValue(Set.empty)
val TwoArgsOp = "([-+*/^])".r // - at the beginning, ^ at the end
val Ident = "(\\p{Alpha}\\w*)".r
val Literal = "(\\d+)".r
var counter = 1
def getIdent = {
val ident = "t" + counter
counter += 1
ident
}
def makeOp(op: String) = {
val op1 = expr
val op2 = expr
val ident = getIdent
val subexpr = op1 + " " + op + " " + op2
Subexprs(ident) = subexpr
Dependencies(ident) = Dependencies(op1) ++ Dependencies(op2) + op1 + op2
ident
}
def expr: String = nextToken match {
case TwoArgsOp(op) => makeOp(op)
case Ident(id) => id
case Literal(lit) => lit
case x => error("Unknown token "+x)
}
def nextToken = tokens.next
var tokens: Iterator[String] = _
def parse(input: String) = {
tokens = input.trim split "\\s+" toIterator;
counter = 1
expr
if (tokens.hasNext)
error("Input not fully parsed: "+tokens.mkString(" "))
(Subexprs, Dependencies)
}
}
This will generate output like this:
scala> val (subexpressions, dependencies) = Parser.parse("- + ^ x 2 ^ y 2 1")
subexpressions: scala.collection.mutable.Map[String,String] = Map(t3 -> t1 + t2, t4 -> t3 - 1, t1 -> x ^ 2, t2 -> y ^ 2)
dependencies: scala.collection.mutable.Map[String,Set[String]] = Map(t3 -> Set(x, y, t2, 2, t1), t4 -> Set(x, y, t3, t2, 1, 2, t1), t1 -> Set(x, 2), t
2 -> Set(y, 2))
scala> subexpressions.toSeq.sorted foreach println
(t1,x ^ 2)
(t2,y ^ 2)
(t3,t1 + t2)
(t4,t3 - 1)
scala> dependencies.toSeq.sortBy(_._1) foreach println
(t1,Set(x, 2))
(t2,Set(y, 2))
(t3,Set(x, y, t2, 2, t1))
(t4,Set(x, y, t3, t2, 1, 2, t1))
This can be easily expanded. For instance, to handle multiple expression statements you can use this:
object Parser {
val Subexprs = collection.mutable.Map[String, String]()
val Dependencies = collection.mutable.Map[String, Set[String]]().withDefaultValue(Set.empty)
val TwoArgsOp = "([-+*/^])".r // - at the beginning, ^ at the end
val Ident = "(\\p{Alpha}\\w*)".r
val Literal = "(\\d+)".r
var counter = 1
def getIdent = {
val ident = "t" + counter
counter += 1
ident
}
def makeOp(op: String) = {
val op1 = expr
val op2 = expr
val ident = getIdent
val subexpr = op1 + " " + op + " " + op2
Subexprs(ident) = subexpr
Dependencies(ident) = Dependencies(op1) ++ Dependencies(op2) + op1 + op2
ident
}
def expr: String = nextToken match {
case TwoArgsOp(op) => makeOp(op)
case Ident(id) => id
case Literal(lit) => lit
case x => error("Unknown token "+x)
}
def assignment: Unit = {
val ident = nextToken
nextToken match {
case "=" =>
val tmpIdent = expr
Dependencies(ident) = Dependencies(tmpIdent)
Subexprs(ident) = Subexprs(tmpIdent)
Dependencies.remove(tmpIdent)
Subexprs.remove(tmpIdent)
case x => error("Expected assignment, got "+x)
}
}
def stmts: Unit = while(tokens.hasNext) tokens.head match {
case TwoArgsOp(_) => expr
case Ident(_) => assignment
case x => error("Unknown statement starting with "+x)
}
def nextToken = tokens.next
var tokens: BufferedIterator[String] = _
def parse(input: String) = {
tokens = (input.trim split "\\s+" toIterator).buffered
counter = 1
stmts
if (tokens.hasNext)
error("Input not fully parsed: "+tokens.mkString(" "))
(Subexprs, Dependencies)
}
}
Yielding:
scala> val (subexpressions, dependencies) = Parser.parse("""
| z = ^ x 2
| - + z ^ y 2 1
| - z y
| """)
subexpressions: scala.collection.mutable.Map[String,String] = Map(t3 -> z + t2, t5 -> z - y, t4 -> t3 - 1, z -> x ^ 2, t2 -> y ^ 2)
dependencies: scala.collection.mutable.Map[String,Set[String]] = Map(t3 -> Set(x, y, t2, 2, z), t5 -> Set(x, 2, z, y), t4 -> Set(x, y, t3, t2, 1, 2, z
), z -> Set(x, 2), t2 -> Set(y, 2))
scala> subexpressions.toSeq.sorted foreach println
(t2,y ^ 2)
(t3,z + t2)
(t4,t3 - 1)
(t5,z - y)
(z,x ^ 2)
scala> dependencies.toSeq.sortBy(_._1) foreach println
(t2,Set(y, 2))
(t3,Set(x, y, t2, 2, z))
(t4,Set(x, y, t3, t2, 1, 2, z))
(t5,Set(x, 2, z, y))
(z,Set(x, 2))
Ok, since recursive parsers are not your thing, here's an alternative with parse combinators:
object PrefixParser extends JavaTokenParsers {
import scala.collection.mutable
// Maps generated through parsing
val Subexprs = mutable.Map[String, String]()
val Dependencies = mutable.Map[String, Set[String]]().withDefaultValue(Set.empty)
// Initialize, read, parse & evaluate string
def read(input: String) = {
counter = 1
Subexprs.clear
Dependencies.clear
parseAll(stmts, input)
}
// Grammar
def stmts = stmt+
def stmt = assignment | expr
def assignment = (ident <~ "=") ~ expr ^^ assignOp
def expr: P = subexpr | identifier | number
def subexpr: P = twoArgs | nArgs
def twoArgs: P = operator ~ expr ~ expr ^^ twoArgsOp
def nArgs: P = "sum" ~ ("\\d+".r >> args) ^^ nArgsOp
def args(n: String): Ps = repN(n.toInt, expr)
def operator = "[-+*/^]".r
def identifier = ident ^^ (id => Result(id, Set(id)))
def number = wholeNumber ^^ (Result(_, Set.empty))
// Evaluation helper class and types
case class Result(ident: String, dependencies: Set[String])
type P = Parser[Result]
type Ps = Parser[List[Result]]
// Evaluation methods
def assignOp: (String ~ Result) => Result = {
case ident ~ result =>
val value = assign(ident,
Subexprs(result.ident),
result.dependencies - result.ident)
Subexprs.remove(result.ident)
Dependencies.remove(result.ident)
value
}
def assign(ident: String,
value: String,
dependencies: Set[String]): Result = {
Subexprs(ident) = value
Dependencies(ident) = dependencies
Result(ident, dependencies)
}
def twoArgsOp: (String ~ Result ~ Result) => Result = {
case op ~ op1 ~ op2 => makeOp(op, op1, op2)
}
def makeOp(op: String,
op1: Result,
op2: Result): Result = {
val ident = getIdent
assign(ident,
"%s %s %s" format (op1.ident, op, op2.ident),
op1.dependencies ++ op2.dependencies + ident)
}
def nArgsOp: (String ~ List[Result]) => Result = {
case op ~ ops => makeNOp(op, ops)
}
def makeNOp(op: String, ops: List[Result]): Result = {
val ident = getIdent
assign(ident,
"%s(%s)" format (op, ops map (_.ident) mkString ", "),
ops.foldLeft(Set(ident))(_ ++ _.dependencies))
}
var counter = 1
def getIdent = {
val ident = "t" + counter
counter += 1
ident
}
// Debugging helper methods
def printAssignments = Subexprs.toSeq.sorted foreach println
def printDependencies = Dependencies.toSeq.sortBy(_._1) map {
case (id, dependencies) => (id, dependencies - id)
} foreach println
}
This is the kind of results you get:
scala> PrefixParser.read("""
| z = ^ x 2
| - + z ^ y 2 1
| - z y
| """)
res77: PrefixParser.ParseResult[List[PrefixParser.Result]] = [5.1] parsed: List(Result(z,Set(x)), Result(t4,Set(t4, y, t3, t2, z)), Result(t5,Set(z, y
, t5)))
scala> PrefixParser.printAssignments
(t2,y ^ 2)
(t3,z + t2)
(t4,t3 - 1)
(t5,z - y)
(z,x ^ 2)
scala> PrefixParser.printDependencies
(t2,Set(y))
(t3,Set(z, y, t2))
(t4,Set(y, t3, t2, z))
(t5,Set(z, y))
(z,Set(x))
n-Ary operator
scala> PrefixParser.read("""
| x = sum 3 + 1 2 * 3 4 5
| * x x
| """)
res93: PrefixParser.ParseResult[List[PrefixParser.Result]] = [4.1] parsed: List(Result(x,Set(t1, t2)), Result(t4,Set(x, t4)))
scala> PrefixParser.printAssignments
(t1,1 + 2)
(t2,3 * 4)
(t4,x * x)
(x,sum(t1, t2, 5))
scala> PrefixParser.printDependencies
(t1,Set())
(t2,Set())
(t4,Set(x))
(x,Set(t1, t2))
It turns out that this sort of parsing is of interest to me also, so I've done a bit more work on it.
There seems to be a sentiment that things like simplification of expressions is hard. I'm not so sure. Let's take a look at a fairly complete solution. (The printing out of tn expressions is not useful for me, and you've got several Scala examples already, so I'll skip that.)
First, we need to extract the various parts of the language. I'll pick regular expressions, though parser combinators could be used also:
object OpParser {
val Natural = "([0-9]+)"r
val Number = """((?:-)?[0-9]+(?:\.[0-9]+)?(?:[eE](?:-)?[0-9]+)?)"""r
val Variable = "([a-z])"r
val Unary = "(exp|sin|cos|tan|sqrt)"r
val Binary = "([-+*/^])"r
val Nary = "(sum|prod|list)"r
Pretty straightforward. We define the various things that might appear. (I've decided that user-defined variables can only be a single lowercase letter, and that numbers can be floating-point since you have the exp function.) The r at the end means this is a regular expression, and it will give us the stuff in parentheses.
Now we need to represent our tree. There are a number of ways to do this, but I'll choose an abstract base class with specific expressions as case classes, since this makes pattern matching easy. Furthermore, we might want nice printing, so we'll override toString. Mostly, though, we'll use recursive functions to do the heavy lifting.
abstract class Expr {
def text: String
def args: List[Expr]
override def toString = args match {
case l :: r :: Nil => "(" + l + " " + text + " " + r + ")"
case Nil => text
case _ => args.mkString(text+"(", ",", ")")
}
}
case class Num(text: String, args: List[Expr]) extends Expr {
val quantity = text.toDouble
}
case class Var(text: String, args: List[Expr]) extends Expr {
override def toString = args match {
case arg :: Nil => "(" + text + " <- " + arg + ")"
case _ => text
}
}
case class Una(text: String, args: List[Expr]) extends Expr
case class Bin(text: String, args: List[Expr]) extends Expr
case class Nar(text: String, args: List[Expr]) extends Expr {
override def toString = text match {
case "list" =>
(for ((a,i) <- args.zipWithIndex) yield {
"%3d: %s".format(i+1,a.toString)
}).mkString("List[\n","\n","\n]")
case _ => super.toString
}
}
Mostly this is pretty dull--each case class overrides the base class, and the text and args automatically fill in for the def. Note that I've decided that a list is a possible n-ary function, and that it will be printed out with line numbers. (The reason is that if you have multiple lines of input, it's sometimes more convenient to work with them all together as one expression; this lets them be one function.)
Once our data structures are defined, we need to parse the expressions. It's convenient to represent the stuff to parse as a list of tokens; as we parse, we'll return both an expression and the remaining tokens that we haven't parsed--this is a particularly useful structure for recursive parsing. Of course, we might fail to parse anything, so it had better be wrapped in an Option also.
def parse(tokens: List[String]): Option[(Expr,List[String])] = tokens match {
case Variable(x) :: "=" :: rest =>
for ((expr,remains) <- parse(rest)) yield (Var(x,List(expr)), remains)
case Variable(x) :: rest => Some(Var(x,Nil), rest)
case Number(n) :: rest => Some(Num(n,Nil), rest)
case Unary(u) :: rest =>
for ((expr,remains) <- parse(rest)) yield (Una(u,List(expr)), remains)
case Binary(b) :: rest =>
for ((lexp,lrem) <- parse(rest); (rexp,rrem) <- parse(lrem)) yield
(Bin(b,List(lexp,rexp)), rrem)
case Nary(a) :: Natural(b) :: rest =>
val buffer = new collection.mutable.ArrayBuffer[Expr]
def parseN(tok: List[String], n: Int = b.toInt): List[String] = {
if (n <= 0) tok
else {
for ((expr,remains) <- parse(tok)) yield { buffer += expr; parseN(remains, n-1) }
}.getOrElse(tok)
}
val remains = parseN(rest)
if (buffer.length == b.toInt) Some( Nar(a,buffer.toList), remains )
else None
case _ => None
}
Note that we use pattern matching and recursion to do most of the heavy lifting--we pick off part of the list, figure out how many arguments we need, and pass those along recursively. The N-ary operation is a little less friendly, but we create a little recursive function that will parse N things at a time for us, storing the results in a buffer.
Of course, this is a little unfriendly to use, so we add some wrapper functions that let us interface with it nicely:
def parse(s: String): Option[Expr] = parse(s.split(" ").toList).flatMap(x => {
if (x._2.isEmpty) Some(x._1) else None
})
def parseLines(ls: List[String]): Option[Expr] = {
val attempt = ls.map(parse).flatten
if (attempt.length<ls.length) None
else if (attempt.length==1) attempt.headOption
else Some(Nar("list",attempt))
}
Okay, now, what about simplification? One thing we might want to do is numeric simplification, where we precompute the expressions and replace the original expression with the reduced version thereof. That sounds like some sort of a recursive operation--find numbers, and combine them. First we get some helper functions to do calculations on numbers:
def calc(n: Num, f: Double => Double): Num = Num(f(n.quantity).toString, Nil)
def calc(n: Num, m: Num, f: (Double,Double) => Double): Num =
Num(f(n.quantity,m.quantity).toString, Nil)
def calc(ln: List[Num], f: (Double,Double) => Double): Num =
Num(ln.map(_.quantity).reduceLeft(f).toString, Nil)
and then we do the simplification:
def numericSimplify(expr: Expr): Expr = expr match {
case Una(t,List(e)) => numericSimplify(e) match {
case n # Num(_,_) => t match {
case "exp" => calc(n, math.exp _)
case "sin" => calc(n, math.sin _)
case "cos" => calc(n, math.cos _)
case "tan" => calc(n, math.tan _)
case "sqrt" => calc(n, math.sqrt _)
}
case a => Una(t,List(a))
}
case Bin(t,List(l,r)) => (numericSimplify(l), numericSimplify(r)) match {
case (n # Num(_,_), m # Num(_,_)) => t match {
case "+" => calc(n, m, _ + _)
case "-" => calc(n, m, _ - _)
case "*" => calc(n, m, _ * _)
case "/" => calc(n, m, _ / _)
case "^" => calc(n, m, math.pow)
}
case (a,b) => Bin(t,List(a,b))
}
case Nar("list",list) => Nar("list",list.map(numericSimplify))
case Nar(t,list) =>
val simple = list.map(numericSimplify)
val nums = simple.collect { case n # Num(_,_) => n }
if (simple.length == 0) t match {
case "sum" => Num("0",Nil)
case "prod" => Num("1",Nil)
}
else if (nums.length == simple.length) t match {
case "sum" => calc(nums, _ + _)
case "prod" => calc(nums, _ * _)
}
else Nar(t, simple)
case Var(t,List(e)) => Var(t, List(numericSimplify(e)))
case _ => expr
}
Notice again the heavy use of pattern matching to find when we're in a good case, and to dispatch the appropriate calculation.
Now, surely algebraic substitution is much more difficult! Actually, all you need to do is notice that an expression has already been used, and assign a variable. Since the syntax I've defined above allows in-place variable substitution, we can actually just modify our expression tree to include more variable assignments. So we do (edited to only insert variables if the user hasn't):
def algebraicSimplify(expr: Expr): Expr = {
val all, dup, used = new collection.mutable.HashSet[Expr]
val made = new collection.mutable.HashMap[Expr,Int]
val user = new collection.mutable.HashMap[Expr,Expr]
def findExpr(e: Expr) {
e match {
case Var(t,List(v)) =>
user += v -> e
if (all contains e) dup += e else all += e
case Var(_,_) | Num(_,_) => // Do nothing in these cases
case _ => if (all contains e) dup += e else all += e
}
e.args.foreach(findExpr)
}
findExpr(expr)
def replaceDup(e: Expr): Expr = {
if (made contains e) Var("x"+made(e),Nil)
else if (used contains e) Var(user(e).text,Nil)
else if (dup contains e) {
val fixed = replaceDupChildren(e)
made += e -> made.size
Var("x"+made(e),List(fixed))
}
else replaceDupChildren(e)
}
def replaceDupChildren(e: Expr): Expr = e match {
case Una(t,List(u)) => Una(t,List(replaceDup(u)))
case Bin(t,List(l,r)) => Bin(t,List(replaceDup(l),replaceDup(r)))
case Nar(t,list) => Nar(t,list.map(replaceDup))
case Var(t,List(v)) =>
used += v
Var(t,List(if (made contains v) replaceDup(v) else replaceDupChildren(v)))
case _ => e
}
replaceDup(expr)
}
That's it--a fully functional algebraic replacement routine. Note that it builds up sets of expressions that it's seen, keeping special track of which ones are duplicates. Thanks to the magic of case classes, all the equalities are defined for us, so it just works. Then we can replace any duplicates as we recurse through to find them. Note that the replace routine is split in half, and that it matches on an unreplaced version of the tree, but uses a replaced version.
Okay, now let's add a few tests:
def main(args: Array[String]) {
val test1 = "- + ^ x 2 ^ y 2 1"
val test2 = "+ + +" // Bad!
val test3 = "exp sin cos sum 5" // Bad!
val test4 = "+ * 2 3 ^ 3 2"
val test5 = List(test1, test4, "^ y 2").mkString("list 3 "," ","")
val test6 = "+ + x y + + * + x y + 4 5 * + x y + 4 y + + x y + 4 y"
def performTest(test: String) = {
println("Start with: " + test)
val p = OpParser.parse(test)
if (p.isEmpty) println(" Parsing failed")
else {
println("Parsed: " + p.get)
val q = OpParser.numericSimplify(p.get)
println("Numeric: " + q)
val r = OpParser.algebraicSimplify(q)
println("Algebraic: " + r)
}
println
}
List(test1,test2,test3,test4,test5,test6).foreach(performTest)
}
}
How does it do?
$ scalac OpParser.scala; scala OpParser
Start with: - + ^ x 2 ^ y 2 1
Parsed: (((x ^ 2) + (y ^ 2)) - 1)
Numeric: (((x ^ 2) + (y ^ 2)) - 1)
Algebraic: (((x ^ 2) + (y ^ 2)) - 1)
Start with: + + +
Parsing failed
Start with: exp sin cos sum 5
Parsing failed
Start with: + * 2 3 ^ 3 2
Parsed: ((2 * 3) + (3 ^ 2))
Numeric: 15.0
Algebraic: 15.0
Start with: list 3 - + ^ x 2 ^ y 2 1 + * 2 3 ^ 3 2 ^ y 2
Parsed: List[
1: (((x ^ 2) + (y ^ 2)) - 1)
2: ((2 * 3) + (3 ^ 2))
3: (y ^ 2)
]
Numeric: List[
1: (((x ^ 2) + (y ^ 2)) - 1)
2: 15.0
3: (y ^ 2)
]
Algebraic: List[
1: (((x ^ 2) + (x0 <- (y ^ 2))) - 1)
2: 15.0
3: x0
]
Start with: + + x y + + * + x y + 4 5 * + x y + 4 y + + x y + 4 y
Parsed: ((x + y) + ((((x + y) * (4 + 5)) + ((x + y) * (4 + y))) + ((x + y) + (4 + y))))
Numeric: ((x + y) + ((((x + y) * 9.0) + ((x + y) * (4 + y))) + ((x + y) + (4 + y))))
Algebraic: ((x0 <- (x + y)) + (((x0 * 9.0) + (x0 * (x1 <- (4 + y)))) + (x0 + x1)))
So I don't know if that's useful for you, but it turns out to be useful for me. And this is the sort of thing that I would be very hesitant to tackle in C++ because various things that were supposed to be easy ended up being painful instead.
Edit: Here's an example of using this structure to print temporary assignments, just to demonstrate that this structure is perfectly okay for doing such things.
Code:
def useTempVars(expr: Expr): Expr = {
var n = 0
def temp = { n += 1; "t"+n }
def replaceTemp(e: Expr, exempt: Boolean = false): Expr = {
def varify(x: Expr) = if (exempt) x else Var(temp,List(x))
e match {
case Var(t,List(e)) => Var(t,List(replaceTemp(e, exempt = true)))
case Una(t,List(u)) => varify( Una(t, List(replaceTemp(u,false))) )
case Bin(t,lr) => varify( Bin(t, lr.map(replaceTemp(_,false))) )
case Nar(t,ls) => varify( Nar(t, ls.map(replaceTemp(_,false))) )
case _ => e
}
}
replaceTemp(expr)
}
def varCut(expr: Expr): Expr = expr match {
case Var(t,_) => Var(t,Nil)
case Una(t,List(u)) => Una(t,List(varCut(u)))
case Bin(t,lr) => Bin(t, lr.map(varCut))
case Nar(t,ls) => Nar(t, ls.map(varCut))
case _ => expr
}
def getAssignments(expr: Expr): List[Expr] = {
val children = expr.args.flatMap(getAssignments)
expr match {
case Var(t,List(e)) => children :+ expr
case _ => children
}
}
def listAssignments(expr: Expr): List[String] = {
getAssignments(expr).collect(e => e match {
case Var(t,List(v)) => t + " = " + varCut(v)
}) :+ (expr.text + " is the answer")
}
Selected results (from listAssignments(useTempVars(r)).foreach(printf(" %s\n",_))):
Start with: - + ^ x 2 ^ y 2 1
Assignments:
t1 = (x ^ 2)
t2 = (y ^ 2)
t3 = (t1 + t2)
t4 = (t3 - 1)
t4 is the answer
Start with: + + x y + + * + x y + 4 5 * + x y + 4 y + + x y + 4 y
Algebraic: ((x0 <- (x + y)) + (((x0 * 9.0) + (x0 * (x1 <- (4 + y)))) + (x0 + x1)))
Assignments:
x0 = (x + y)
t1 = (x0 * 9.0)
x1 = (4 + y)
t2 = (x0 * x1)
t3 = (t1 + t2)
t4 = (x0 + x1)
t5 = (t3 + t4)
t6 = (x0 + t5)
t6 is the answer
Second edit: finding dependencies is also not too bad.
Code:
def directDepends(expr: Expr): Set[Expr] = expr match {
case Var(t,_) => Set(expr)
case _ => expr.args.flatMap(directDepends).toSet
}
def indirectDepends(expr: Expr) = {
val depend = getAssignments(expr).map(e =>
e -> e.args.flatMap(directDepends).toSet
).toMap
val tagged = for ((k,v) <- depend) yield (k.text -> v.map(_.text))
def percolate(tags: Map[String,Set[String]]): Option[Map[String,Set[String]]] = {
val expand = for ((k,v) <- tags) yield (
k -> (v union v.flatMap(x => tags.get(x).getOrElse(Set())))
)
if (tags.exists(kv => expand(kv._1) contains kv._1)) None // Cyclic dependency!
else if (tags == expand) Some(tags)
else percolate(expand)
}
percolate(tagged)
}
def listDependents(expr: Expr): List[(String,String)] = {
def sayNothing(s: String) = if (s=="") "nothing" else s
val e = expr match {
case Var(_,_) => expr
case _ => Var("result",List(expr))
}
indirectDepends(e).map(_.toList.map(x =>
(x._1, sayNothing(x._2.toList.sorted.mkString(" ")))
)).getOrElse(List((e.text,"cyclic")))
}
And if we add new test cases val test7 = "list 3 z = ^ x 2 - + z ^ y 2 1 w = - z y" and val test8 = "list 2 x = y y = x" and show the answers with for ((v,d) <- listDependents(r)) println(" "+v+" requires "+d) we get (selected results):
Start with: - + ^ x 2 ^ y 2 1
Dependencies:
result requires x y
Start with: list 3 z = ^ x 2 - + z ^ y 2 1 w = - z y
Parsed: List[
1: (z <- (x ^ 2))
2: ((z + (y ^ 2)) - 1)
3: (w <- (z - y))
]
Dependencies:
z requires x
w requires x y z
result requires w x y z
Start with: list 2 x = y y = x
Parsed: List[
1: (x <- y)
2: (y <- x)
]
Dependencies:
result requires cyclic
Start with: + + x y + + * + x y + 4 5 * + x y + 4 y + + x y + 4 y
Algebraic: ((x0 <- (x + y)) + (((x0 * 9.0) + (x0 * (x1 <- (4 + y)))) + (x0 + x1)))
Dependencies:
x0 requires x y
x1 requires y
result requires x x0 x1 y
So I think that on top of this sort of structure, all of your individual requirements are met by blocks of one or two dozen lines of Scala code.
Edit: here's expression evaluation, if you're given a mapping from vars to values:
def numericEvaluate(expr: Expr, initialValues: Map[String,Double]) = {
val chain = new collection.mutable.ArrayBuffer[(String,Double)]
val evaluated = new collection.mutable.HashMap[String,Double]
def note(xv: (String,Double)) { chain += xv; evaluated += xv }
evaluated ++= initialValues
def substitute(expr: Expr): Expr = expr match {
case Var(t,List(n # Num(v,_))) => { note(t -> v.toDouble); n }
case Var(t,_) if (evaluated contains t) => Num(evaluated(t).toString,Nil)
case Var(t,ls) => Var(t,ls.map(substitute))
case Una(t,List(u)) => Una(t,List(substitute(u)))
case Bin(t,ls) => Bin(t,ls.map(substitute))
case Nar(t,ls) => Nar(t,ls.map(substitute))
case _ => expr
}
def recurse(e: Expr): Expr = {
val sub = numericSimplify(substitute(e))
if (sub == e) e else recurse(sub)
}
(recurse(expr), chain.toList)
}
and it's used like so in the testing routine:
val (num,ops) = numericEvaluate(r,Map("x"->3,"y"->1.5))
println("Evaluated:")
for ((v,n) <- ops) println(" "+v+" = "+n)
println(" result = " + num)
giving results like these (with input of x = 3 and y = 1.5):
Start with: list 3 - + ^ x 2 ^ y 2 1 + * 2 3 ^ 3 2 ^ y 2
Algebraic: List[
1: (((x ^ 2) + (x0 <- (y ^ 2))) - 1)
2: 15.0
3: x0
]
Evaluated:
x0 = 2.25
result = List[
1: 10.25
2: 15.0
3: 2.25
]
Start with: list 3 z = ^ x 2 - + z ^ y 2 1 w = - z y
Algebraic: List[
1: (z <- (x ^ 2))
2: ((z + (y ^ 2)) - 1)
3: (w <- (z - y))
]
Evaluated:
z = 9.0
w = 7.5
result = List[
1: 9.0
2: 10.25
3: 7.5
]
The other challenge--picking out the vars that haven't already been used--is just set subtraction off of the dependencies result list. diff is the name of the set subtraction method.
The problem consists of two subproblems: parsing and symbolic manipulation. It seems to me the answer boils down to two possible solutions.
One is to implement everything from scratch: "I do recommend creating the full expression tree if you want to retain maximum flexibility for handling tricky cases." - proposed by Rex. As Sven points out: "any of the high-level languages you listed are almost equally suited for the task," however "Python (or any of the high-level languages you listed) won't take away the complexity of the problem."
I have received very nice solutions in Scala (many thanks for Rex and Daniel), a nice little example in Python (from Sven). However, I am still interested in Lisp, Haskell or Erlang solutions.
The other solution is to use some existing library/software for the task, with all the implied pros and cons. Candidates are Maxima (Common Lisp), SymPy (Python, proposed by payne) and GiNaC (C++).