Finding sqrt of big integers in python [duplicate] - python

Is there an integer square root somewhere in python, or in standard libraries? I want it to be exact (i.e. return an integer), and raise an exception if the input isn't a perfect square.
I tried using this code:
def isqrt(n):
i = int(math.sqrt(n) + 0.5)
if i**2 == n:
return i
raise ValueError('input was not a perfect square')
But it's ugly and I don't really trust it for large integers. I could iterate through the squares and give up if I've exceeded the value, but I assume it would be kinda slow to do something like that. Also, surely this is already implemented somewhere?
See also: Check if a number is a perfect square.

Note: There is now math.isqrt in stdlib, available since Python 3.8.
Newton's method works perfectly well on integers:
def isqrt(n):
x = n
y = (x + 1) // 2
while y < x:
x = y
y = (x + n // x) // 2
return x
This returns the largest integer x for which x * x does not exceed n. If you want to check if the result is exactly the square root, simply perform the multiplication to check if n is a perfect square.
I discuss this algorithm, and three other algorithms for calculating square roots, at my blog.

Update: Python 3.8 has a math.isqrt function in the standard library!
I benchmarked every (correct) function here on both small (0…222) and large (250001) inputs. The clear winners in both cases are gmpy2.isqrt suggested by mathmandan in first place, followed by Python 3.8’s math.isqrt in second, followed by the ActiveState recipe linked by NPE in third. The ActiveState recipe has a bunch of divisions that can be replaced by shifts, which makes it a bit faster (but still behind the native functions):
def isqrt(n):
if n > 0:
x = 1 << (n.bit_length() + 1 >> 1)
while True:
y = (x + n // x) >> 1
if y >= x:
return x
x = y
elif n == 0:
return 0
else:
raise ValueError("square root not defined for negative numbers")
Benchmark results:
gmpy2.isqrt() (mathmandan): 0.08 µs small, 0.07 ms large
int(gmpy2.isqrt())*: 0.3 µs small, 0.07 ms large
Python 3.8 math.isqrt: 0.13 µs small, 0.9 ms large
ActiveState (optimized as above): 0.6 µs small, 17.0 ms large
ActiveState (NPE): 1.0 µs small, 17.3 ms large
castlebravo long-hand: 4 µs small, 80 ms large
mathmandan improved: 2.7 µs small, 120 ms large
martineau (with this correction): 2.3 µs small, 140 ms large
nibot: 8 µs small, 1000 ms large
mathmandan: 1.8 µs small, 2200 ms large
castlebravo Newton’s method: 1.5 µs small, 19000 ms large
user448810: 1.4 µs small, 20000 ms large
(* Since gmpy2.isqrt returns a gmpy2.mpz object, which behaves mostly but not exactly like an int, you may need to convert it back to an int for some uses.)

Sorry for the very late response; I just stumbled onto this page. In case anyone visits this page in the future, the python module gmpy2 is designed to work with very large inputs, and includes among other things an integer square root function.
Example:
>>> import gmpy2
>>> gmpy2.isqrt((10**100+1)**2)
mpz(10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001L)
>>> gmpy2.isqrt((10**100+1)**2 - 1)
mpz(10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000L)
Granted, everything will have the "mpz" tag, but mpz's are compatible with int's:
>>> gmpy2.mpz(3)*4
mpz(12)
>>> int(gmpy2.mpz(12))
12
See my other answer for a discussion of this method's performance relative to some other answers to this question.
Download: https://code.google.com/p/gmpy/

Here's a very straightforward implementation:
def i_sqrt(n):
i = n.bit_length() >> 1 # i = floor( (1 + floor(log_2(n))) / 2 )
m = 1 << i # m = 2^i
#
# Fact: (2^(i + 1))^2 > n, so m has at least as many bits
# as the floor of the square root of n.
#
# Proof: (2^(i+1))^2 = 2^(2i + 2) >= 2^(floor(log_2(n)) + 2)
# >= 2^(ceil(log_2(n) + 1) >= 2^(log_2(n) + 1) > 2^(log_2(n)) = n. QED.
#
while m*m > n:
m >>= 1
i -= 1
for k in xrange(i-1, -1, -1):
x = m | (1 << k)
if x*x <= n:
m = x
return m
This is just a binary search. Initialize the value m to be the largest power of 2 that does not exceed the square root, then check whether each smaller bit can be set while keeping the result no larger than the square root. (Check the bits one at a time, in descending order.)
For reasonably large values of n (say, around 10**6000, or around 20000 bits), this seems to be:
Faster than the Newton's method implementation described by user448810.
Much, much slower than the gmpy2 built-in method in my other answer.
Comparable to, but somewhat slower than, the Longhand Square Root described by nibot.
All of these approaches succeed on inputs of this size, but on my machine, this function takes around 1.5 seconds, while #Nibot's takes about 0.9 seconds, #user448810's takes around 19 seconds, and the gmpy2 built-in method takes less than a millisecond(!). Example:
>>> import random
>>> import timeit
>>> import gmpy2
>>> r = random.getrandbits
>>> t = timeit.timeit
>>> t('i_sqrt(r(20000))', 'from __main__ import *', number = 5)/5. # This function
1.5102493192883117
>>> t('exact_sqrt(r(20000))', 'from __main__ import *', number = 5)/5. # Nibot
0.8952787937686366
>>> t('isqrt(r(20000))', 'from __main__ import *', number = 5)/5. # user448810
19.326695976676184
>>> t('gmpy2.isqrt(r(20000))', 'from __main__ import *', number = 5)/5. # gmpy2
0.0003599147067689046
>>> all(i_sqrt(n)==isqrt(n)==exact_sqrt(n)[0]==int(gmpy2.isqrt(n)) for n in (r(1500) for i in xrange(1500)))
True
This function can be generalized easily, though it's not quite as nice because I don't have quite as precise of an initial guess for m:
def i_root(num, root, report_exactness = True):
i = num.bit_length() / root
m = 1 << i
while m ** root < num:
m <<= 1
i += 1
while m ** root > num:
m >>= 1
i -= 1
for k in xrange(i-1, -1, -1):
x = m | (1 << k)
if x ** root <= num:
m = x
if report_exactness:
return m, m ** root == num
return m
However, note that gmpy2 also has an i_root method.
In fact this method could be adapted and applied to any (nonnegative, increasing) function f to determine an "integer inverse of f". However, to choose an efficient initial value of m you'd still want to know something about f.
Edit: Thanks to #Greggo for pointing out that the i_sqrt function can be rewritten to avoid using any multiplications. This yields an impressive performance boost!
def improved_i_sqrt(n):
assert n >= 0
if n == 0:
return 0
i = n.bit_length() >> 1 # i = floor( (1 + floor(log_2(n))) / 2 )
m = 1 << i # m = 2^i
#
# Fact: (2^(i + 1))^2 > n, so m has at least as many bits
# as the floor of the square root of n.
#
# Proof: (2^(i+1))^2 = 2^(2i + 2) >= 2^(floor(log_2(n)) + 2)
# >= 2^(ceil(log_2(n) + 1) >= 2^(log_2(n) + 1) > 2^(log_2(n)) = n. QED.
#
while (m << i) > n: # (m<<i) = m*(2^i) = m*m
m >>= 1
i -= 1
d = n - (m << i) # d = n-m^2
for k in xrange(i-1, -1, -1):
j = 1 << k
new_diff = d - (((m<<1) | j) << k) # n-(m+2^k)^2 = n-m^2-2*m*2^k-2^(2k)
if new_diff >= 0:
d = new_diff
m |= j
return m
Note that by construction, the kth bit of m << 1 is not set, so bitwise-or may be used to implement the addition of (m<<1) + (1<<k). Ultimately I have (2*m*(2**k) + 2**(2*k)) written as (((m<<1) | (1<<k)) << k), so it's three shifts and one bitwise-or (followed by a subtraction to get new_diff). Maybe there is still a more efficient way to get this? Regardless, it's far better than multiplying m*m! Compare with above:
>>> t('improved_i_sqrt(r(20000))', 'from __main__ import *', number = 5)/5.
0.10908999762373242
>>> all(improved_i_sqrt(n) == i_sqrt(n) for n in xrange(10**6))
True

Long-hand square root algorithm
It turns out that there is an algorithm for computing square roots that you can compute by hand, something like long-division. Each iteration of the algorithm produces exactly one digit of the resulting square root while consuming two digits of the number whose square root you seek. While the "long hand" version of the algorithm is specified in decimal, it works in any base, with binary being simplest to implement and perhaps the fastest to execute (depending on the underlying bignum representation).
Because this algorithm operates on numbers digit-by-digit, it produces exact results for arbitrarily large perfect squares, and for non-perfect-squares, can produce as many digits of precision (to the right of the decimal place) as desired.
There are two nice writeups on the "Dr. Math" site that explain the algorithm:
Square Roots in Binary
Longhand Square Roots
And here's an implementation in Python:
def exact_sqrt(x):
"""Calculate the square root of an arbitrarily large integer.
The result of exact_sqrt(x) is a tuple (a, r) such that a**2 + r = x, where
a is the largest integer such that a**2 <= x, and r is the "remainder". If
x is a perfect square, then r will be zero.
The algorithm used is the "long-hand square root" algorithm, as described at
http://mathforum.org/library/drmath/view/52656.html
Tobin Fricke 2014-04-23
Max Planck Institute for Gravitational Physics
Hannover, Germany
"""
N = 0 # Problem so far
a = 0 # Solution so far
# We'll process the number two bits at a time, starting at the MSB
L = x.bit_length()
L += (L % 2) # Round up to the next even number
for i in xrange(L, -1, -1):
# Get the next group of two bits
n = (x >> (2*i)) & 0b11
# Check whether we can reduce the remainder
if ((N - a*a) << 2) + n >= (a<<2) + 1:
b = 1
else:
b = 0
a = (a << 1) | b # Concatenate the next bit of the solution
N = (N << 2) | n # Concatenate the next bit of the problem
return (a, N-a*a)
You could easily modify this function to conduct additional iterations to calculate the fractional part of the square root. I was most interested in computing roots of large perfect squares.
I'm not sure how this compares to the "integer Newton's method" algorithm. I suspect that Newton's method is faster, since it can in principle generate multiple bits of the solution in one iteration, while the "long hand" algorithm generates exactly one bit of the solution per iteration.
Source repo: https://gist.github.com/tobin/11233492

One option would be to use the decimal module, and do it in sufficiently-precise floats:
import decimal
def isqrt(n):
nd = decimal.Decimal(n)
with decimal.localcontext() as ctx:
ctx.prec = n.bit_length()
i = int(nd.sqrt())
if i**2 != n:
raise ValueError('input was not a perfect square')
return i
which I think should work:
>>> isqrt(1)
1
>>> isqrt(7**14) == 7**7
True
>>> isqrt(11**1000) == 11**500
True
>>> isqrt(11**1000+1)
Traceback (most recent call last):
File "<ipython-input-121-e80953fb4d8e>", line 1, in <module>
isqrt(11**1000+1)
File "<ipython-input-100-dd91f704e2bd>", line 10, in isqrt
raise ValueError('input was not a perfect square')
ValueError: input was not a perfect square

Python's default math library has an integer square root function:
math.isqrt(n)
Return the integer square root of the nonnegative integer n. This is the floor of the exact square root of n, or equivalently the greatest integer a such that a² ≤ n.

Seems like you could check like this:
if int(math.sqrt(n))**2 == n:
print n, 'is a perfect square'
Update:
As you pointed out the above fails for large values of n. For those the following looks promising, which is an adaptation of the example C code, by Martin Guy # UKC, June 1985, for the relatively simple looking binary numeral digit-by-digit calculation method mentioned in the Wikipedia article Methods of computing square roots:
from math import ceil, log
def isqrt(n):
res = 0
bit = 4**int(ceil(log(n, 4))) if n else 0 # smallest power of 4 >= the argument
while bit:
if n >= res + bit:
n -= res + bit
res = (res >> 1) + bit
else:
res >>= 1
bit >>= 2
return res
if __name__ == '__main__':
from math import sqrt # for comparison purposes
for i in range(17)+[2**53, (10**100+1)**2]:
is_perfect_sq = isqrt(i)**2 == i
print '{:21,d}: math.sqrt={:12,.7G}, isqrt={:10,d} {}'.format(
i, sqrt(i), isqrt(i), '(perfect square)' if is_perfect_sq else '')
Output:
0: math.sqrt= 0, isqrt= 0 (perfect square)
1: math.sqrt= 1, isqrt= 1 (perfect square)
2: math.sqrt= 1.414214, isqrt= 1
3: math.sqrt= 1.732051, isqrt= 1
4: math.sqrt= 2, isqrt= 2 (perfect square)
5: math.sqrt= 2.236068, isqrt= 2
6: math.sqrt= 2.44949, isqrt= 2
7: math.sqrt= 2.645751, isqrt= 2
8: math.sqrt= 2.828427, isqrt= 2
9: math.sqrt= 3, isqrt= 3 (perfect square)
10: math.sqrt= 3.162278, isqrt= 3
11: math.sqrt= 3.316625, isqrt= 3
12: math.sqrt= 3.464102, isqrt= 3
13: math.sqrt= 3.605551, isqrt= 3
14: math.sqrt= 3.741657, isqrt= 3
15: math.sqrt= 3.872983, isqrt= 3
16: math.sqrt= 4, isqrt= 4 (perfect square)
9,007,199,254,740,992: math.sqrt=9.490627E+07, isqrt=94,906,265
100,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,020,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,001: math.sqrt= 1E+100, isqrt=10,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,001 (perfect square)

The script below extracts integer square roots. It uses no divisions, only bitshifts, so it is quite fast. It uses Newton's method on the inverse square root, a technique made famous by Quake III Arena as mentioned in the Wikipedia article, Fast inverse square root.
The strategy of the algorithm to compute s = sqrt(Y) is as follows.
Reduce the argument Y to y in the range [1/4, 1), i.e., y = Y/B, with 1/4 <= y < 1, where B is an even power of 2, so B = 2**(2*k) for some integer k. We want to find X, where x = X/B, and x = 1 / sqrt(y).
Determine a first approximation to X using a quadratic minimax polynomial.
Refine X using Newton's method.
Calculate s = X*Y/(2**(3*k)).
We don't actually create fractions or perform any divisions. All the arithmetic is done with integers, and we use bit shifting to divide by various powers of B.
Range reduction lets us find a good initial approximation to feed to Newton's method. Here's a version of the 2nd degree minimax polynomial approximation to the inverse square root in the interval [1/4, 1):
(Sorry, I've reversed the meaning of x & y here, to conform to the usual conventions). The maximum error of this approximation is around 0.0355 ~= 1/28. Here's a graph showing the error:
Using this poly, our initial x starts with at least 4 or 5 bits of precision. Each round of Newton's method doubles the precision, so it doesn't take many rounds to get thousands of bits, if we want them.
""" Integer square root
Uses no divisions, only shifts
"Quake" style algorithm,
i.e., Newton's method for 1 / sqrt(y)
Uses a quadratic minimax polynomial for the first approximation
Written by PM 2Ring 2022.01.23
"""
def int_sqrt(y):
if y < 0:
raise ValueError("int_sqrt arg must be >= 0, not %s" % y)
if y < 2:
return y
# print("\n*", y, "*")
# Range reduction.
# Find k such that 1/4 <= y/b < 1, where b = 2 ** (k*2)
j = y.bit_length()
# Round k*2 up to the next even number
k2 = j + (j & 1)
# k and some useful multiples
k = k2 >> 1
k3 = k2 + k
k6 = k3 << 1
kd = k6 + 1
# b cubed
b3 = 1 << k6
# Minimax approximation: x/b ~= 1 / sqrt(y/b)
x = (((463 * y * y) >> k2) - (896 * y) + (698 << k2)) >> 8
# print(" ", x, h)
# Newton's method for 1 / sqrt(y/b)
epsilon = 1 << k
for i in range(1, 99):
dx = x * (b3 - y * x * x) >> kd
x += dx
# print(f" {i}: {x} {dx}")
if abs(dx) <= epsilon:
break
# s == sqrt(y)
s = x * y >> k3
# Adjust if too low
ss = s + 1
return ss if ss * ss <= y else s
def test(lo, hi, step=1):
for y in range(lo, hi, step):
s = int_sqrt(y)
ss = s + 1
s2, ss2 = s * s, ss * ss
assert s2 <= y < ss2, (y, s2, ss2)
print("ok")
test(0, 100000, 1)
This code is certainly slower than math.isqrt and decimal.Decimal.sqrt. Its purpose is simply to illustrate the algorithm. It would be interesting to see how fast it would be if it were implemented in C...
Here's a live version, running on the SageMathCell server. Set hi <= 0 to calculate and display the results for a single value set in lo. You can put expressions in the input boxes, eg set hi to 0 and lo to 2 * 10**100 to get sqrt(2) * 10**50.

Inspired by all answers, decided to implement in pure C++ several best methods from these answers. As everybody knows C++ is always faster than Python.
To glue C++ and Python I used Cython. It allows to make out of C++ a Python module and then call C++ functions directly from Python functions.
Also as complementary I provided not only Python-adopted code, but pure C++ with tests too.
Here are timings from pure C++ tests:
Test 'GMP', bits 64, time 0.000001 sec
Test 'AndersKaseorg', bits 64, time 0.000003 sec
Test 'Babylonian', bits 64, time 0.000006 sec
Test 'ChordTangent', bits 64, time 0.000018 sec
Test 'GMP', bits 50000, time 0.000118 sec
Test 'AndersKaseorg', bits 50000, time 0.002777 sec
Test 'Babylonian', bits 50000, time 0.003062 sec
Test 'ChordTangent', bits 50000, time 0.009120 sec
and same C++ functions but as adopted Python module have timings:
Bits 50000
math.isqrt: 2.819 ms
gmpy2.isqrt: 0.166 ms
ISqrt_GMP: 0.252 ms
ISqrt_AndersKaseorg: 3.338 ms
ISqrt_Babylonian: 3.756 ms
ISqrt_ChordTangent: 10.564 ms
My Cython-C++ is nice in a sence as a framework for those people who want to write and test his own C++ method from Python directly.
As you noticed in above timings as example I used following methods:
math.isqrt, implementation from standard library.
gmpy2.isqrt, GMPY2 library's implementation.
ISqrt_GMP - same as GMPY2, but using my Cython module, there I use C++ GMP library (<gmpxx.h>) directly.
ISqrt_AndersKaseorg, code taken from answer of #AndersKaseorg.
ISqrt_Babylonian, method taken from Wikipedia article, so-called Babylonian method. My own implementation as I understand it.
ISqrt_ChordTangent, it is my own method that I called Chord-Tangent, because it uses chord and tangent line to iteratively shorten interval of search. This method is described in moderate details in my other article. This method is nice because it searches not only square root, but also K-th root for any K. I drew a small picture showing details of this algorithm.
Regarding compiling C++/Cython code, I used GMP library. You need to install it first, under Linux it is easy through sudo apt install libgmp-dev.
Under Windows easiest is to install really great program VCPKG, this is software Package Manager, similar to APT in Linux. VCPKG compiles all packages from sources using Visual Studio (don't forget to install Community version of Visual Studio). After installing VCPKG you can install GMP by vcpkg install gmp. Also you may install MPIR, this is alternative fork of GMP, you can install it through vcpkg install mpir.
After GMP is installed under Windows please edit my Python code and replace path to include directory and library file. VCPKG at the end of installation should show you path to ZIP file with GMP library, there are .lib and .h files.
You may notice in Python code that I also designed special handy cython_compile() function that I use to compile any C++ code into Python module. This function is really good as it allows for you to easily plug-in any C++ code into Python, this can be reused many times.
If you have any questions or suggestions, or something doesn't work on your PC, please write in comments.
Below first I show code in Python, afterwards in C++. See Try it online! link above C++ code to run code online on GodBolt servers. Both code snippets I fully runnable from scratch as they are, nothing needs to be edited in them.
def cython_compile(srcs):
import json, hashlib, os, glob, importlib, sys, shutil, tempfile
srch = hashlib.sha256(json.dumps(srcs, sort_keys = True, ensure_ascii = True).encode('utf-8')).hexdigest().upper()[:12]
pdir = 'cyimp'
if len(glob.glob(f'{pdir}/cy{srch}*')) == 0:
class ChDir:
def __init__(self, newd):
self.newd = newd
def __enter__(self):
self.curd = os.getcwd()
os.chdir(self.newd)
return self
def __exit__(self, ext, exv, tb):
os.chdir(self.curd)
os.makedirs(pdir, exist_ok = True)
with tempfile.TemporaryDirectory(dir = pdir) as td, ChDir(str(td)) as chd:
os.makedirs(pdir, exist_ok = True)
for k, v in srcs.items():
with open(f'cys{srch}_{k}', 'wb') as f:
f.write(v.replace('{srch}', srch).encode('utf-8'))
import numpy as np
from setuptools import setup, Extension
from Cython.Build import cythonize
sys.argv += ['build_ext', '--inplace']
setup(
ext_modules = cythonize(
Extension(
f'{pdir}.cy{srch}', [f'cys{srch}_{k}' for k in filter(lambda e: e[e.rfind('.') + 1:] in ['pyx', 'c', 'cpp'], srcs.keys())],
depends = [f'cys{srch}_{k}' for k in filter(lambda e: e[e.rfind('.') + 1:] not in ['pyx', 'c', 'cpp'], srcs.keys())],
extra_compile_args = ['/O2', '/std:c++latest',
'/ID:/dev/_3party/vcpkg_bin/gmp/include/',
],
),
compiler_directives = {'language_level': 3, 'embedsignature': True},
annotate = True,
),
include_dirs = [np.get_include()],
)
del sys.argv[-2:]
for f in glob.glob(f'{pdir}/cy{srch}*'):
shutil.copy(f, f'./../')
print('Cython module:', f'cy{srch}')
return importlib.import_module(f'{pdir}.cy{srch}')
def cython_import():
srcs = {
'lib.h': """
#include <cstring>
#include <cstdint>
#include <stdexcept>
#include <tuple>
#include <iostream>
#include <string>
#include <type_traits>
#include <sstream>
#include <gmpxx.h>
#pragma comment(lib, "D:/dev/_3party/vcpkg_bin/gmp/lib/gmp.lib")
#define ASSERT_MSG(cond, msg) { if (!(cond)) throw std::runtime_error("Assertion (" #cond ") failed at line " + std::to_string(__LINE__) + "! Msg '" + std::string(msg) + "'."); }
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define LN { std::cout << "LN " << __LINE__ << std::endl; }
using u32 = uint32_t;
using u64 = uint64_t;
template <typename T>
size_t BitLen(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return mpz_sizeinbase(n.get_mpz_t(), 2);
else {
size_t cnt = 0;
while (n >= (1ULL << 32)) {
cnt += 32;
n >>= 32;
}
while (n >= (1 << 8)) {
cnt += 8;
n >>= 8;
}
while (n) {
++cnt;
n >>= 1;
}
return cnt;
}
}
template <typename T>
T ISqrt_Babylonian(T const & y) {
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
if (y <= 1)
return y;
T x = T(1) << (BitLen(y) / 2), a = 0, b = 0, limit = 3;
while (true) {
size_t constexpr loops = 3;
for (size_t i = 0; i < loops; ++i) {
if (i + 1 >= loops)
a = x;
b = y;
b /= x;
x += b;
x >>= 1;
}
if (b < a)
std::swap(a, b);
if (b - a > limit)
continue;
++b;
for (size_t i = 0; a <= b; ++a, ++i)
if (a * a > y) {
if (i == 0)
break;
else
return a - 1;
}
ASSERT(false);
}
}
template <typename T>
T ISqrt_AndersKaseorg(T const & n) {
// https://stackoverflow.com/a/53983683/941531
if (n > 0) {
T y = 0, x = T(1) << ((BitLen(n) + 1) >> 1);
while (true) {
y = (x + n / x) >> 1;
if (y >= x)
return x;
x = y;
}
} else if (n == 0)
return 0;
else
ASSERT_MSG(false, "square root not defined for negative numbers");
}
template <typename T>
T ISqrt_GMP(T const & y) {
// https://gmplib.org/manual/Integer-Roots
mpz_class r, n;
bool constexpr is_mpz = std::is_same_v<std::decay_t<T>, mpz_class>;
if constexpr(is_mpz)
n = y;
else {
static_assert(sizeof(T) <= 8);
n = u32(y >> 32);
n <<= 32;
n |= u32(y);
}
mpz_sqrt(r.get_mpz_t(), n.get_mpz_t());
if constexpr(is_mpz)
return r;
else
return (u64(mpz_get_ui(mpz_class(r >> 32).get_mpz_t())) << 32) | u64(mpz_get_ui(mpz_class(r & u32(-1)).get_mpz_t()));
}
template <typename T>
T KthRoot_ChordTangent(T const & n, size_t k = 2) {
// https://i.stack.imgur.com/et9O0.jpg
if (n <= 1)
return n;
auto KthPow = [&](auto const & x){
T y = x * x;
for (size_t i = 2; i < k; ++i)
y *= x;
return y;
};
auto KthPowDer = [&](auto const & x){
T y = x * u32(k);
for (size_t i = 1; i + 1 < k; ++i)
y *= x;
return y;
};
size_t root_bit_len = (BitLen(n) + k - 1) / k;
T hi = T(1) << root_bit_len,
x_begin = hi >> 1, x_end = hi,
y_begin = KthPow(x_begin), y_end = KthPow(x_end),
x_mid = 0, y_mid = 0, x_n = 0, y_n = 0, tangent_x = 0, chord_x = 0;
for (size_t icycle = 0; icycle < (1 << 30); ++icycle) {
if (x_end <= x_begin + 2)
break;
if constexpr(0) { // Do Binary Search step if needed
x_mid = (x_begin + x_end) >> 1;
y_mid = KthPow(x_mid);
if (y_mid > n) {
x_end = x_mid; y_end = y_mid;
} else {
x_begin = x_mid; y_begin = y_mid;
}
}
// (y_end - y_begin) / (x_end - x_begin) = (n - y_begin) / (x_n - x_begin) ->
x_n = x_begin + (n - y_begin) * (x_end - x_begin) / (y_end - y_begin);
y_n = KthPow(x_n);
tangent_x = x_n + (n - y_n) / KthPowDer(x_n) + 1;
chord_x = x_n + (n - y_n) * (x_end - x_n) / (y_end - y_n);
//ASSERT(chord_x <= tangent_x);
x_begin = chord_x; x_end = tangent_x;
y_begin = KthPow(x_begin); y_end = KthPow(x_end);
//ASSERT(y_begin <= n);
//ASSERT(y_end > n);
}
for (size_t i = 0; x_begin <= x_end; ++x_begin, ++i)
if (x_begin * x_begin > n) {
if (i == 0)
break;
else
return x_begin - 1;
}
ASSERT(false);
return 0;
}
mpz_class FromLimbs(uint64_t * limbs, uint64_t * cnt) {
mpz_class r;
mpz_import(r.get_mpz_t(), *cnt, -1, 8, -1, 0, limbs);
return r;
}
void ToLimbs(mpz_class const & n, uint64_t * limbs, uint64_t * cnt) {
uint64_t cnt_before = *cnt;
size_t cnt_res = 0;
mpz_export(limbs, &cnt_res, -1, 8, -1, 0, n.get_mpz_t());
ASSERT(cnt_res <= cnt_before);
std::memset(limbs + cnt_res, 0, (cnt_before - cnt_res) * 8);
*cnt = cnt_res;
}
void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_GMP<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_AndersKaseorg<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_Babylonian<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(KthRoot_ChordTangent<mpz_class>(FromLimbs(limbs, cnt), 2), limbs, cnt);
}
""",
'main.pyx': r"""
# distutils: language = c++
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
import numpy as np
cimport numpy as np
cimport cython
from libc.stdint cimport *
cdef extern from "cys{srch}_lib.h" nogil:
void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt);
#cython.boundscheck(False)
#cython.wraparound(False)
def ISqrt(method, n):
mask64 = (1 << 64) - 1
def ToLimbs():
return np.copy(np.frombuffer(n.to_bytes((n.bit_length() + 63) // 64 * 8, 'little'), dtype = np.uint64))
words = (n.bit_length() + 63) // 64
t = n
r = np.zeros((words,), dtype = np.uint64)
for i in range(words):
r[i] = np.uint64(t & mask64)
t >>= 64
return r
def FromLimbs(x):
return int.from_bytes(x.tobytes(), 'little')
n = 0
for i in range(x.shape[0]):
n |= int(x[i]) << (i * 64)
return n
n = ToLimbs()
cdef uint64_t[:] cn = n
cdef uint64_t ccnt = len(n)
cdef uint64_t cmethod = {'GMP': 0, 'AndersKaseorg': 1, 'Babylonian': 2, 'ChordTangent': 3}[method]
with nogil:
(ISqrt_GMP_Py if cmethod == 0 else ISqrt_AndersKaseorg_Py if cmethod == 1 else ISqrt_Babylonian_Py if cmethod == 2 else ISqrt_ChordTangent_Py)(
<uint64_t *>&cn[0], <uint64_t *>&ccnt
)
return FromLimbs(n[:ccnt])
""",
}
return cython_compile(srcs)
def main():
import math, gmpy2, timeit, random
mod = cython_import()
fs = [
('math.isqrt', math.isqrt),
('gmpy2.isqrt', gmpy2.isqrt),
('ISqrt_GMP', lambda n: mod.ISqrt('GMP', n)),
('ISqrt_AndersKaseorg', lambda n: mod.ISqrt('AndersKaseorg', n)),
('ISqrt_Babylonian', lambda n: mod.ISqrt('Babylonian', n)),
('ISqrt_ChordTangent', lambda n: mod.ISqrt('ChordTangent', n)),
]
times = [0] * len(fs)
ntests = 1 << 6
bits = 50000
for i in range(ntests):
n = random.randrange(1 << (bits - 1), 1 << bits)
ref = None
for j, (fn, f) in enumerate(fs):
timeit_cnt = 3
tim = timeit.timeit(lambda: f(n), number = timeit_cnt) / timeit_cnt
times[j] += tim
x = f(n)
if j == 0:
ref = x
else:
assert x == ref, (fn, ref, x)
print('Bits', bits)
print('\n'.join([f'{fs[i][0]:>19}: {round(times[i] / ntests * 1000, 3):>7} ms' for i in range(len(fs))]))
if __name__ == '__main__':
main()
and C++:
Try it online!
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <tuple>
#include <iostream>
#include <string>
#include <type_traits>
#include <sstream>
#include <gmpxx.h>
#define ASSERT_MSG(cond, msg) { if (!(cond)) throw std::runtime_error("Assertion (" #cond ") failed at line " + std::to_string(__LINE__) + "! Msg '" + std::string(msg) + "'."); }
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define LN { std::cout << "LN " << __LINE__ << std::endl; }
using u32 = uint32_t;
using u64 = uint64_t;
template <typename T>
size_t BitLen(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return mpz_sizeinbase(n.get_mpz_t(), 2);
else {
size_t cnt = 0;
while (n >= (1ULL << 32)) {
cnt += 32;
n >>= 32;
}
while (n >= (1 << 8)) {
cnt += 8;
n >>= 8;
}
while (n) {
++cnt;
n >>= 1;
}
return cnt;
}
}
template <typename T>
T ISqrt_Babylonian(T const & y) {
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
if (y <= 1)
return y;
T x = T(1) << (BitLen(y) / 2), a = 0, b = 0, limit = 3;
while (true) {
size_t constexpr loops = 3;
for (size_t i = 0; i < loops; ++i) {
if (i + 1 >= loops)
a = x;
b = y;
b /= x;
x += b;
x >>= 1;
}
if (b < a)
std::swap(a, b);
if (b - a > limit)
continue;
++b;
for (size_t i = 0; a <= b; ++a, ++i)
if (a * a > y) {
if (i == 0)
break;
else
return a - 1;
}
ASSERT(false);
}
}
template <typename T>
T ISqrt_AndersKaseorg(T const & n) {
// https://stackoverflow.com/a/53983683/941531
if (n > 0) {
T y = 0, x = T(1) << ((BitLen(n) + 1) >> 1);
while (true) {
y = (x + n / x) >> 1;
if (y >= x)
return x;
x = y;
}
} else if (n == 0)
return 0;
else
ASSERT_MSG(false, "square root not defined for negative numbers");
}
template <typename T>
T ISqrt_GMP(T const & y) {
// https://gmplib.org/manual/Integer-Roots
mpz_class r, n;
bool constexpr is_mpz = std::is_same_v<std::decay_t<T>, mpz_class>;
if constexpr(is_mpz)
n = y;
else {
static_assert(sizeof(T) <= 8);
n = u32(y >> 32);
n <<= 32;
n |= u32(y);
}
mpz_sqrt(r.get_mpz_t(), n.get_mpz_t());
if constexpr(is_mpz)
return r;
else
return (u64(mpz_get_ui(mpz_class(r >> 32).get_mpz_t())) << 32) | u64(mpz_get_ui(mpz_class(r & u32(-1)).get_mpz_t()));
}
template <typename T>
std::string IntToStr(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return n.get_str();
else {
std::ostringstream ss;
ss << n;
return ss.str();
}
}
template <typename T>
T KthRoot_ChordTangent(T const & n, size_t k = 2) {
// https://i.stack.imgur.com/et9O0.jpg
if (n <= 1)
return n;
auto KthPow = [&](auto const & x){
T y = x * x;
for (size_t i = 2; i < k; ++i)
y *= x;
return y;
};
auto KthPowDer = [&](auto const & x){
T y = x * u32(k);
for (size_t i = 1; i + 1 < k; ++i)
y *= x;
return y;
};
size_t root_bit_len = (BitLen(n) + k - 1) / k;
T hi = T(1) << root_bit_len,
x_begin = hi >> 1, x_end = hi,
y_begin = KthPow(x_begin), y_end = KthPow(x_end),
x_mid = 0, y_mid = 0, x_n = 0, y_n = 0, tangent_x = 0, chord_x = 0;
for (size_t icycle = 0; icycle < (1 << 30); ++icycle) {
//std::cout << "x_begin, x_end = " << IntToStr(x_begin) << ", " << IntToStr(x_end) << ", n " << IntToStr(n) << std::endl;
if (x_end <= x_begin + 2)
break;
if constexpr(0) { // Do Binary Search step if needed
x_mid = (x_begin + x_end) >> 1;
y_mid = KthPow(x_mid);
if (y_mid > n) {
x_end = x_mid; y_end = y_mid;
} else {
x_begin = x_mid; y_begin = y_mid;
}
}
// (y_end - y_begin) / (x_end - x_begin) = (n - y_begin) / (x_n - x_begin) ->
x_n = x_begin + (n - y_begin) * (x_end - x_begin) / (y_end - y_begin);
y_n = KthPow(x_n);
tangent_x = x_n + (n - y_n) / KthPowDer(x_n) + 1;
chord_x = x_n + (n - y_n) * (x_end - x_n) / (y_end - y_n);
//ASSERT(chord_x <= tangent_x);
x_begin = chord_x; x_end = tangent_x;
y_begin = KthPow(x_begin); y_end = KthPow(x_end);
//ASSERT(y_begin <= n);
//ASSERT(y_end > n);
}
for (size_t i = 0; x_begin <= x_end; ++x_begin, ++i)
if (x_begin * x_begin > n) {
if (i == 0)
break;
else
return x_begin - 1;
}
ASSERT(false);
return 0;
}
mpz_class FromLimbs(uint64_t * limbs, uint64_t * cnt) {
mpz_class r;
mpz_import(r.get_mpz_t(), *cnt, -1, 8, -1, 0, limbs);
return r;
}
void ToLimbs(mpz_class const & n, uint64_t * limbs, uint64_t * cnt) {
uint64_t cnt_before = *cnt;
size_t cnt_res = 0;
mpz_export(limbs, &cnt_res, -1, 8, -1, 0, n.get_mpz_t());
ASSERT(cnt_res <= cnt_before);
std::memset(limbs + cnt_res, 0, (cnt_before - cnt_res) * 8);
*cnt = cnt_res;
}
void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(KthRoot_ChordTangent<mpz_class>(FromLimbs(limbs, cnt), 2), limbs, cnt);
}
void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_GMP<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_AndersKaseorg<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_Babylonian<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
// Testing
#include <chrono>
#include <random>
#include <vector>
#include <iomanip>
inline double Time() {
static auto const gtb = std::chrono::high_resolution_clock::now();
return std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::high_resolution_clock::now() - gtb)
.count();
}
template <typename T, typename F>
std::vector<T> Test0(std::string const & test_name, size_t bits, size_t ntests, F && f) {
std::mt19937_64 rng{123};
std::vector<T> nums;
for (size_t i = 0; i < ntests; ++i) {
T n = 0;
for (size_t j = 0; j < bits; j += 32) {
size_t const cbits = std::min<size_t>(32, bits - j);
n <<= cbits;
n ^= u32(rng()) >> (32 - cbits);
}
nums.push_back(n);
}
auto tim = Time();
for (auto & n: nums)
n = f(n);
tim = Time() - tim;
std::cout << "Test " << std::setw(15) << ("'" + test_name + "'")
<< ", bits " << std::setw(6) << bits << ", time "
<< std::fixed << std::setprecision(6) << std::setw(9) << tim / ntests << " sec" << std::endl;
return nums;
}
void Test() {
auto f = [](auto ty, size_t bits, size_t ntests){
using T = std::decay_t<decltype(ty)>;
auto tim = Time();
auto a = Test0<T>("GMP", bits, ntests, [](auto const & x){ return ISqrt_GMP<T>(x); });
auto b = Test0<T>("AndersKaseorg", bits, ntests, [](auto const & x){ return ISqrt_AndersKaseorg<T>(x); });
ASSERT(b == a);
auto c = Test0<T>("Babylonian", bits, ntests, [](auto const & x){ return ISqrt_Babylonian<T>(x); });
ASSERT(c == a);
auto d = Test0<T>("ChordTangent", bits, ntests, [](auto const & x){ return KthRoot_ChordTangent<T>(x); });
ASSERT(d == a);
std::cout << "Bits " << bits << " nums " << ntests << " time " << std::fixed << std::setprecision(1) << (Time() - tim) << " sec" << std::endl;
};
for (auto p: std::vector<std::pair<int, int>>{{15, 1 << 19}, {30, 1 << 19}})
f(u64(), p.first, p.second);
for (auto p: std::vector<std::pair<int, int>>{{64, 1 << 15}, {8192, 1 << 10}, {50000, 1 << 5}})
f(mpz_class(), p.first, p.second);
}
int main() {
try {
Test();
return 0;
} catch (std::exception const & ex) {
std::cout << "Exception: " << ex.what() << std::endl;
return -1;
}
}

Your function fails for large inputs:
In [26]: isqrt((10**100+1)**2)
ValueError: input was not a perfect square
There is a recipe on the ActiveState site which should hopefully be more reliable since it uses integer maths only. It is based on an earlier StackOverflow question: Writing your own square root function

Floats cannot be precisely represented on computers. You can test for a desired proximity setting epsilon to a small value within the accuracy of python's floats.
def isqrt(n):
epsilon = .00000000001
i = int(n**.5 + 0.5)
if abs(i**2 - n) < epsilon:
return i
raise ValueError('input was not a perfect square')

Try this condition (no additional computation):
def isqrt(n):
i = math.sqrt(n)
if i != int(i):
raise ValueError('input was not a perfect square')
return i
If you need it to return an int (not a float with a trailing zero) then either assign a 2nd variable or compute int(i) twice.

I have compared the different methods given here with a loop:
for i in range (1000000): # 700 msec
r=int(123456781234567**0.5+0.5)
if r**2==123456781234567:rr=r
else:rr=-1
finding that this one is fastest and need no math-import. Very long might fail, but look at this
15241576832799734552675677489**0.5 = 123456781234567.0

Related

Is there an efficient algorithm to compute the Jacobsthal matrix or quadratic character in GF(q)?

Is there an efficient algorithm to compute the Jacobsthal matrix [WP] or equivalently the quadratic character χ in GF(q),
J [ i, j ] = χ ( i - j ) = 0 if i = j else 1 if i - j is a square in GF(q) else -1,
where i, j run over the elements of GF(q)?
The order of the elements <=> rows/columns does not really matter, so it's mainly to know whether an element of GF(q) is a square.
Unfortunately, when q = p n with n > 1, one cannot just take i, j ∈ Z/qZ (which works well iff q is a prime <=> n = 1).
On the other hand, implementing arithmetics in GF(q) appears a nontrivial task to me, at least the naive way (constructing an irreducible polynomial P of degree n over Z/pZ and implementing multiplication through multiplication of polynomials modulo P...).
The problem is easily solved in Python using the galois package (see here), but this is quite heavy artillery which I'd like to avoid to deploy.
Of course dedicated number theory software may also have GF arithmetics implemented. But I needed this just to produce Hadamard matrices through the Paley construction [WP], so I'd like to be able to compute this without using sophisticated software (and anyway I think it would be interesting to know whether there's a simple algorithm to do this).
Since we only need to know which elements are squares, I hoped there might be an efficient way to determine that.
EDIT: Let me clarify again that the question is whether there exists an efficient way of implementing this function (for arbitrary q = p k) without implementing general arithmetic in GF(q). It's not difficult to solve the problem using dedicated software: For example, Python's galois package provides the is_quadratic_residue() function which immediately gives the matrix elements - in spite of its name, since quadratic residues (mod p^k) aren't the same as squares in GF(p^k): Indeed, default modular arithmetic, i.e., issquare(Mod(i-j, p^k)), will usually yield incorrect results for when k > 1. For example, in G(2^k) every element is a square, but 2 and 3 aren't squares mod 2^2). A crude check is to compute J JT which should equal q I - U (for p > 2) where U is the "all 1s" matrix.)
Here is a basic math table setup for GF(3^4) based on 1 x^4 + 1 x^3 + 1 x^2 + 2 x + 2. At the end of this answer is a brute force search for any primitive polynomial (where powers of 3 map all non-zero elements). Numbers are stored as integer equivalents, for example, x^3 + 2 x + 1 = 1*(3^3) + 2*(3) + 1 = 16, so I store this as 16. Add and subtract just map from integer to vector and back. Multiply and divide use exp and log tables. Exp table is generated by taking powers of 3 (multiplying by x). Log table is the reverse mapped exp table. InitGF initializes the exp table using GFMpyA (multiply by alpha == multiply by x). Showing the math starting at integer 27 = 1 x^3 * x, showing the long hand division of
ex = e0 * x modulo polynomial
1 q = 1 = quotient
-----------
1 1 1 2 2 | 1 0 0 0 0 poly | ex
1 1 1 2 2 poly * q
---------
2 2 1 1 remainder
2 q = 2 = quotient
-----------
1 1 1 2 2 | 2 2 1 1 0 poly | ex
2 2 2 1 1 poly * 2
---------
0 2 0 2 remainder
Basic math code with initialization:
typedef unsigned char BYTE;
/* GFS(3) */
#define GFS 3
/* GF(3^2) */
#define GF 81
/* alpha = 1x + 0 */
#define ALPHA 3
typedef struct{ /* element of field */
int d; /* = dx^3 + cx^2 + bx + a */
int c;
int b;
int a;
}ELEM;
typedef struct{ /* extended element of field */
int e; /* = ex^4 + dx^3 + cx^2 +bx + a */
int d;
int c;
int b;
int a;
}ELEMX;
/*----------------------------------------------------------------------*/
/* GFAdd(i0, i1) */
/*----------------------------------------------------------------------*/
static int GFAdd(int i0, int i1)
{
ELEM e0, e1;
e0 = aiI2E[i0];
e1 = aiI2E[i1];
e0.d = (e0.d + e1.d);
if(e0.d >= GFS)e0.d -= GFS;
e0.c = (e0.c + e1.c);
if(e0.c >= GFS)e0.c -= GFS;
e0.b = (e0.b + e1.b);
if(e0.b >= GFS)e0.b -= GFS;
e0.a = (e0.a + e1.a);
if(e0.a >= GFS)e0.a -= GFS;
return (((((e0.d*GFS)+e0.c)*GFS)+e0.b)*GFS)+e0.a;
}
/*----------------------------------------------------------------------*/
/* GFSub(i0, i1) */
/*----------------------------------------------------------------------*/
static int GFSub(int i0, int i1)
{
ELEM e0, e1;
e0 = aiI2E[i0];
e1 = aiI2E[i1];
e0.d = (e0.d - e1.d);
if(e0.d < 0)e0.d += GFS;
e0.c = (e0.c - e1.c);
if(e0.c < 0)e0.c += GFS;
e0.b = (e0.b - e1.b);
if(e0.b < 0)e0.b += GFS;
e0.a = (e0.a - e1.a);
if(e0.a < 0)e0.a += GFS;
return (((((e0.d*GFS)+e0.c)*GFS)+e0.b)*GFS)+e0.a;
}
/*----------------------------------------------------------------------*/
/* GFMpy(i0, i1) i0*i1 using logs */
/*----------------------------------------------------------------------*/
static int GFMpy(int i0, int i1)
{
if(i0 == 0 || i1 == 0)
return(0);
return(aiExp[aiLog[i0]+aiLog[i1]]);
}
/*----------------------------------------------------------------------*/
/* GFDiv(i0, i1) i0/i1 */
/*----------------------------------------------------------------------*/
static int GFDiv(int i0, int i1)
{
if(i0 == 0)
return(0);
return(aiExp[(GF-1)+aiLog[i0]-aiLog[i1]]);
}
/*----------------------------------------------------------------------*/
/* GFPow(i0, i1) i0^i1 */
/*----------------------------------------------------------------------*/
static int GFPow(int i0, int i1)
{
if(i1 == 0)
return (1);
if(i0 == 0)
return (0);
return(aiExp[(aiLog[i0]*i1)%(GF-1)]);
}
/*----------------------------------------------------------------------*/
/* GFMpyA(i0) i0*ALPHA using low level math */
/*----------------------------------------------------------------------*/
/* hard coded for elements of size 4 */
static int GFMpyA(int i0)
{
ELEM e0;
ELEMX ex;
int q; /* quotient */
e0 = aiI2E[i0]; /* e0 = i0 split up */
ex.e = e0.d; /* ex = e0*x */
ex.d = e0.c;
ex.c = e0.b;
ex.b = e0.a;
ex.a = 0;
q = ex.e;
/* ex.e -= q * pGFPoly.aata[0] % GFS; ** always == 0 */
/* if(ex.e < 0)ex.d += GFS; ** always == 0 */
ex.d -= q * pGFPoly.data[1] % GFS;
if(ex.d < 0)ex.d += GFS;
ex.c -= q * pGFPoly.data[2] % GFS;
if(ex.c < 0)ex.c += GFS;
ex.b -= q * pGFPoly.data[3] % GFS;
if(ex.b < 0)ex.b += GFS;
ex.a -= q * pGFPoly.data[4] % GFS;
if(ex.a < 0)ex.a += GFS;
return (((((ex.d*GFS)+ex.c)*GFS)+ex.b)*GFS)+ex.a;
}
/*----------------------------------------------------------------------*/
/* InitGF Initialize Galios Stuff */
/*----------------------------------------------------------------------*/
static void InitGF(void)
{
int i;
int t;
for(i = 0; i < GF; i++){ /* init index to element table */
t = i;
aiI2E[i].a = t%GFS;
t /= GFS;
aiI2E[i].b = t%GFS;
t /= GFS;
aiI2E[i].c = t%GFS;
t /= GFS;
aiI2E[i].d = t;
}
pGFPoly.size = 5; /* init GF() polynomial */
pGFPoly.data[0] = 1;
pGFPoly.data[1] = 1;
pGFPoly.data[2] = 1;
pGFPoly.data[3] = 2;
pGFPoly.data[4] = 2;
t = 1; /* init aiExp[] */
for(i = 0; i < GF*2; i++){
aiExp[i] = t;
t = GFMpyA(t);
}
aiLog[0] = -1; /* init aiLog[] */
for(i = 0; i < GF-1; i++)
aiLog[aiExp[i]] = i;
}
/*----------------------------------------------------------------------*/
/* main */
/*----------------------------------------------------------------------*/
int main()
{
InitGF();
return(0);
}
Code to display a list of primitive polynomials for GF(3^4)
pGFPoly.size = 5; /* display primitive polynomials */
pGFPoly.data[0] = 1;
pGFPoly.data[1] = 0;
pGFPoly.data[2] = 0;
pGFPoly.data[3] = 0;
pGFPoly.data[4] = 1;
while(1){
i = 0;
t = 1;
do{
i++;
t = GFMpyA(t);}
while(t != 1);
if(i == (GF-1)){
printf("pGFPoly: ");
ShowVector(&pGFPoly);}
pGFPoly.data[4] += 1;
if(pGFPoly.data[4] == GFS){
pGFPoly.data[4] = 1;
pGFPoly.data[3] += 1;
if(pGFPoly.data[3] == GFS){
pGFPoly.data[3] = 0;
pGFPoly.data[2] += 1;
if(pGFPoly.data[2] == GFS){
pGFPoly.data[2] = 0;
pGFPoly.data[1] += 1;
if(pGFPoly.data[1] == GFS){
break;}}}}}
This produces this list:
1 0 0 1 2 The one normally used x^4 + x + 2
1 0 0 2 2
1 1 0 0 2
1 1 1 2 2 I used this to test all 5 terms
1 1 2 2 2
1 2 0 0 2
1 2 1 1 2
1 2 2 1 2

Unusual behaviour of Ant Colony Optimization for Closest String Problem in Python and C++

This is probably going to be a long question, I apologize in advance.
I'm working on a project with the goal of researching different solutions for the closest string problem.
Let s_1, ... s_n be strings of length m. Find a string s of length m such that it minimizes max{d(s, s_i) | i = 1, ..., n}, where d is the hamming distance.
One solution that has been tried is one using ant colony optimization, as decribed here.
The paper itself does not go into implementation details, so I've done my best on efficiency. However, efficiency is not the only unusual behaviour.
I'm not sure whether it's common pratice to do so, but I will present my code through pastebin since I believe it would overwhelm the thread if I should put it directly here. If that turns out to be a problem, I won't mind editing the thread to put it here directly. As all the previous algorithms I've experimented with, I've written this one in python initially. Here's the code:
def solve_(self, problem: CSProblem) -> CSSolution:
m, n, alphabet, strings = problem.m, problem.n, problem.alphabet, problem.strings
A = len(alphabet)
rho = self.config['RHO']
colony_size = self.config['COLONY_SIZE']
global_best_ant = None
global_best_metric = m
ants = np.full((colony_size, m), '')
world_trails = np.full((m, A), 1 / A)
for iteration in range(self.config['MAX_ITERS']):
local_best_ant = None
local_best_metric = m
for ant_idx in range(colony_size):
for next_character_index in range(m):
ants[ant_idx][next_character_index] = random.choices(alphabet, weights=world_trails[next_character_index], k=1)[0]
ant_metric = utils.problem_metric(ants[ant_idx], strings)
if ant_metric < local_best_metric:
local_best_metric = ant_metric
local_best_ant = ants[ant_idx]
# First we perform pheromone evaporation
for i in range(m):
for j in range(A):
world_trails[i][j] = world_trails[i][j] * (1 - rho)
# Now, using the elitist strategy, only the best ant is allowed to update his pheromone trails
best_ant_ys = (alphabet.index(a) for a in local_best_ant)
best_ant_xs = range(m)
for x, y in zip(best_ant_xs, best_ant_ys):
world_trails[x][y] = world_trails[x][y] + (1 - local_best_metric / m)
if local_best_metric < global_best_metric:
global_best_metric = local_best_metric
global_best_ant = local_best_ant
return CSSolution(''.join(global_best_ant), global_best_metric)
The utils.problem_metric function looks like this:
def hamming_distance(s1, s2):
return sum(c1 != c2 for c1, c2 in zip(s1, s2))
def problem_metric(string, references):
return max(hamming_distance(string, r) for r in references)
I've seen that there are a lot more tweaks and other parameters you can add to ACO, but I've kept it simple for now. The configuration I'm using is is 250 iterations, colony size od 10 ants and rho=0.1. The problem that I'm testing it on is from here: http://tcs.informatik.uos.de/research/csp_cssp , the one called 2-10-250-1-0.csp (the first one). The alphabet consists only of '0' and '1', the strings are of length 250, and there are 10 strings in total.
For the ACO configuration that I've mentioned, this problem, using the python solver, gets solved on average in 5 seconds, and the average target function value is 108.55 (simulated 20 times). The correct target function value is 96. Ironically, the 5-second average is good compared to what it used to be in my first attempt of implementing this solution. However, it's still surprisingly slow.
After doing all kinds of optimizations, I've decided to try and implement the exact same solution in C++ so see whether there will be a significant difference between the running times. Here's the C++ solution:
#include <iostream>
#include <vector>
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <sstream>
#include <string>
#include <random>
#include <chrono>
#include <map>
class CSPProblem{
public:
int m;
int n;
std::vector<char> alphabet;
std::vector<std::string> strings;
CSPProblem(int m, int n, std::vector<char> alphabet, std::vector<std::string> strings)
: m(m), n(n), alphabet(alphabet), strings(strings)
{
}
static CSPProblem from_csp(std::string filepath){
std::ifstream file(filepath);
std::string line;
std::vector<std::string> input_lines;
while (std::getline(file, line)){
input_lines.push_back(line);
}
int alphabet_size = std::stoi(input_lines[0]);
int n = std::stoi(input_lines[1]);
int m = std::stoi(input_lines[2]);
std::vector<char> alphabet;
for (int i = 3; i < 3 + alphabet_size; i++){
alphabet.push_back(input_lines[i][0]);
}
std::vector<std::string> strings;
for (int i = 3 + alphabet_size; i < input_lines.size(); i++){
strings.push_back(input_lines[i]);
}
return CSPProblem(m, n, alphabet, strings);
}
int hamm(const std::string& s1, const std::string& s2) const{
int h = 0;
for (int i = 0; i < s1.size(); i++){
if (s1[i] != s2[i])
h++;
}
return h;
}
int measure(const std::string& sol) const{
int mm = 0;
for (const auto& s: strings){
int h = hamm(sol, s);
if (h > mm){
mm = h;
}
}
return mm;
}
friend std::ostream& operator<<(std::ostream& out, CSPProblem problem){
out << "m: " << problem.m << std::endl;
out << "n: " << problem.n << std::endl;
out << "alphabet_size: " << problem.alphabet.size() << std::endl;
out << "alphabet: ";
for (const auto& a: problem.alphabet){
out << a << " ";
}
out << std::endl;
out << "strings:" << std::endl;
for (const auto& s: problem.strings){
out << "\t" << s << std::endl;
}
return out;
}
};
std::random_device rd;
std::mt19937 gen(rd());
int get_from_distrib(const std::vector<float>& weights){
std::discrete_distribution<> d(std::begin(weights), std::end(weights));
return d(gen);
}
int max_iter = 250;
float rho = 0.1f;
int colony_size = 10;
int ant_colony_solver(const CSPProblem& problem){
srand(time(NULL));
int m = problem.m;
int n = problem.n;
auto alphabet = problem.alphabet;
auto strings = problem.strings;
int A = alphabet.size();
float init_pher = 1.0 / A;
std::string global_best_ant;
int global_best_matric = m;
std::vector<std::vector<float>> world_trails(m, std::vector<float>(A, 0.0f));
for (int i = 0; i < m; i++){
for (int j = 0; j < A; j++){
world_trails[i][j] = init_pher;
}
}
std::vector<std::string> ants(colony_size, std::string(m, ' '));
for (int iteration = 0; iteration < max_iter; iteration++){
std::string local_best_ant;
int local_best_metric = m;
for (int ant_idx = 0; ant_idx < colony_size; ant_idx++){
for (int next_character_idx = 0; next_character_idx < m; next_character_idx++){
char next_char = alphabet[get_from_distrib(world_trails[next_character_idx])];
ants[ant_idx][next_character_idx] = next_char;
}
int ant_metric = problem.measure(ants[ant_idx]);
if (ant_metric < local_best_metric){
local_best_metric = ant_metric;
local_best_ant = ants[ant_idx];
}
}
// Evaporation
for (int i = 0; i < m; i++){
for (int j = 0; j < A; j++){
world_trails[i][j] = world_trails[i][j] + (1.0 - rho);
}
}
std::vector<int> best_ant_xs;
for (int i = 0; i < m; i++){
best_ant_xs.push_back(i);
}
std::vector<int> best_ant_ys;
for (const auto& c: local_best_ant){
auto loc = std::find(std::begin(alphabet), std::end(alphabet), c);
int idx = loc- std::begin(alphabet);
best_ant_ys.push_back(idx);
}
for (int i = 0; i < m; i++){
int x = best_ant_xs[i];
int y = best_ant_ys[i];
world_trails[x][y] = world_trails[x][y] + (1.0 - static_cast<float>(local_best_metric) / m);
}
if (local_best_metric < global_best_matric){
global_best_matric = local_best_metric;
global_best_ant = local_best_ant;
}
}
return global_best_matric;
}
int main(){
auto problem = CSPProblem::from_csp("in.csp");
int TRIES = 20;
std::vector<int> times;
std::vector<int> measures;
for (int i = 0; i < TRIES; i++){
auto start = std::chrono::high_resolution_clock::now();
int m = ant_colony_solver(problem);
auto stop = std::chrono::high_resolution_clock::now();
int duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start).count();
times.push_back(duration);
measures.push_back(m);
}
float average_time = static_cast<float>(std::accumulate(std::begin(times), std::end(times), 0)) / TRIES;
float average_measure = static_cast<float>(std::accumulate(std::begin(measures), std::end(measures), 0)) / TRIES;
std::cout << "Average running time: " << average_time << std::endl;
std::cout << "Average solution: " << average_measure << std::endl;
std::cout << "all solutions: ";
for (const auto& m: measures) std::cout << m << " ";
std::cout << std::endl;
return 0;
}
The average running time now is only 530.4 miliseconds. However, the average target function value is 122.75, which is significantly higher than that of the python solution.
If the average function values were the same, and the times were as they are, I would simply write this off as 'C++ is faster than python' (even though the difference in speed is also very suspiscious). But, since C++ yields worse solutions, it leads me to believe that I've done something wrong in C++. What I'm suspiscious of is the way I'm generating an alphabet index using weights. In python I've done it using random.choices as follows:
ants[ant_idx][next_character_index] = random.choices(alphabet, weights=world_trails[next_character_index], k=1)[0]
As for C++, I haven't done it in a while so I'm a bit rusty on reading cppreference (which is a skill of its own), and the std::discrete_distribution solution is something I've plain copied from the reference:
std::random_device rd;
std::mt19937 gen(rd());
int get_from_distrib(const std::vector<float>& weights){
std::discrete_distribution<> d(std::begin(weights), std::end(weights));
return d(gen);
}
The suspiscious thing here is the fact that I'm declaring the std::random_device and std::mt19937 objects globally and using the same ones every time. I have not been able to find an answer to whether this is the way they're meant to be used. However, if I put them in the function:
int get_from_distrib(const std::vector<float>& weights){
std::random_device rd;
std::mt19937 gen(rd());
std::discrete_distribution<> d(std::begin(weights), std::end(weights));
return d(gen);
}
the average running time gets significantly worse, clocking in at 8.84 seconds. However, even more surprisingly, the average function value gets worse as well, at 130.
Again, if only one of the two things changed (say if only the time went up) I would have been able to draw some conclusions. This way it only gets more confusing.
So, does anybody have an idea of why this is happening?
Thanks in advance.
MAJOR EDIT: I feel embarrased having asked such a huge question when in fact the problem lies in a simple typo. Namely in the evaporation step in the C++ version I put a + instead of a *.
Now the algorithms behave identically in terms of average solution quality.
However, I could still use some tips on how to optimize the python version.
Apart form the dumb mistake I've mentioned in the question edit, it seems I've finally found a way to optimize the python solution decently. First of all, keeping world_trails and ants as numpy arrays instead of lists of lists actually slowed things down. Furthermore, I actually stopped keeping a list of ants altogether since I only ever need the best one per iteration.
Lastly, running cProfile indicated that a lot of the time was spent on random.choices, therefore I've decided to implement my own version of it suited specifically for this case. I've done this by pre-computing total weight sum per character for each next iteration (in the trail_row_wise_sums array), and using the following function:
def fast_pick(arr, weights, ws):
r = random.random()*ws
for i in range(len(arr)):
if r < weights[i]:
return arr[i]
r -= weights[i]
return 0
The new version now looks like this:
def solve_(self, problem: CSProblem) -> CSSolution:
m, n, alphabet, strings = problem.m, problem.n, problem.alphabet, problem.strings
A = len(alphabet)
rho = self.config['RHO']
colony_size = self.config['COLONY_SIZE']
miters = self.config['MAX_ITERS']
global_best_ant = None
global_best_metric = m
init_pher = 1.0 / A
world_trails = [[init_pher for _ in range(A)] for _ in range(m)]
trail_row_wise_sums = [1.0 for _ in range(m)]
for iteration in tqdm(range(miters)):
local_best_ant = None
local_best_metric = m
for _ in range(colony_size):
ant = ''.join(fast_pick(alphabet, world_trails[next_character_index], trail_row_wise_sums[next_character_index]) for next_character_index in range(m))
ant_metric = utils.problem_metric(ant, strings)
if ant_metric <= local_best_metric:
local_best_metric = ant_metric
local_best_ant = ant
# First we perform pheromone evaporation
for i in range(m):
for j in range(A):
world_trails[i][j] = world_trails[i][j] * (1 - rho)
# Now, using the elitist strategy, only the best ant is allowed to update his pheromone trails
best_ant_ys = (alphabet.index(a) for a in local_best_ant)
best_ant_xs = range(m)
for x, y in zip(best_ant_xs, best_ant_ys):
world_trails[x][y] = world_trails[x][y] + (1 - 1.0*local_best_metric / m)
if local_best_metric < global_best_metric:
global_best_metric = local_best_metric
global_best_ant = local_best_ant
trail_row_wise_sums = [sum(world_trails[i]) for i in range(m)]
return CSSolution(global_best_ant, global_best_metric)
The average running time is now down to 800 miliseconds (compared to 5 seconds that it was before). Granted, applying the same fast_pick optimization to the C++ solution did also speed up the C++ version (around 150 ms) but I guess now I can write it off as C++ being faster than python.
Profiler also showed that a lot of the time was spent on calculating Hamming distances, but that's to be expected, and apart from that I see no other way of computing the Hamming distance between arbitrary strings more efficiently.

Discrepancy between a program in Python and C++

I wrote a code in C++ to solve a programming challenge on Project Euler (problem 65: https://projecteuler.net/problem=65), but have not been able to get the correct answer. To check my algorithm, I wrote the "same" code in Python (same as far as I can tell, but there's obviously some difference), but I got a different answer (the correct answer).
The basic idea behind the program is to find the sum of digits in the numerator of the 100th convergent of the continued fraction for e. Representing the continuing fraction as "e = [2;1,2,1,1,4,1,1,6,1,...,1,2k,1,...]"--this is the vector e_terms that I fill in my code.
My question is why does my C++ code give me a different answer than Python? (Again, I already know the answer to the challenge, I just want to understand the difference between my codes.) Both codes give me the same first 10 approximations of the continued fraction, and up to the 45th approximation; but from 46 and on, my codes give me different values. Is it because a long double doesn't store integers like Python does? Or is there some kind of flaw in my C++ code?
Thanks in advance for the help! And let me know if there is something I should clarify about my code.
C++ code
#include <iostream>
#include <vector>
#include "math.h"
using namespace std;
void PrintDouble(long double num)
{
int val = 0;
for (int i = floor(log(num)/log(10)); i >= 0; i--)
{
val = floor(num/(pow(10,i)));
num -= val*pow(10,i);
cout << val;
}
}
void Convergents_e(int max_val)
{
long double numer = 0, denom = 1, temp;
vector<int> e_terms;
if (max_val > 1) e_terms.push_back(1);
if (max_val > 2) e_terms.push_back(2);
if (max_val > 3)
{
for (int i = 1; i <= max_val-3; i++)
{
if (i%3 == 0) e_terms.push_back((int)((i/3+1)*2));
else e_terms.push_back(1);
cout << e_terms.back() << endl;
}
}
for (vector<int>::reverse_iterator it = e_terms.rbegin(); it != e_terms.rend(); it++)
{
numer += ((double)(*it) * denom);
cout << "\tnumer = ";
PrintDouble(numer);
cout << endl;
swap(numer, denom);
}
numer += 2*denom; // +1 term
cout << "e = " << numer << "/" << denom << " = " << numer/denom << endl;
int val = 0, sum = 0;
for (int i = floor(log(numer)/log(10)); i >= 0; i--)
{
val = floor(numer/(pow(10,i)));
numer -= val*pow(10,i);
sum += val;
cout << val;
}
cout << endl << "The total of the digits is " << sum << endl << endl;
}
int main()
{
Convergents_e(100);
}
Python code
#!/usr/bin/env python3
import math as m
max_val = 100
numer = 0
denom = 1
e_terms = []
digits = []
if (max_val > 1): e_terms.append(1)
if (max_val > 2): e_terms.append(2)
if (max_val > 3):
for i in range(1,max_val-3 + 1):
if (i%3 == 0): e_terms.append(int((i/3+1)*2))
else: e_terms.append(1)
print(e_terms[-1])
for e_t in reversed(e_terms):
numer += e_t * denom
print("\tnumer =", numer)
numer, denom = denom, numer
numer += 2*denom # +1 term
print("e =", numer, "/", denom, "=", numer/denom)
val = 0
sum = 0
for i in reversed(range(0,m.ceil(m.log(numer)/m.log(10)))):
val = m.floor(numer/(10**i))
numer -= (val*(10**i))
digits.append(val)
sum += val
#print(val)
print( "The total of the digits is",sum)

Knapsack Problem: Why do I need a 2 dimensional DP Matrix

I came across some classical Knapsack solutions and they always build a 2-dimensional DP array.
In my opinion, my code below solves the classical knapsack problem but with only a 1-dim DP array.
Can someone tell me where my solution does not work or why it is computationally inefficient compared to the 2D-DP version?
A 2D-DP version can be found here
https://www.geeksforgeeks.org/python-program-for-dynamic-programming-set-10-0-1-knapsack-problem/
example input:
weights = [(3,30),(2,20),(1,50),(4,30)]
constraint = 5
And my solution:
def knapsack(weights,constraint):
n = len(weights)
#define dp array
dp = [0]*(constraint+1)
#start filling in the array
for k in weights:
for i in range(constraint,k[0]-1,-1):
dp[i] = max(dp[i],dp[i-k[0]]+k[1])
return dp[constraint]
The version using O(nW) memory is more intuitive and makes it possible to easily retrieve the subset of items that produce the optimal answer value.
But, using O(n + W) of memory, we cannot retrieve this subset directly. While it is possible to do this, using the divide-and-conquer technique as explained in https://codeforces.com/blog/entry/47247?#comment-316200.
Sample code
#include <bits/stdc++.h>
using namespace std;
using vi = vector<int>;
#define FOR(i, b) for(int i = 0; i < (b); i++)
template<class T>
struct Knapsack{
int n, W;
vector<T> dp, vl;
vi ans, opt, wg;
Knapsack(int n_, int W): n(0), W(W),
dp(W + 1), vl(n_), opt(W + 1), wg(n_){}
void Add(T v, int w){
vl[n] = v;
wg[n++] = w;
}
T conquer(int l, int r, int W){
if(l == r){
if(W >= wg[l])
return ans.push_back(l), vl[l];
return 0;
}
FOR(i, W + 1)
opt[i] = dp[i] = 0;
int m = (l + r) >> 1;
for(int i = l; i <= r; i++)
for(int sz = W; sz >= wg[i]; sz--){
T dpCur = dp[sz - wg[i]] + vl[i];
if(dpCur > dp[sz]){
dp[sz] = dpCur;
opt[sz] = i <= m ? sz : opt[sz - wg[i]];
}
}
T ret = dp[W];
int K = opt[W];
T ret2 = conquer(l, m, K) + conquer(m + 1, r, W - K);
assert(ret2 == ret);
return ret;
}
T Solve(){
return conquer(0, n - 1, W);
}
};
int main(){
cin.tie(0)->sync_with_stdio(0);
int n, W, vl, wg;
cin >> n >> W;
Knapsack<int> ks(n, W);
FOR(i, n){
cin >> vl >> wg;
ks.Add(vl, wg);
}
cout << ks.Solve() << endl;
}

Integer square root in python

Is there an integer square root somewhere in python, or in standard libraries? I want it to be exact (i.e. return an integer), and raise an exception if the input isn't a perfect square.
I tried using this code:
def isqrt(n):
i = int(math.sqrt(n) + 0.5)
if i**2 == n:
return i
raise ValueError('input was not a perfect square')
But it's ugly and I don't really trust it for large integers. I could iterate through the squares and give up if I've exceeded the value, but I assume it would be kinda slow to do something like that. Also, surely this is already implemented somewhere?
See also: Check if a number is a perfect square.
Note: There is now math.isqrt in stdlib, available since Python 3.8.
Newton's method works perfectly well on integers:
def isqrt(n):
x = n
y = (x + 1) // 2
while y < x:
x = y
y = (x + n // x) // 2
return x
This returns the largest integer x for which x * x does not exceed n. If you want to check if the result is exactly the square root, simply perform the multiplication to check if n is a perfect square.
I discuss this algorithm, and three other algorithms for calculating square roots, at my blog.
Update: Python 3.8 has a math.isqrt function in the standard library!
I benchmarked every (correct) function here on both small (0…222) and large (250001) inputs. The clear winners in both cases are gmpy2.isqrt suggested by mathmandan in first place, followed by Python 3.8’s math.isqrt in second, followed by the ActiveState recipe linked by NPE in third. The ActiveState recipe has a bunch of divisions that can be replaced by shifts, which makes it a bit faster (but still behind the native functions):
def isqrt(n):
if n > 0:
x = 1 << (n.bit_length() + 1 >> 1)
while True:
y = (x + n // x) >> 1
if y >= x:
return x
x = y
elif n == 0:
return 0
else:
raise ValueError("square root not defined for negative numbers")
Benchmark results:
gmpy2.isqrt() (mathmandan): 0.08 µs small, 0.07 ms large
int(gmpy2.isqrt())*: 0.3 µs small, 0.07 ms large
Python 3.8 math.isqrt: 0.13 µs small, 0.9 ms large
ActiveState (optimized as above): 0.6 µs small, 17.0 ms large
ActiveState (NPE): 1.0 µs small, 17.3 ms large
castlebravo long-hand: 4 µs small, 80 ms large
mathmandan improved: 2.7 µs small, 120 ms large
martineau (with this correction): 2.3 µs small, 140 ms large
nibot: 8 µs small, 1000 ms large
mathmandan: 1.8 µs small, 2200 ms large
castlebravo Newton’s method: 1.5 µs small, 19000 ms large
user448810: 1.4 µs small, 20000 ms large
(* Since gmpy2.isqrt returns a gmpy2.mpz object, which behaves mostly but not exactly like an int, you may need to convert it back to an int for some uses.)
Sorry for the very late response; I just stumbled onto this page. In case anyone visits this page in the future, the python module gmpy2 is designed to work with very large inputs, and includes among other things an integer square root function.
Example:
>>> import gmpy2
>>> gmpy2.isqrt((10**100+1)**2)
mpz(10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001L)
>>> gmpy2.isqrt((10**100+1)**2 - 1)
mpz(10000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000L)
Granted, everything will have the "mpz" tag, but mpz's are compatible with int's:
>>> gmpy2.mpz(3)*4
mpz(12)
>>> int(gmpy2.mpz(12))
12
See my other answer for a discussion of this method's performance relative to some other answers to this question.
Download: https://code.google.com/p/gmpy/
Here's a very straightforward implementation:
def i_sqrt(n):
i = n.bit_length() >> 1 # i = floor( (1 + floor(log_2(n))) / 2 )
m = 1 << i # m = 2^i
#
# Fact: (2^(i + 1))^2 > n, so m has at least as many bits
# as the floor of the square root of n.
#
# Proof: (2^(i+1))^2 = 2^(2i + 2) >= 2^(floor(log_2(n)) + 2)
# >= 2^(ceil(log_2(n) + 1) >= 2^(log_2(n) + 1) > 2^(log_2(n)) = n. QED.
#
while m*m > n:
m >>= 1
i -= 1
for k in xrange(i-1, -1, -1):
x = m | (1 << k)
if x*x <= n:
m = x
return m
This is just a binary search. Initialize the value m to be the largest power of 2 that does not exceed the square root, then check whether each smaller bit can be set while keeping the result no larger than the square root. (Check the bits one at a time, in descending order.)
For reasonably large values of n (say, around 10**6000, or around 20000 bits), this seems to be:
Faster than the Newton's method implementation described by user448810.
Much, much slower than the gmpy2 built-in method in my other answer.
Comparable to, but somewhat slower than, the Longhand Square Root described by nibot.
All of these approaches succeed on inputs of this size, but on my machine, this function takes around 1.5 seconds, while #Nibot's takes about 0.9 seconds, #user448810's takes around 19 seconds, and the gmpy2 built-in method takes less than a millisecond(!). Example:
>>> import random
>>> import timeit
>>> import gmpy2
>>> r = random.getrandbits
>>> t = timeit.timeit
>>> t('i_sqrt(r(20000))', 'from __main__ import *', number = 5)/5. # This function
1.5102493192883117
>>> t('exact_sqrt(r(20000))', 'from __main__ import *', number = 5)/5. # Nibot
0.8952787937686366
>>> t('isqrt(r(20000))', 'from __main__ import *', number = 5)/5. # user448810
19.326695976676184
>>> t('gmpy2.isqrt(r(20000))', 'from __main__ import *', number = 5)/5. # gmpy2
0.0003599147067689046
>>> all(i_sqrt(n)==isqrt(n)==exact_sqrt(n)[0]==int(gmpy2.isqrt(n)) for n in (r(1500) for i in xrange(1500)))
True
This function can be generalized easily, though it's not quite as nice because I don't have quite as precise of an initial guess for m:
def i_root(num, root, report_exactness = True):
i = num.bit_length() / root
m = 1 << i
while m ** root < num:
m <<= 1
i += 1
while m ** root > num:
m >>= 1
i -= 1
for k in xrange(i-1, -1, -1):
x = m | (1 << k)
if x ** root <= num:
m = x
if report_exactness:
return m, m ** root == num
return m
However, note that gmpy2 also has an i_root method.
In fact this method could be adapted and applied to any (nonnegative, increasing) function f to determine an "integer inverse of f". However, to choose an efficient initial value of m you'd still want to know something about f.
Edit: Thanks to #Greggo for pointing out that the i_sqrt function can be rewritten to avoid using any multiplications. This yields an impressive performance boost!
def improved_i_sqrt(n):
assert n >= 0
if n == 0:
return 0
i = n.bit_length() >> 1 # i = floor( (1 + floor(log_2(n))) / 2 )
m = 1 << i # m = 2^i
#
# Fact: (2^(i + 1))^2 > n, so m has at least as many bits
# as the floor of the square root of n.
#
# Proof: (2^(i+1))^2 = 2^(2i + 2) >= 2^(floor(log_2(n)) + 2)
# >= 2^(ceil(log_2(n) + 1) >= 2^(log_2(n) + 1) > 2^(log_2(n)) = n. QED.
#
while (m << i) > n: # (m<<i) = m*(2^i) = m*m
m >>= 1
i -= 1
d = n - (m << i) # d = n-m^2
for k in xrange(i-1, -1, -1):
j = 1 << k
new_diff = d - (((m<<1) | j) << k) # n-(m+2^k)^2 = n-m^2-2*m*2^k-2^(2k)
if new_diff >= 0:
d = new_diff
m |= j
return m
Note that by construction, the kth bit of m << 1 is not set, so bitwise-or may be used to implement the addition of (m<<1) + (1<<k). Ultimately I have (2*m*(2**k) + 2**(2*k)) written as (((m<<1) | (1<<k)) << k), so it's three shifts and one bitwise-or (followed by a subtraction to get new_diff). Maybe there is still a more efficient way to get this? Regardless, it's far better than multiplying m*m! Compare with above:
>>> t('improved_i_sqrt(r(20000))', 'from __main__ import *', number = 5)/5.
0.10908999762373242
>>> all(improved_i_sqrt(n) == i_sqrt(n) for n in xrange(10**6))
True
Long-hand square root algorithm
It turns out that there is an algorithm for computing square roots that you can compute by hand, something like long-division. Each iteration of the algorithm produces exactly one digit of the resulting square root while consuming two digits of the number whose square root you seek. While the "long hand" version of the algorithm is specified in decimal, it works in any base, with binary being simplest to implement and perhaps the fastest to execute (depending on the underlying bignum representation).
Because this algorithm operates on numbers digit-by-digit, it produces exact results for arbitrarily large perfect squares, and for non-perfect-squares, can produce as many digits of precision (to the right of the decimal place) as desired.
There are two nice writeups on the "Dr. Math" site that explain the algorithm:
Square Roots in Binary
Longhand Square Roots
And here's an implementation in Python:
def exact_sqrt(x):
"""Calculate the square root of an arbitrarily large integer.
The result of exact_sqrt(x) is a tuple (a, r) such that a**2 + r = x, where
a is the largest integer such that a**2 <= x, and r is the "remainder". If
x is a perfect square, then r will be zero.
The algorithm used is the "long-hand square root" algorithm, as described at
http://mathforum.org/library/drmath/view/52656.html
Tobin Fricke 2014-04-23
Max Planck Institute for Gravitational Physics
Hannover, Germany
"""
N = 0 # Problem so far
a = 0 # Solution so far
# We'll process the number two bits at a time, starting at the MSB
L = x.bit_length()
L += (L % 2) # Round up to the next even number
for i in xrange(L, -1, -1):
# Get the next group of two bits
n = (x >> (2*i)) & 0b11
# Check whether we can reduce the remainder
if ((N - a*a) << 2) + n >= (a<<2) + 1:
b = 1
else:
b = 0
a = (a << 1) | b # Concatenate the next bit of the solution
N = (N << 2) | n # Concatenate the next bit of the problem
return (a, N-a*a)
You could easily modify this function to conduct additional iterations to calculate the fractional part of the square root. I was most interested in computing roots of large perfect squares.
I'm not sure how this compares to the "integer Newton's method" algorithm. I suspect that Newton's method is faster, since it can in principle generate multiple bits of the solution in one iteration, while the "long hand" algorithm generates exactly one bit of the solution per iteration.
Source repo: https://gist.github.com/tobin/11233492
One option would be to use the decimal module, and do it in sufficiently-precise floats:
import decimal
def isqrt(n):
nd = decimal.Decimal(n)
with decimal.localcontext() as ctx:
ctx.prec = n.bit_length()
i = int(nd.sqrt())
if i**2 != n:
raise ValueError('input was not a perfect square')
return i
which I think should work:
>>> isqrt(1)
1
>>> isqrt(7**14) == 7**7
True
>>> isqrt(11**1000) == 11**500
True
>>> isqrt(11**1000+1)
Traceback (most recent call last):
File "<ipython-input-121-e80953fb4d8e>", line 1, in <module>
isqrt(11**1000+1)
File "<ipython-input-100-dd91f704e2bd>", line 10, in isqrt
raise ValueError('input was not a perfect square')
ValueError: input was not a perfect square
Python's default math library has an integer square root function:
math.isqrt(n)
Return the integer square root of the nonnegative integer n. This is the floor of the exact square root of n, or equivalently the greatest integer a such that a² ≤ n.
Seems like you could check like this:
if int(math.sqrt(n))**2 == n:
print n, 'is a perfect square'
Update:
As you pointed out the above fails for large values of n. For those the following looks promising, which is an adaptation of the example C code, by Martin Guy # UKC, June 1985, for the relatively simple looking binary numeral digit-by-digit calculation method mentioned in the Wikipedia article Methods of computing square roots:
from math import ceil, log
def isqrt(n):
res = 0
bit = 4**int(ceil(log(n, 4))) if n else 0 # smallest power of 4 >= the argument
while bit:
if n >= res + bit:
n -= res + bit
res = (res >> 1) + bit
else:
res >>= 1
bit >>= 2
return res
if __name__ == '__main__':
from math import sqrt # for comparison purposes
for i in range(17)+[2**53, (10**100+1)**2]:
is_perfect_sq = isqrt(i)**2 == i
print '{:21,d}: math.sqrt={:12,.7G}, isqrt={:10,d} {}'.format(
i, sqrt(i), isqrt(i), '(perfect square)' if is_perfect_sq else '')
Output:
0: math.sqrt= 0, isqrt= 0 (perfect square)
1: math.sqrt= 1, isqrt= 1 (perfect square)
2: math.sqrt= 1.414214, isqrt= 1
3: math.sqrt= 1.732051, isqrt= 1
4: math.sqrt= 2, isqrt= 2 (perfect square)
5: math.sqrt= 2.236068, isqrt= 2
6: math.sqrt= 2.44949, isqrt= 2
7: math.sqrt= 2.645751, isqrt= 2
8: math.sqrt= 2.828427, isqrt= 2
9: math.sqrt= 3, isqrt= 3 (perfect square)
10: math.sqrt= 3.162278, isqrt= 3
11: math.sqrt= 3.316625, isqrt= 3
12: math.sqrt= 3.464102, isqrt= 3
13: math.sqrt= 3.605551, isqrt= 3
14: math.sqrt= 3.741657, isqrt= 3
15: math.sqrt= 3.872983, isqrt= 3
16: math.sqrt= 4, isqrt= 4 (perfect square)
9,007,199,254,740,992: math.sqrt=9.490627E+07, isqrt=94,906,265
100,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,020,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,001: math.sqrt= 1E+100, isqrt=10,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,000,001 (perfect square)
The script below extracts integer square roots. It uses no divisions, only bitshifts, so it is quite fast. It uses Newton's method on the inverse square root, a technique made famous by Quake III Arena as mentioned in the Wikipedia article, Fast inverse square root.
The strategy of the algorithm to compute s = sqrt(Y) is as follows.
Reduce the argument Y to y in the range [1/4, 1), i.e., y = Y/B, with 1/4 <= y < 1, where B is an even power of 2, so B = 2**(2*k) for some integer k. We want to find X, where x = X/B, and x = 1 / sqrt(y).
Determine a first approximation to X using a quadratic minimax polynomial.
Refine X using Newton's method.
Calculate s = X*Y/(2**(3*k)).
We don't actually create fractions or perform any divisions. All the arithmetic is done with integers, and we use bit shifting to divide by various powers of B.
Range reduction lets us find a good initial approximation to feed to Newton's method. Here's a version of the 2nd degree minimax polynomial approximation to the inverse square root in the interval [1/4, 1):
(Sorry, I've reversed the meaning of x & y here, to conform to the usual conventions). The maximum error of this approximation is around 0.0355 ~= 1/28. Here's a graph showing the error:
Using this poly, our initial x starts with at least 4 or 5 bits of precision. Each round of Newton's method doubles the precision, so it doesn't take many rounds to get thousands of bits, if we want them.
""" Integer square root
Uses no divisions, only shifts
"Quake" style algorithm,
i.e., Newton's method for 1 / sqrt(y)
Uses a quadratic minimax polynomial for the first approximation
Written by PM 2Ring 2022.01.23
"""
def int_sqrt(y):
if y < 0:
raise ValueError("int_sqrt arg must be >= 0, not %s" % y)
if y < 2:
return y
# print("\n*", y, "*")
# Range reduction.
# Find k such that 1/4 <= y/b < 1, where b = 2 ** (k*2)
j = y.bit_length()
# Round k*2 up to the next even number
k2 = j + (j & 1)
# k and some useful multiples
k = k2 >> 1
k3 = k2 + k
k6 = k3 << 1
kd = k6 + 1
# b cubed
b3 = 1 << k6
# Minimax approximation: x/b ~= 1 / sqrt(y/b)
x = (((463 * y * y) >> k2) - (896 * y) + (698 << k2)) >> 8
# print(" ", x, h)
# Newton's method for 1 / sqrt(y/b)
epsilon = 1 << k
for i in range(1, 99):
dx = x * (b3 - y * x * x) >> kd
x += dx
# print(f" {i}: {x} {dx}")
if abs(dx) <= epsilon:
break
# s == sqrt(y)
s = x * y >> k3
# Adjust if too low
ss = s + 1
return ss if ss * ss <= y else s
def test(lo, hi, step=1):
for y in range(lo, hi, step):
s = int_sqrt(y)
ss = s + 1
s2, ss2 = s * s, ss * ss
assert s2 <= y < ss2, (y, s2, ss2)
print("ok")
test(0, 100000, 1)
This code is certainly slower than math.isqrt and decimal.Decimal.sqrt. Its purpose is simply to illustrate the algorithm. It would be interesting to see how fast it would be if it were implemented in C...
Here's a live version, running on the SageMathCell server. Set hi <= 0 to calculate and display the results for a single value set in lo. You can put expressions in the input boxes, eg set hi to 0 and lo to 2 * 10**100 to get sqrt(2) * 10**50.
Inspired by all answers, decided to implement in pure C++ several best methods from these answers. As everybody knows C++ is always faster than Python.
To glue C++ and Python I used Cython. It allows to make out of C++ a Python module and then call C++ functions directly from Python functions.
Also as complementary I provided not only Python-adopted code, but pure C++ with tests too.
Here are timings from pure C++ tests:
Test 'GMP', bits 64, time 0.000001 sec
Test 'AndersKaseorg', bits 64, time 0.000003 sec
Test 'Babylonian', bits 64, time 0.000006 sec
Test 'ChordTangent', bits 64, time 0.000018 sec
Test 'GMP', bits 50000, time 0.000118 sec
Test 'AndersKaseorg', bits 50000, time 0.002777 sec
Test 'Babylonian', bits 50000, time 0.003062 sec
Test 'ChordTangent', bits 50000, time 0.009120 sec
and same C++ functions but as adopted Python module have timings:
Bits 50000
math.isqrt: 2.819 ms
gmpy2.isqrt: 0.166 ms
ISqrt_GMP: 0.252 ms
ISqrt_AndersKaseorg: 3.338 ms
ISqrt_Babylonian: 3.756 ms
ISqrt_ChordTangent: 10.564 ms
My Cython-C++ is nice in a sence as a framework for those people who want to write and test his own C++ method from Python directly.
As you noticed in above timings as example I used following methods:
math.isqrt, implementation from standard library.
gmpy2.isqrt, GMPY2 library's implementation.
ISqrt_GMP - same as GMPY2, but using my Cython module, there I use C++ GMP library (<gmpxx.h>) directly.
ISqrt_AndersKaseorg, code taken from answer of #AndersKaseorg.
ISqrt_Babylonian, method taken from Wikipedia article, so-called Babylonian method. My own implementation as I understand it.
ISqrt_ChordTangent, it is my own method that I called Chord-Tangent, because it uses chord and tangent line to iteratively shorten interval of search. This method is described in moderate details in my other article. This method is nice because it searches not only square root, but also K-th root for any K. I drew a small picture showing details of this algorithm.
Regarding compiling C++/Cython code, I used GMP library. You need to install it first, under Linux it is easy through sudo apt install libgmp-dev.
Under Windows easiest is to install really great program VCPKG, this is software Package Manager, similar to APT in Linux. VCPKG compiles all packages from sources using Visual Studio (don't forget to install Community version of Visual Studio). After installing VCPKG you can install GMP by vcpkg install gmp. Also you may install MPIR, this is alternative fork of GMP, you can install it through vcpkg install mpir.
After GMP is installed under Windows please edit my Python code and replace path to include directory and library file. VCPKG at the end of installation should show you path to ZIP file with GMP library, there are .lib and .h files.
You may notice in Python code that I also designed special handy cython_compile() function that I use to compile any C++ code into Python module. This function is really good as it allows for you to easily plug-in any C++ code into Python, this can be reused many times.
If you have any questions or suggestions, or something doesn't work on your PC, please write in comments.
Below first I show code in Python, afterwards in C++. See Try it online! link above C++ code to run code online on GodBolt servers. Both code snippets I fully runnable from scratch as they are, nothing needs to be edited in them.
def cython_compile(srcs):
import json, hashlib, os, glob, importlib, sys, shutil, tempfile
srch = hashlib.sha256(json.dumps(srcs, sort_keys = True, ensure_ascii = True).encode('utf-8')).hexdigest().upper()[:12]
pdir = 'cyimp'
if len(glob.glob(f'{pdir}/cy{srch}*')) == 0:
class ChDir:
def __init__(self, newd):
self.newd = newd
def __enter__(self):
self.curd = os.getcwd()
os.chdir(self.newd)
return self
def __exit__(self, ext, exv, tb):
os.chdir(self.curd)
os.makedirs(pdir, exist_ok = True)
with tempfile.TemporaryDirectory(dir = pdir) as td, ChDir(str(td)) as chd:
os.makedirs(pdir, exist_ok = True)
for k, v in srcs.items():
with open(f'cys{srch}_{k}', 'wb') as f:
f.write(v.replace('{srch}', srch).encode('utf-8'))
import numpy as np
from setuptools import setup, Extension
from Cython.Build import cythonize
sys.argv += ['build_ext', '--inplace']
setup(
ext_modules = cythonize(
Extension(
f'{pdir}.cy{srch}', [f'cys{srch}_{k}' for k in filter(lambda e: e[e.rfind('.') + 1:] in ['pyx', 'c', 'cpp'], srcs.keys())],
depends = [f'cys{srch}_{k}' for k in filter(lambda e: e[e.rfind('.') + 1:] not in ['pyx', 'c', 'cpp'], srcs.keys())],
extra_compile_args = ['/O2', '/std:c++latest',
'/ID:/dev/_3party/vcpkg_bin/gmp/include/',
],
),
compiler_directives = {'language_level': 3, 'embedsignature': True},
annotate = True,
),
include_dirs = [np.get_include()],
)
del sys.argv[-2:]
for f in glob.glob(f'{pdir}/cy{srch}*'):
shutil.copy(f, f'./../')
print('Cython module:', f'cy{srch}')
return importlib.import_module(f'{pdir}.cy{srch}')
def cython_import():
srcs = {
'lib.h': """
#include <cstring>
#include <cstdint>
#include <stdexcept>
#include <tuple>
#include <iostream>
#include <string>
#include <type_traits>
#include <sstream>
#include <gmpxx.h>
#pragma comment(lib, "D:/dev/_3party/vcpkg_bin/gmp/lib/gmp.lib")
#define ASSERT_MSG(cond, msg) { if (!(cond)) throw std::runtime_error("Assertion (" #cond ") failed at line " + std::to_string(__LINE__) + "! Msg '" + std::string(msg) + "'."); }
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define LN { std::cout << "LN " << __LINE__ << std::endl; }
using u32 = uint32_t;
using u64 = uint64_t;
template <typename T>
size_t BitLen(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return mpz_sizeinbase(n.get_mpz_t(), 2);
else {
size_t cnt = 0;
while (n >= (1ULL << 32)) {
cnt += 32;
n >>= 32;
}
while (n >= (1 << 8)) {
cnt += 8;
n >>= 8;
}
while (n) {
++cnt;
n >>= 1;
}
return cnt;
}
}
template <typename T>
T ISqrt_Babylonian(T const & y) {
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
if (y <= 1)
return y;
T x = T(1) << (BitLen(y) / 2), a = 0, b = 0, limit = 3;
while (true) {
size_t constexpr loops = 3;
for (size_t i = 0; i < loops; ++i) {
if (i + 1 >= loops)
a = x;
b = y;
b /= x;
x += b;
x >>= 1;
}
if (b < a)
std::swap(a, b);
if (b - a > limit)
continue;
++b;
for (size_t i = 0; a <= b; ++a, ++i)
if (a * a > y) {
if (i == 0)
break;
else
return a - 1;
}
ASSERT(false);
}
}
template <typename T>
T ISqrt_AndersKaseorg(T const & n) {
// https://stackoverflow.com/a/53983683/941531
if (n > 0) {
T y = 0, x = T(1) << ((BitLen(n) + 1) >> 1);
while (true) {
y = (x + n / x) >> 1;
if (y >= x)
return x;
x = y;
}
} else if (n == 0)
return 0;
else
ASSERT_MSG(false, "square root not defined for negative numbers");
}
template <typename T>
T ISqrt_GMP(T const & y) {
// https://gmplib.org/manual/Integer-Roots
mpz_class r, n;
bool constexpr is_mpz = std::is_same_v<std::decay_t<T>, mpz_class>;
if constexpr(is_mpz)
n = y;
else {
static_assert(sizeof(T) <= 8);
n = u32(y >> 32);
n <<= 32;
n |= u32(y);
}
mpz_sqrt(r.get_mpz_t(), n.get_mpz_t());
if constexpr(is_mpz)
return r;
else
return (u64(mpz_get_ui(mpz_class(r >> 32).get_mpz_t())) << 32) | u64(mpz_get_ui(mpz_class(r & u32(-1)).get_mpz_t()));
}
template <typename T>
T KthRoot_ChordTangent(T const & n, size_t k = 2) {
// https://i.stack.imgur.com/et9O0.jpg
if (n <= 1)
return n;
auto KthPow = [&](auto const & x){
T y = x * x;
for (size_t i = 2; i < k; ++i)
y *= x;
return y;
};
auto KthPowDer = [&](auto const & x){
T y = x * u32(k);
for (size_t i = 1; i + 1 < k; ++i)
y *= x;
return y;
};
size_t root_bit_len = (BitLen(n) + k - 1) / k;
T hi = T(1) << root_bit_len,
x_begin = hi >> 1, x_end = hi,
y_begin = KthPow(x_begin), y_end = KthPow(x_end),
x_mid = 0, y_mid = 0, x_n = 0, y_n = 0, tangent_x = 0, chord_x = 0;
for (size_t icycle = 0; icycle < (1 << 30); ++icycle) {
if (x_end <= x_begin + 2)
break;
if constexpr(0) { // Do Binary Search step if needed
x_mid = (x_begin + x_end) >> 1;
y_mid = KthPow(x_mid);
if (y_mid > n) {
x_end = x_mid; y_end = y_mid;
} else {
x_begin = x_mid; y_begin = y_mid;
}
}
// (y_end - y_begin) / (x_end - x_begin) = (n - y_begin) / (x_n - x_begin) ->
x_n = x_begin + (n - y_begin) * (x_end - x_begin) / (y_end - y_begin);
y_n = KthPow(x_n);
tangent_x = x_n + (n - y_n) / KthPowDer(x_n) + 1;
chord_x = x_n + (n - y_n) * (x_end - x_n) / (y_end - y_n);
//ASSERT(chord_x <= tangent_x);
x_begin = chord_x; x_end = tangent_x;
y_begin = KthPow(x_begin); y_end = KthPow(x_end);
//ASSERT(y_begin <= n);
//ASSERT(y_end > n);
}
for (size_t i = 0; x_begin <= x_end; ++x_begin, ++i)
if (x_begin * x_begin > n) {
if (i == 0)
break;
else
return x_begin - 1;
}
ASSERT(false);
return 0;
}
mpz_class FromLimbs(uint64_t * limbs, uint64_t * cnt) {
mpz_class r;
mpz_import(r.get_mpz_t(), *cnt, -1, 8, -1, 0, limbs);
return r;
}
void ToLimbs(mpz_class const & n, uint64_t * limbs, uint64_t * cnt) {
uint64_t cnt_before = *cnt;
size_t cnt_res = 0;
mpz_export(limbs, &cnt_res, -1, 8, -1, 0, n.get_mpz_t());
ASSERT(cnt_res <= cnt_before);
std::memset(limbs + cnt_res, 0, (cnt_before - cnt_res) * 8);
*cnt = cnt_res;
}
void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_GMP<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_AndersKaseorg<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_Babylonian<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(KthRoot_ChordTangent<mpz_class>(FromLimbs(limbs, cnt), 2), limbs, cnt);
}
""",
'main.pyx': r"""
# distutils: language = c++
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
import numpy as np
cimport numpy as np
cimport cython
from libc.stdint cimport *
cdef extern from "cys{srch}_lib.h" nogil:
void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt);
void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt);
#cython.boundscheck(False)
#cython.wraparound(False)
def ISqrt(method, n):
mask64 = (1 << 64) - 1
def ToLimbs():
return np.copy(np.frombuffer(n.to_bytes((n.bit_length() + 63) // 64 * 8, 'little'), dtype = np.uint64))
words = (n.bit_length() + 63) // 64
t = n
r = np.zeros((words,), dtype = np.uint64)
for i in range(words):
r[i] = np.uint64(t & mask64)
t >>= 64
return r
def FromLimbs(x):
return int.from_bytes(x.tobytes(), 'little')
n = 0
for i in range(x.shape[0]):
n |= int(x[i]) << (i * 64)
return n
n = ToLimbs()
cdef uint64_t[:] cn = n
cdef uint64_t ccnt = len(n)
cdef uint64_t cmethod = {'GMP': 0, 'AndersKaseorg': 1, 'Babylonian': 2, 'ChordTangent': 3}[method]
with nogil:
(ISqrt_GMP_Py if cmethod == 0 else ISqrt_AndersKaseorg_Py if cmethod == 1 else ISqrt_Babylonian_Py if cmethod == 2 else ISqrt_ChordTangent_Py)(
<uint64_t *>&cn[0], <uint64_t *>&ccnt
)
return FromLimbs(n[:ccnt])
""",
}
return cython_compile(srcs)
def main():
import math, gmpy2, timeit, random
mod = cython_import()
fs = [
('math.isqrt', math.isqrt),
('gmpy2.isqrt', gmpy2.isqrt),
('ISqrt_GMP', lambda n: mod.ISqrt('GMP', n)),
('ISqrt_AndersKaseorg', lambda n: mod.ISqrt('AndersKaseorg', n)),
('ISqrt_Babylonian', lambda n: mod.ISqrt('Babylonian', n)),
('ISqrt_ChordTangent', lambda n: mod.ISqrt('ChordTangent', n)),
]
times = [0] * len(fs)
ntests = 1 << 6
bits = 50000
for i in range(ntests):
n = random.randrange(1 << (bits - 1), 1 << bits)
ref = None
for j, (fn, f) in enumerate(fs):
timeit_cnt = 3
tim = timeit.timeit(lambda: f(n), number = timeit_cnt) / timeit_cnt
times[j] += tim
x = f(n)
if j == 0:
ref = x
else:
assert x == ref, (fn, ref, x)
print('Bits', bits)
print('\n'.join([f'{fs[i][0]:>19}: {round(times[i] / ntests * 1000, 3):>7} ms' for i in range(len(fs))]))
if __name__ == '__main__':
main()
and C++:
Try it online!
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <tuple>
#include <iostream>
#include <string>
#include <type_traits>
#include <sstream>
#include <gmpxx.h>
#define ASSERT_MSG(cond, msg) { if (!(cond)) throw std::runtime_error("Assertion (" #cond ") failed at line " + std::to_string(__LINE__) + "! Msg '" + std::string(msg) + "'."); }
#define ASSERT(cond) ASSERT_MSG(cond, "")
#define LN { std::cout << "LN " << __LINE__ << std::endl; }
using u32 = uint32_t;
using u64 = uint64_t;
template <typename T>
size_t BitLen(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return mpz_sizeinbase(n.get_mpz_t(), 2);
else {
size_t cnt = 0;
while (n >= (1ULL << 32)) {
cnt += 32;
n >>= 32;
}
while (n >= (1 << 8)) {
cnt += 8;
n >>= 8;
}
while (n) {
++cnt;
n >>= 1;
}
return cnt;
}
}
template <typename T>
T ISqrt_Babylonian(T const & y) {
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
if (y <= 1)
return y;
T x = T(1) << (BitLen(y) / 2), a = 0, b = 0, limit = 3;
while (true) {
size_t constexpr loops = 3;
for (size_t i = 0; i < loops; ++i) {
if (i + 1 >= loops)
a = x;
b = y;
b /= x;
x += b;
x >>= 1;
}
if (b < a)
std::swap(a, b);
if (b - a > limit)
continue;
++b;
for (size_t i = 0; a <= b; ++a, ++i)
if (a * a > y) {
if (i == 0)
break;
else
return a - 1;
}
ASSERT(false);
}
}
template <typename T>
T ISqrt_AndersKaseorg(T const & n) {
// https://stackoverflow.com/a/53983683/941531
if (n > 0) {
T y = 0, x = T(1) << ((BitLen(n) + 1) >> 1);
while (true) {
y = (x + n / x) >> 1;
if (y >= x)
return x;
x = y;
}
} else if (n == 0)
return 0;
else
ASSERT_MSG(false, "square root not defined for negative numbers");
}
template <typename T>
T ISqrt_GMP(T const & y) {
// https://gmplib.org/manual/Integer-Roots
mpz_class r, n;
bool constexpr is_mpz = std::is_same_v<std::decay_t<T>, mpz_class>;
if constexpr(is_mpz)
n = y;
else {
static_assert(sizeof(T) <= 8);
n = u32(y >> 32);
n <<= 32;
n |= u32(y);
}
mpz_sqrt(r.get_mpz_t(), n.get_mpz_t());
if constexpr(is_mpz)
return r;
else
return (u64(mpz_get_ui(mpz_class(r >> 32).get_mpz_t())) << 32) | u64(mpz_get_ui(mpz_class(r & u32(-1)).get_mpz_t()));
}
template <typename T>
std::string IntToStr(T n) {
if constexpr(std::is_same_v<std::decay_t<T>, mpz_class>)
return n.get_str();
else {
std::ostringstream ss;
ss << n;
return ss.str();
}
}
template <typename T>
T KthRoot_ChordTangent(T const & n, size_t k = 2) {
// https://i.stack.imgur.com/et9O0.jpg
if (n <= 1)
return n;
auto KthPow = [&](auto const & x){
T y = x * x;
for (size_t i = 2; i < k; ++i)
y *= x;
return y;
};
auto KthPowDer = [&](auto const & x){
T y = x * u32(k);
for (size_t i = 1; i + 1 < k; ++i)
y *= x;
return y;
};
size_t root_bit_len = (BitLen(n) + k - 1) / k;
T hi = T(1) << root_bit_len,
x_begin = hi >> 1, x_end = hi,
y_begin = KthPow(x_begin), y_end = KthPow(x_end),
x_mid = 0, y_mid = 0, x_n = 0, y_n = 0, tangent_x = 0, chord_x = 0;
for (size_t icycle = 0; icycle < (1 << 30); ++icycle) {
//std::cout << "x_begin, x_end = " << IntToStr(x_begin) << ", " << IntToStr(x_end) << ", n " << IntToStr(n) << std::endl;
if (x_end <= x_begin + 2)
break;
if constexpr(0) { // Do Binary Search step if needed
x_mid = (x_begin + x_end) >> 1;
y_mid = KthPow(x_mid);
if (y_mid > n) {
x_end = x_mid; y_end = y_mid;
} else {
x_begin = x_mid; y_begin = y_mid;
}
}
// (y_end - y_begin) / (x_end - x_begin) = (n - y_begin) / (x_n - x_begin) ->
x_n = x_begin + (n - y_begin) * (x_end - x_begin) / (y_end - y_begin);
y_n = KthPow(x_n);
tangent_x = x_n + (n - y_n) / KthPowDer(x_n) + 1;
chord_x = x_n + (n - y_n) * (x_end - x_n) / (y_end - y_n);
//ASSERT(chord_x <= tangent_x);
x_begin = chord_x; x_end = tangent_x;
y_begin = KthPow(x_begin); y_end = KthPow(x_end);
//ASSERT(y_begin <= n);
//ASSERT(y_end > n);
}
for (size_t i = 0; x_begin <= x_end; ++x_begin, ++i)
if (x_begin * x_begin > n) {
if (i == 0)
break;
else
return x_begin - 1;
}
ASSERT(false);
return 0;
}
mpz_class FromLimbs(uint64_t * limbs, uint64_t * cnt) {
mpz_class r;
mpz_import(r.get_mpz_t(), *cnt, -1, 8, -1, 0, limbs);
return r;
}
void ToLimbs(mpz_class const & n, uint64_t * limbs, uint64_t * cnt) {
uint64_t cnt_before = *cnt;
size_t cnt_res = 0;
mpz_export(limbs, &cnt_res, -1, 8, -1, 0, n.get_mpz_t());
ASSERT(cnt_res <= cnt_before);
std::memset(limbs + cnt_res, 0, (cnt_before - cnt_res) * 8);
*cnt = cnt_res;
}
void ISqrt_ChordTangent_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(KthRoot_ChordTangent<mpz_class>(FromLimbs(limbs, cnt), 2), limbs, cnt);
}
void ISqrt_GMP_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_GMP<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_AndersKaseorg_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_AndersKaseorg<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
void ISqrt_Babylonian_Py(uint64_t * limbs, uint64_t * cnt) {
ToLimbs(ISqrt_Babylonian<mpz_class>(FromLimbs(limbs, cnt)), limbs, cnt);
}
// Testing
#include <chrono>
#include <random>
#include <vector>
#include <iomanip>
inline double Time() {
static auto const gtb = std::chrono::high_resolution_clock::now();
return std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::high_resolution_clock::now() - gtb)
.count();
}
template <typename T, typename F>
std::vector<T> Test0(std::string const & test_name, size_t bits, size_t ntests, F && f) {
std::mt19937_64 rng{123};
std::vector<T> nums;
for (size_t i = 0; i < ntests; ++i) {
T n = 0;
for (size_t j = 0; j < bits; j += 32) {
size_t const cbits = std::min<size_t>(32, bits - j);
n <<= cbits;
n ^= u32(rng()) >> (32 - cbits);
}
nums.push_back(n);
}
auto tim = Time();
for (auto & n: nums)
n = f(n);
tim = Time() - tim;
std::cout << "Test " << std::setw(15) << ("'" + test_name + "'")
<< ", bits " << std::setw(6) << bits << ", time "
<< std::fixed << std::setprecision(6) << std::setw(9) << tim / ntests << " sec" << std::endl;
return nums;
}
void Test() {
auto f = [](auto ty, size_t bits, size_t ntests){
using T = std::decay_t<decltype(ty)>;
auto tim = Time();
auto a = Test0<T>("GMP", bits, ntests, [](auto const & x){ return ISqrt_GMP<T>(x); });
auto b = Test0<T>("AndersKaseorg", bits, ntests, [](auto const & x){ return ISqrt_AndersKaseorg<T>(x); });
ASSERT(b == a);
auto c = Test0<T>("Babylonian", bits, ntests, [](auto const & x){ return ISqrt_Babylonian<T>(x); });
ASSERT(c == a);
auto d = Test0<T>("ChordTangent", bits, ntests, [](auto const & x){ return KthRoot_ChordTangent<T>(x); });
ASSERT(d == a);
std::cout << "Bits " << bits << " nums " << ntests << " time " << std::fixed << std::setprecision(1) << (Time() - tim) << " sec" << std::endl;
};
for (auto p: std::vector<std::pair<int, int>>{{15, 1 << 19}, {30, 1 << 19}})
f(u64(), p.first, p.second);
for (auto p: std::vector<std::pair<int, int>>{{64, 1 << 15}, {8192, 1 << 10}, {50000, 1 << 5}})
f(mpz_class(), p.first, p.second);
}
int main() {
try {
Test();
return 0;
} catch (std::exception const & ex) {
std::cout << "Exception: " << ex.what() << std::endl;
return -1;
}
}
Your function fails for large inputs:
In [26]: isqrt((10**100+1)**2)
ValueError: input was not a perfect square
There is a recipe on the ActiveState site which should hopefully be more reliable since it uses integer maths only. It is based on an earlier StackOverflow question: Writing your own square root function
Floats cannot be precisely represented on computers. You can test for a desired proximity setting epsilon to a small value within the accuracy of python's floats.
def isqrt(n):
epsilon = .00000000001
i = int(n**.5 + 0.5)
if abs(i**2 - n) < epsilon:
return i
raise ValueError('input was not a perfect square')
Try this condition (no additional computation):
def isqrt(n):
i = math.sqrt(n)
if i != int(i):
raise ValueError('input was not a perfect square')
return i
If you need it to return an int (not a float with a trailing zero) then either assign a 2nd variable or compute int(i) twice.
I have compared the different methods given here with a loop:
for i in range (1000000): # 700 msec
r=int(123456781234567**0.5+0.5)
if r**2==123456781234567:rr=r
else:rr=-1
finding that this one is fastest and need no math-import. Very long might fail, but look at this
15241576832799734552675677489**0.5 = 123456781234567.0

Categories