What is the big O of the following if statement?
if "pl" in "apple":
...
What is the overall big O of how python determines if the string "pl" is found in the string "apple"
or any other substring in string search.
Is this the most efficient way to test if a substring is in a string? Does it use the same algorithm as .find()?
The time complexity is O(N) on average, O(NM) worst case (N being the length of the longer string, M, the shorter string you search for). As of Python 3.10, heuristics are used to lower the worst-case scenario to O(N + M) by switching algorithms.
The same algorithm is used for str.index(), str.find(), str.__contains__() (the in operator) and str.replace(); it is a simplification of the Boyer-Moore with ideas taken from the Boyer–Moore–Horspool and Sunday algorithms.
See the original stringlib discussion post, as well as the fastsearch.h source code; until Python 3.10, the base algorithm has not changed since introduction in Python 2.5 (apart from some low-level optimisations and corner-case fixes).
The post includes a Python-code outline of the algorithm:
def find(s, p):
# find first occurrence of p in s
n = len(s)
m = len(p)
skip = delta1(p)[p[m-1]]
i = 0
while i <= n-m:
if s[i+m-1] == p[m-1]: # (boyer-moore)
# potential match
if s[i:i+m-1] == p[:m-1]:
return i
if s[i+m] not in p:
i = i + m + 1 # (sunday)
else:
i = i + skip # (horspool)
else:
# skip
if s[i+m] not in p:
i = i + m + 1 # (sunday)
else:
i = i + 1
return -1 # not found
as well as speed comparisons.
In Python 3.10, the algorithm was updated to use an enhanced version of the Crochemore and Perrin's Two-Way string searching algorithm for larger problems (with p and s longer than 100 and 2100 characters, respectively, with s at least 6 times as long as p), in response to a pathological edgecase someone reported. The commit adding this change included a write-up on how the algorithm works.
The Two-way algorithm has a worst-case time complexity of O(N + M), where O(M) is a cost paid up-front to build a shift table from the s search needle. Once you have that table, this algorithm does have a best-case performance of O(N/M).
In Python 3.4.2, it looks like they are resorting to the same function, but there may be a difference in timing nevertheless. For example, s.find first is required to look up the find method of the string and such.
The algorithm used is a mix between Boyer-More and Horspool.
You can use timeit and test it yourself:
maroun#DQHCPY1:~$ python -m timeit 's = "apple";s.find("pl")'
10000000 loops, best of 3: 0.125 usec per loop
maroun#DQHCPY1:~$ python -m timeit 's = "apple";"pl" in s'
10000000 loops, best of 3: 0.0371 usec per loop
Using in is indeed faster (0.0371 usec compared to 0.125 usec).
For actual implementation, you can look at the code itself.
I think the best way to find out is to look at the source. This looks like it would implement __contains__:
static int
bytes_contains(PyObject *self, PyObject *arg)
{
Py_ssize_t ival = PyNumber_AsSsize_t(arg, PyExc_ValueError);
if (ival == -1 && PyErr_Occurred()) {
Py_buffer varg;
Py_ssize_t pos;
PyErr_Clear();
if (PyObject_GetBuffer(arg, &varg, PyBUF_SIMPLE) != 0)
return -1;
pos = stringlib_find(PyBytes_AS_STRING(self), Py_SIZE(self),
varg.buf, varg.len, 0);
PyBuffer_Release(&varg);
return pos >= 0;
}
if (ival < 0 || ival >= 256) {
PyErr_SetString(PyExc_ValueError, "byte must be in range(0, 256)");
return -1;
}
return memchr(PyBytes_AS_STRING(self), (int) ival, Py_SIZE(self)) != NULL;
}
in terms of stringlib_find(), which uses fastsearch().
Related
I'm having trouble trying to count the number of leading zero bits after an sha256 hash function as I don't have a lot of experience on 'low level' stuff in python
hex = hashlib.sha256((some_data_from_file).encode('ascii')).hexdigest()
# hex = 0000094e7cc7303a3e33aaeaba76ad937603d4d040064f473a12f10ab30a879f
# this has 20 leading zero bits
hex_digits = int.from_bytes(bytes(hex.encode('ascii')), 'big') #convert str to int
#count num of leading zeroes
def countZeros(x):
total_bits = 256
res = 0
while ((x & (1 << (total_bits - 1))) == 0):
x = (x << 1)
res += 1
return res
print(countZeroes(hex_digits)) # returns 2
I've also tried converting it using bin() however that didn't provide me with any leading zeros.
Instead of getting the hex digest and analyzing that hex string, you could just get the digest, interpret it as an int, ask for its bit-length, and subtract that from 256:
digest = hashlib.sha256(some_data_from_file.encode('ascii')).digest()
print(256 - int.from_bytes(digest, 'big').bit_length())
Demo (Try it online!):
import hashlib
some_data_from_file = '665782'
# Show hex digest for clarity
hex = hashlib.sha256(some_data_from_file.encode('ascii')).hexdigest()
print(hex)
# Show number of leading zero bits
digest = hashlib.sha256(some_data_from_file.encode('ascii')).digest()
print(256 - int.from_bytes(digest, 'big').bit_length())
Output:
0000000399c6aea5ad0c709a9bc331a3ed6494702bd1d129d8c817a0257a1462
30
Benchmark along with Pranav's (not sure how to handle mtraceur's) starting with sha256-values (i.e., before calling hexdigest() or digest()):
462 ns 464 ns 471 ns just_digest
510 ns 518 ns 519 ns just_hexdigest
566 ns 568 ns 574 ns Kelly3
608 ns 608 ns 611 ns Kelly2
688 ns 688 ns 692 ns Kelly
1139 ns 1139 ns 1140 ns Pranav
Benchmark code (Try it online!):
def Kelly(sha256):
return 256 - int.from_bytes(sha256.digest(), 'big').bit_length()
def Kelly2(sha256):
zeros = 0
for byte in sha256.digest():
if byte:
return zeros + 8 - byte.bit_length()
zeros += 8
return zeros
def Kelly3(sha256):
digest = sha256.digest()
if byte := digest[0]:
return 8 - byte.bit_length()
zeros = 0
for byte in digest:
if byte:
return zeros + 8 - byte.bit_length()
zeros += 8
return zeros
def Pranav(sha256):
nzeros = 0
for c in sha256.hexdigest():
if c == "0": nzeros += 4
else:
digit = int(c, base=16)
nzeros += 4 - (math.floor(math.log2(digit)) + 1)
break
return nzeros
def just_digest(sha256):
return sha256.digest()
def just_hexdigest(sha256):
return sha256.hexdigest()
funcs = just_digest, just_hexdigest, Kelly3, Kelly2, Kelly, Pranav
from timeit import repeat
import hashlib, math
from collections import deque
sha256s = [hashlib.sha256(str(i).encode('ascii'))
for i in range(10_000)]
expect = list(map(Kelly, sha256s))
for func in funcs:
result = list(map(func, sha256s))
print(result == expect, func.__name__)
tss = [[] for _ in funcs]
for _ in range(10):
print()
for func, ts in zip(funcs, tss):
t = min(repeat(lambda: deque(map(func, sha256s), 0), number=1))
ts.append(t)
for func, ts in zip(funcs, tss):
print(*('%4d ns ' % (t / len(sha256s) * 1e9) for t in sorted(ts)[:3]), func.__name__)
.hexdigest() returns a string, so your hex variable is a string.
I'm going to call it h instead, because hex is a builtin python function.
So you have:
h = "0000094e7cc7303a3e33aaeaba76ad937603d4d040064f473a12f10ab30a879f"
Now this is a hexadecimal string. Each digit in a hexadecimal number gives you four bits in binary. Since this has five leading zeros, you already have 5 * 4 = 20 leading zeros.
nzeros = 0
for c in h:
if c == "0": nzeros += 4
else: break
Then, you need to count the leading zeros in the binary representation of the first non-zero hexadecimal digit. This is easy to get: A number has math.floor(math.log2(number)) + 1 binary digits, i.e. 4 - (math.floor(math.log2(number)) + 1) leading zeros if it's a hexadecimal digit, since they can only have a max of 4 bits. In this case, the digit is a 9 (1001 in binary), so there are zero additional leading zeros.
So, modify the previous loop:
nzeros = 0
for c in h:
if c == "0": nzeros += 4
else:
digit = int(c, base=16)
nzeros += 4 - (math.floor(math.log2(digit)) + 1)
break
print(nzeros) # 20
Danger!!!
Is this security-sensitive code? Can this hash ever be the result of hashing secret/private data?
If so, then you should probably implement something in C or similar, while taking care to protect against leaking information about the hash through side-channels.
Otherwise, I suggest picking the version (from any of these answers) that you and the people working on your code find the most intuitive, clear, and so on, unless performance matters more than readability, in which case pick the fastest one.
If your hashes are never of security-sensitive inputs, then:
If you just want a good balance of simplicity and low-effort:
def count_leading_zeroes(value, max_bits=256):
value &= (1 << max_bits) - 1 # truncate; treat negatives as 2's compliment
if value == 0:
return max_bits
significant_bits = len(bin(value)) - 2 # has "0b" prefix
return max_bits - significant_bits
If you want to really embrace the bit twiddling you were trying in your question:
def count_leading_zeroes(value, max_bits=256):
top_bit = 1 << (max_bits - 1)
count = 0
value &= (1 << max_bits) - 1
while not value & top_bit:
count += 1
value <<= 1
return count
If you're doing manual bit twiddling, I think in this case a loop which counts from the top is the most justified option, because hashes are rather evenly randomly distributed and so each bit has about even chance of not being zero.
So you have a good chance of exiting the loop early and thus executing more efficiently if you start for from the top (if you start from the bottom you have to check every bit).
You could alternatively do a bit twiddling thing that's inspired by binary search. That way instead of O(n) steps you do O(log(n)) steps. However, this arguably isn't an optimization worth doing in CPython, and for a JIT implementation like PyPy this manual optimization can actually make it harder for automatic optimization to realize that you can just use a raw "count leading zeroes" CPU instruction. If I ever get the time I'll edit this answer with an example of that later.
Now about those side-channel attacks: basically any time you have code that works on secret data (or any results of secret data which you can't prove (like a cryptographer would) have fully irretrievably lost all information about the secret data) , you should make sure your code takes does exactly the same amount of operations and takes the same branches regardless of the data!
Explaining why you should do this is outside the scope of this answer, but if you don't, your code could be harming users by leaking their secret information in ways that hackers could access.
So!
You might be tempted to modify the simple version that uses bin, but bin is inherently hash-dependent: it produces a string whose length is conditional on the leading zeroes, and as far as I know it doesn't (and logically can't! at least not in the general case) guarantee that it does so in constant-time without data-dependent branches. So we should assume merely running bin on an integer leaks information about the integer through side-channels like runtime and branch predictor state and amount of memory allocated and so on.
For illustrative purposes, if we did have a side-channel-safe bin, which I'll call "bin_", we could do:
def count_leading_zeroes(value, max_bits=256):
value &= (1 << max_bits) - 1 # truncate; treat negatives as 2's compliment
value <<= 1 # securely compensate for 0
significant_bits = len(bin_(value)) - 3 # has "0b" prefix and "0" suffix
return max_bits - significant_bits
In principle, a bit-twiddling loop could do leading zero bit count in constant-time and free of input-dependent branches.
The challenge is writing this neatly in Python, because Python is so abstracted from the underlying machine.
The core problem is this: at the CPU level, it's really easy to take a 1 or 0 bit and turn it, branchlessly, into something more useful (like a mask with all bits 1s or all bits 2s, which then lets you conditionally but branchlessly select one of two numbers or clear a number, which you can then use to implement something like "if the lowest bit is set, reset the counter to zero"). At the Python level, implementing stuff like this is a struggle through the fog of a lot of uncertainty of how the Python runtime is implemented - there are many places where it might be reasonable to have data-dependent branches under the covers. So really we want to reduce the amount of Python steps and conversions between the digest that hashlib gives us and our leading zeroes answer.
So the best option is actually to never even reach for human-readable stuff like hex or integer forms of the digest at all! Just stick to the raw digest. Something like this, conceptually:
def count_leading_zeroes_in_bytes(data):
count = 0
# branchless "latch" mask to stop counting:
still_in_leading_zeroes = 1
for byte in data:
for index in reversed(range(8)):
bit = (byte >> index) & 1
# branchless "conditional" if bit is zero:
is_zero = bit ^ 1
# branchlessly increment count if needed:
count += is_zero & still_in_leading_zeroes
# branchlessly latch count on first 1 bit:
still_in_leading_zeroes &= is_zero
return count
This is the best I was able to think of in pure Python. And it still failed.
But some quick testing by both #KellyBundy and me (see comments and Kelly's answer for some examples) shows this version is both extremely slow, and does not actually achieve input-independent execution times (because there's yet another relevant data-dependent optimization inside Python, and possibly for other reasons we're missing).
So if you're going to try to implement anything in Python, test it thoroughly before relying on it to be actually be secure, or just taking the general gist and implementing a C or assembly version. Something like this:
/* _clz.c */
#include <limits.h> /* CHAR_BIT */
#include <stddef.h> /* size_t */
int count_leading_zeroes_bytes(char * bytes, size_t length)
{
int still_in_leading_zeroes = 1;
int count = 0;
while(length--)
{
char byte = *bytes++;
int bits = CHAR_BIT;
while(bits--)
{
int bit = (byte >> bits) & 1;
int is_zero = bit ^ 1;
count += is_zero & still_in_leading_zeroes;
still_in_leading_zeroes &= is_zero;
}
}
return count;
}
# clz.py
import ctypes
# This is just a quick example. A mature version
# would load the library as appropriate for each
# platform.
_library = ctypes.CDLL('./_clz.so')
_count_leading_zeroes_bytes = _library.count_leading_zeroes_bytes
def count_leading_zeroes_bytes(data):
return _count_leading_zeroes_bytes(
ctypes.c_char_p(data),
ctypes.c_size_t(len(data)),
)
When comparing floats to integers, some pairs of values take much longer to be evaluated than other values of a similar magnitude.
For example:
>>> import timeit
>>> timeit.timeit("562949953420000.7 < 562949953421000") # run 1 million times
0.5387085462592742
But if the float or integer is made smaller or larger by a certain amount, the comparison runs much more quickly:
>>> timeit.timeit("562949953420000.7 < 562949953422000") # integer increased by 1000
0.1481498428446173
>>> timeit.timeit("562949953423001.8 < 562949953421000") # float increased by 3001.1
0.1459577925548956
Changing the comparison operator (e.g. using == or > instead) does not affect the times in any noticeable way.
This is not solely related to magnitude because picking larger or smaller values can result in faster comparisons, so I suspect it is down to some unfortunate way the bits line up.
Clearly, comparing these values is more than fast enough for most use cases. I am simply curious as to why Python seems to struggle more with some pairs of values than with others.
A comment in the Python source code for float objects acknowledges that:
Comparison is pretty much a nightmare
This is especially true when comparing a float to an integer, because, unlike floats, integers in Python can be arbitrarily large and are always exact. Trying to cast the integer to a float might lose precision and make the comparison inaccurate. Trying to cast the float to an integer is not going to work either because any fractional part will be lost.
To get around this problem, Python performs a series of checks, returning the result if one of the checks succeeds. It compares the signs of the two values, then whether the integer is "too big" to be a float, then compares the exponent of the float to the length of the integer. If all of these checks fail, it is necessary to construct two new Python objects to compare in order to obtain the result.
When comparing a float v to an integer/long w, the worst case is that:
v and w have the same sign (both positive or both negative),
the integer w has few enough bits that it can be held in the size_t type (typically 32 or 64 bits),
the integer w has at least 49 bits,
the exponent of the float v is the same as the number of bits in w.
And this is exactly what we have for the values in the question:
>>> import math
>>> math.frexp(562949953420000.7) # gives the float's (significand, exponent) pair
(0.9999999999976706, 49)
>>> (562949953421000).bit_length()
49
We see that 49 is both the exponent of the float and the number of bits in the integer. Both numbers are positive and so the four criteria above are met.
Choosing one of the values to be larger (or smaller) can change the number of bits of the integer, or the value of the exponent, and so Python is able to determine the result of the comparison without performing the expensive final check.
This is specific to the CPython implementation of the language.
The comparison in more detail
The float_richcompare function handles the comparison between two values v and w.
Below is a step-by-step description of the checks that the function performs. The comments in the Python source are actually very helpful when trying to understand what the function does, so I've left them in where relevant. I've also summarised these checks in a list at the foot of the answer.
The main idea is to map the Python objects v and w to two appropriate C doubles, i and j, which can then be easily compared to give the correct result. Both Python 2 and Python 3 use the same ideas to do this (the former just handles int and long types separately).
The first thing to do is check that v is definitely a Python float and map it to a C double i. Next the function looks at whether w is also a float and maps it to a C double j. This is the best case scenario for the function as all the other checks can be skipped. The function also checks to see whether v is inf or nan:
static PyObject*
float_richcompare(PyObject *v, PyObject *w, int op)
{
double i, j;
int r = 0;
assert(PyFloat_Check(v));
i = PyFloat_AS_DOUBLE(v);
if (PyFloat_Check(w))
j = PyFloat_AS_DOUBLE(w);
else if (!Py_IS_FINITE(i)) {
if (PyLong_Check(w))
j = 0.0;
else
goto Unimplemented;
}
Now we know that if w failed these checks, it is not a Python float. Now the function checks if it's a Python integer. If this is the case, the easiest test is to extract the sign of v and the sign of w (return 0 if zero, -1 if negative, 1 if positive). If the signs are different, this is all the information needed to return the result of the comparison:
else if (PyLong_Check(w)) {
int vsign = i == 0.0 ? 0 : i < 0.0 ? -1 : 1;
int wsign = _PyLong_Sign(w);
size_t nbits;
int exponent;
if (vsign != wsign) {
/* Magnitudes are irrelevant -- the signs alone
* determine the outcome.
*/
i = (double)vsign;
j = (double)wsign;
goto Compare;
}
}
If this check failed, then v and w have the same sign.
The next check counts the number of bits in the integer w. If it has too many bits then it can't possibly be held as a float and so must be larger in magnitude than the float v:
nbits = _PyLong_NumBits(w);
if (nbits == (size_t)-1 && PyErr_Occurred()) {
/* This long is so large that size_t isn't big enough
* to hold the # of bits. Replace with little doubles
* that give the same outcome -- w is so large that
* its magnitude must exceed the magnitude of any
* finite float.
*/
PyErr_Clear();
i = (double)vsign;
assert(wsign != 0);
j = wsign * 2.0;
goto Compare;
}
On the other hand, if the integer w has 48 or fewer bits, it can safely turned in a C double j and compared:
if (nbits <= 48) {
j = PyLong_AsDouble(w);
/* It's impossible that <= 48 bits overflowed. */
assert(j != -1.0 || ! PyErr_Occurred());
goto Compare;
}
From this point onwards, we know that w has 49 or more bits. It will be convenient to treat w as a positive integer, so change the sign and the comparison operator as necessary:
if (nbits <= 48) {
/* "Multiply both sides" by -1; this also swaps the
* comparator.
*/
i = -i;
op = _Py_SwappedOp[op];
}
Now the function looks at the exponent of the float. Recall that a float can be written (ignoring sign) as significand * 2exponent and that the significand represents a number between 0.5 and 1:
(void) frexp(i, &exponent);
if (exponent < 0 || (size_t)exponent < nbits) {
i = 1.0;
j = 2.0;
goto Compare;
}
This checks two things. If the exponent is less than 0 then the float is smaller than 1 (and so smaller in magnitude than any integer). Or, if the exponent is less than the number of bits in w then we have that v < |w| since significand * 2exponent is less than 2nbits.
Failing these two checks, the function looks to see whether the exponent is greater than the number of bit in w. This shows that significand * 2exponent is greater than 2nbits and so v > |w|:
if ((size_t)exponent > nbits) {
i = 2.0;
j = 1.0;
goto Compare;
}
If this check did not succeed we know that the exponent of the float v is the same as the number of bits in the integer w.
The only way that the two values can be compared now is to construct two new Python integers from v and w. The idea is to discard the fractional part of v, double the integer part, and then add one. w is also doubled and these two new Python objects can be compared to give the correct return value. Using an example with small values, 4.65 < 4 would be determined by the comparison (2*4)+1 == 9 < 8 == (2*4) (returning false).
{
double fracpart;
double intpart;
PyObject *result = NULL;
PyObject *one = NULL;
PyObject *vv = NULL;
PyObject *ww = w;
// snip
fracpart = modf(i, &intpart); // split i (the double that v mapped to)
vv = PyLong_FromDouble(intpart);
// snip
if (fracpart != 0.0) {
/* Shift left, and or a 1 bit into vv
* to represent the lost fraction.
*/
PyObject *temp;
one = PyLong_FromLong(1);
temp = PyNumber_Lshift(ww, one); // left-shift doubles an integer
ww = temp;
temp = PyNumber_Lshift(vv, one);
vv = temp;
temp = PyNumber_Or(vv, one); // a doubled integer is even, so this adds 1
vv = temp;
}
// snip
}
}
For brevity I've left out the additional error-checking and garbage-tracking Python has to do when it creates these new objects. Needless to say, this adds additional overhead and explains why the values highlighted in the question are significantly slower to compare than others.
Here is a summary of the checks that are performed by the comparison function.
Let v be a float and cast it as a C double. Now, if w is also a float:
Check whether w is nan or inf. If so, handle this special case separately depending on the type of w.
If not, compare v and w directly by their representations as C doubles.
If w is an integer:
Extract the signs of v and w. If they are different then we know v and w are different and which is the greater value.
(The signs are the same.) Check whether w has too many bits to be a float (more than size_t). If so, w has greater magnitude than v.
Check if w has 48 or fewer bits. If so, it can be safely cast to a C double without losing its precision and compared with v.
(w has more than 48 bits. We will now treat w as a positive integer having changed the compare op as appropriate.)
Consider the exponent of the float v. If the exponent is negative, then v is less than 1 and therefore less than any positive integer. Else, if the exponent is less than the number of bits in w then it must be less than w.
If the exponent of v is greater than the number of bits in w then v is greater than w.
(The exponent is the same as the number of bits in w.)
The final check. Split v into its integer and fractional parts. Double the integer part and add 1 to compensate for the fractional part. Now double the integer w. Compare these two new integers instead to get the result.
Using gmpy2 with arbitrary precision floats and integers it is possible to get more uniform comparison performance:
~ $ ptipython
Python 3.5.1 |Anaconda 4.0.0 (64-bit)| (default, Dec 7 2015, 11:16:01)
Type "copyright", "credits" or "license" for more information.
IPython 4.1.2 -- An enhanced Interactive Python.
? -> Introduction and overview of IPython's features.
%quickref -> Quick reference.
help -> Python's own help system.
object? -> Details about 'object', use 'object??' for extra details.
In [1]: import gmpy2
In [2]: from gmpy2 import mpfr
In [3]: from gmpy2 import mpz
In [4]: gmpy2.get_context().precision=200
In [5]: i1=562949953421000
In [6]: i2=562949953422000
In [7]: f=562949953420000.7
In [8]: i11=mpz('562949953421000')
In [9]: i12=mpz('562949953422000')
In [10]: f1=mpfr('562949953420000.7')
In [11]: f<i1
Out[11]: True
In [12]: f<i2
Out[12]: True
In [13]: f1<i11
Out[13]: True
In [14]: f1<i12
Out[14]: True
In [15]: %timeit f<i1
The slowest run took 10.15 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 441 ns per loop
In [16]: %timeit f<i2
The slowest run took 12.55 times longer than the fastest. This could mean that an intermediate result is being cached.
10000000 loops, best of 3: 152 ns per loop
In [17]: %timeit f1<i11
The slowest run took 32.04 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 269 ns per loop
In [18]: %timeit f1<i12
The slowest run took 36.81 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 231 ns per loop
In [19]: %timeit f<i11
The slowest run took 78.26 times longer than the fastest. This could mean that an intermediate result is being cached.
10000000 loops, best of 3: 156 ns per loop
In [20]: %timeit f<i12
The slowest run took 21.24 times longer than the fastest. This could mean that an intermediate result is being cached.
10000000 loops, best of 3: 194 ns per loop
In [21]: %timeit f1<i1
The slowest run took 37.61 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 275 ns per loop
In [22]: %timeit f1<i2
The slowest run took 39.03 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 259 ns per loop
I am trying to play around with some R code I found recently that imitates parts of Norvig's spell checker written in Python; In particular, I am trying to work out the right way to implement the edit2 function in R:
def splits(word):
return [(word[:i], word[i:])
for i in range(len(word)+1)]
def edits1(word):
pairs = splits(word)
deletes = [a+b[1:] for (a, b) in pairs if b]
transposes = [a+b[1]+b[0]+b[2:] for (a, b) in pairs if len(b) > 1]
replaces = [a+c+b[1:] for (a, b) in pairs for c in alphabet if b]
inserts = [a+c+b for (a, b) in pairs for c in alphabet]
return set(deletes + transposes + replaces + inserts)
def edits2(word):
return set(e2 for e1 in edits1(word) for e2 in edits1(e1))
However, in my benchmarks, it seems, generating thousands of small strings in R using paste0 (or str_c from stringr, or stri_join from stringi) results in code that is roughly 10x (or ~100x, or ~50x) slower than the Python implementation shown by Norvig. (Yes, the stringr and stringi-based functions interestingly are even slower than using paste0.) My questions are (with #3 being the main one I want resolved):
Am I doing this correctly (is the code "right")?
If so, is this a known issue of R (extremely slow string concatenation)?
Is there anything I can do about this to make this significantly faster (one or more orders of magnitude, at least) without rewriting the whole function in Rcpp11 or something like that?
Here is my R code I came up with for the edit2 function:
# 1. generate a list of all binary splits of a word
binary.splits <- function(w) {
n <- nchar(w)
lapply(0:n, function(x)
c(stri_sub(w, 0, x), stri_sub(w, x + 1, n)))
}
# 2. generate a list of all bigrams for a word
bigram.unsafe <- function(word)
sapply(2:nchar(word), function(i) substr(word, i-1, i))
bigram <- function(word)
if (nchar(word) > 1) bigram.unsafe(word) else word
# 3. four edit types: deletion, transposition, replacement, and insertion
alphabet = letters
deletions <- function(splits) if (length(splits) > 1) {
sapply(1:(length(splits)-1), function(i)
paste0(splits[[i]][1], splits[[i+1]][2]), simplify=FALSE)
} else {
splits[[1]][2]
}
transpositions <- function(splits) if (length(splits) > 2) {
swaps <- rev(bigram.unsafe(stri_reverse(splits[[1]][2])))
sapply(1:length(swaps), function(i)
paste0(splits[[i]][1], swaps[i], splits[[i+2]][2]), simplify=FALSE)
} else {
stri_reverse(splits[[1]][2])
}
replacements <- function(splits) if (length(splits) > 1) {
sapply(1:(length(splits)-1), function(i)
lapply(alphabet, function(symbol)
paste0(splits[[i]][1], symbol, splits[[i+1]][2])))
} else {
alphabet
}
insertions <- function(splits)
sapply(splits, function(pair)
lapply(alphabet, function(symbol)
paste0(pair[1], symbol, pair[2])))
# 4. create a vector of all words at edit distance 1 given the input word
edit.1 <- function(word) {
splits <- binary.splits(word)
unique(unlist(c(deletions(splits),
transpositions(splits),
replacements(splits),
insertions(splits))))
}
# 5. create a simple function to generate all words of edit distance 1 and 2
edit.2 <- function(word) {
e1 <- edit.1(word)
unique(c(unlist(lapply(e1, edit.1)), e1))
}
If you start profiling this code, you will see that replacements and insertions have nested "lapplies" and seem to take 10x longer than the deletions or transpositions, because they generate far more spelling variants.
library(rbenchmark)
benchmark(edit.2('abcd'), replications=20)
This takes about 8 seconds on my Core i5 MacBook Air, while the corresponding Python benchmark (running the corresponding edit2 function 20 times) takes about 0.6 seconds, i.e., it is about 10-15 times faster!
I have tried using expand.grid to get rid of the inner lapply, but this made the code slower, not faster. And I know that using lapply in place of sapply makes my code a bit faster, but I do not see the point of using the "wrong" function (I want a vector back) for a minor speed bump. But maybe generating the result of the edit.2 function can be made much faster in pure R?
Performance of R's paste0 vs. python's ''.join
The original title asked whether paste0 in R was 10x slower than string concatenation in python. If it is, then there's no hope of writing an algorithm that relies heavily on string concatenation in R that is as fast as the corresponding python algorithm.
I have
> R.version.string
[1] "R version 3.1.0 Patched (2014-05-31 r65803)"
and
>>> sys.version '3.4.0 (default, Apr 11 2014, 13:05:11) \n[GCC 4.8.2]'
Here's a first comparison
> library(microbenchmark)
> microbenchmark(paste0("a", "b"), times=1e6)
Unit: nanoseconds
expr min lq median uq max neval
paste0("a", "b") 951 1071 1162 1293 21794972 1e+06
(so about 1s for all replicates) versus
>>> import timeit
>>> timeit.timeit("''.join(x)", "x=('a', 'b')", number=int(1e6))
0.119668865998392
I guess that's the 10x performance difference the original poster observed.
However, R works better on vectors, and the algorithm involves vectors
of words anyway, so we might be interested in the comparison
> x = y = sample(LETTERS, 1e7, TRUE); system.time(z <- paste0(x, y))
user system elapsed
1.479 0.009 1.488
and
>>> setup = '''
import random
import string
y = x = [random.choice(string.ascii_uppercase) for _ in range(10000000)]
'''
>>> timeit.Timer("map(''.join, zip(x, y))", setup=setup).repeat(1)
[0.362522566007101]
This suggests that we would be on the right track if our R algorithm
were to run at 1/4 the speed of python; the OP found a 10-fold
difference, so it looks like there's room for improvement.
R iteration versus vectorization
The OP uses iteration (lapply and friends), rather than vectorization. We can compare the vector version to various approaches to iteration with the following
f0 = paste0
f1 = function(x, y)
vapply(seq_along(x), function(i, x, y) paste0(x[i], y[i]), character(1), x, y)
f2 = function(x, y) Map(paste0, x, y)
f3 = function(x, y) {
z = character(length(x))
for (i in seq_along(x))
z[i] = paste0(x[i], y[i])
z
}
f3c = compiler::cmpfun(f3) # explicitly compile
f4 = function(x, y) {
z = character()
for (i in seq_along(x))
z[i] = paste0(x[i], y[i])
z
}
Scaling the data back, defining the 'vectorized' solution as f0, and
comparing these approaches
> x = y = sample(LETTERS, 100000, TRUE)
> library(microbenchmark)
> microbenchmark(f0(x, y), f1(x, y), f2(x, y), f3(x, y), f3c(x, y), times=5)
Unit: milliseconds
expr min lq median uq max neval
f0(x, y) 14.69877 14.70235 14.75409 14.98777 15.14739 5
f1(x, y) 241.34212 250.19018 268.21613 279.01582 292.21065 5
f2(x, y) 198.74594 199.07489 214.79558 229.50684 271.77853 5
f3(x, y) 250.64388 251.88353 256.09757 280.04688 296.29095 5
f3c(x, y) 174.15546 175.46522 200.09589 201.18543 214.18290 5
with f4 being too painfully slow to include
> system.time(f4(x, y))
user system elapsed
24.325 0.000 24.330
So from this one can see the advice from Dr. Tierney, that there may be a benefit to vectorizing those lapply calls.
Further vectorizing the updated original post
#fnl adopted the original code by partly unrolling the loops. There remain opportunities for more of the same, for instance,
replacements <- function(splits) if (length(splits$left) > 1) {
lapply(1:(length(splits$left)-1), function(i)
paste0(splits$left[i], alphabet, splits$right[i+1]))
} else {
splits$right[1]
}
might be revised to perform a single paste call, relying on argument recycling (short vectors recycled until their length matches longer vectors)
replacements1 <- function(splits) if (length(splits$left) > 1) {
len <- length(splits$left)
paste0(splits$left[-len], rep(alphabet, each = len - 1), splits$right[-1])
} else {
splits$right[1]
}
The values are in different order, but that is not important for the algorithm. Dropping subscripts (prefix with -) is potentially more memory efficient. Similiarly
deletions1 <- function(splits) if (length(splits$left) > 1) {
paste0(splits$left[-length(splits$left)], splits$right[-1])
} else {
splits$right[1]
}
insertions1 <- function(splits)
paste0(splits$left, rep(alphabet, each=length(splits$left)), splits$right)
We then have
edit.1.1 <- function(word) {
splits <- binary.splits(word)
unique(c(deletions1(splits),
transpositions(splits),
replacements1(splits),
insertions1(splits)))
}
with some speed-up
> identical(sort(edit.1("word")), sort(edit.1.1("word")))
[1] TRUE
> microbenchmark(edit.1("word"), edit.1.1("word"))
Unit: microseconds
expr min lq median uq max neval
edit.1("word") 354.125 358.7635 362.5260 372.9185 521.337 100
edit.1.1("word") 296.575 298.9830 300.8305 307.3725 369.419 100
The OP indicates that their original version was 10x slower than
python, and that their original modifications resulted in a 5x
speed-up. We gain a further 1.2x speed-up, so are perhaps at the
expected performance of the algorithm using R's paste0. A next step is to ask whether alternative algorithms or implementations are more performant, in particular substr might be promising.
Following #LukeTierney's tips in the question's comments on vecotrizing paste0 calls and returning two vectors binary.splits, I edited the functions to be correctly vectorized. I have added the additional modifications as described by #MartinMorgan in his answer, too: dropping items using single suffixes instead of using selection ranges (i.e., "[-1]" instead of "[2:n]", etc.; but NB: for multiple suffixes, as used in transpositions, this is actually slower) and, particularly, using rep to further vectorize the paste0 calls in replacements and insertions.
This results in the best possible answer (so far?) to implement edit.2 in R (thank you, Luke and Martin!). In other words, with the main hints provided by Luke and some subsequent improvements by Martin, the R implementation ends up roughly half as fast as Python (but see Martin's final comments in his answer below). (The functions edit.1, edit.2, and bigram.unsafe remain unchanged, as shown above.)
binary.splits <- function(w) {
n <- nchar(w)
list(left=stri_sub(w, rep(0, n + 1), 0:n),
right=stri_sub(w, 1:(n + 1), rep(n, n + 1)))
}
deletions <- function(splits) {
n <- length(splits$left)
if (n > 1) paste0(splits$left[-n], splits$right[-1])
else splits$right[1]
}
transpositions <- function(splits) if (length(splits$left) > 2) {
swaps <- rev(bigram.unsafe(stri_reverse(splits$right[1])))
paste0(splits$left[1:length(swaps)], swaps,
splits$right[3:length(splits$right)])
} else {
stri_reverse(splits$right[1])
}
replacements <- function(splits) {
n <- length(splits$left)
if (n > 1) paste0(splits$left[-n],
rep(alphabet, each=n-1),
splits$right[-1])
else alphabet
}
insertions <- function(splits)
paste0(splits$left,
rep(alphabet, each=length(splits$left)),
splits$right)
Overall, and to conclude this exercise, Luke's and Martin's suggestions made the R implementation run roughly half as fast as the Python code shown in the beginning, improving my original code by about a factor of 6. What worries me even more in the end, however, are two different issues: (1) The R code seems to be far more verbose (LOC, but might be polished up a bit) and (2) the fact that even a slight deviation from "correct vectorization" makes R code perform horrible, while in Python slight deviations from "correct Python" usually do not have such an extreme impact. Nonetheless, I'll keep on with my "coding efficient R" effort - thanks to everybody involved!
My use case is to evaluate Poisson pmf on all points which is less than say, 10, and I would call such function multiple of times with difference lambdas. The lambdas are not known ahead of time so I cannot vectorize lambdas.
I heard from somewhere about a secret trick which is to use _pmf. What is the downside to do so? But still, it is a bit slow, is there any way to improve it without rewriting the pmf in C from scratch?
%timeit scipy.stats.poisson.pmf(np.arange(0,10),3.3)
%timeit scipy.stats.poisson._pmf(np.arange(0,10),3.3)
a = np.arange(0,10)
%timeit scipy.stats.poisson._pmf(a,3.3)
10000 loops, best of 3: 94.5 µs per loop
100000 loops, best of 3: 15.2 µs per loop
100000 loops, best of 3: 13.7 µs per loop
Update
Ok, simply I was just too lazy to write in cython. I had expected there is a faster solution for all discrete distribution that can be evaluated sequentially (iteratively) for consecutive x. E.g. P(X=3) = P(X=2) * lambda / 3 if X ~ Pois(lambda)
Related: Is the build-in probability density functions of `scipy.stat.distributions` slower than a user provided one?
I have less faith in Scipy and Python now. The library function isn't as advanced as what I had expected.
Most of scipy.stats distributions support vectorized evaluation:
>>> poisson.pmf(1, [5, 6, 7, 8])
array([ 0.03368973, 0.01487251, 0.00638317, 0.0026837 ])
This may or may not be fast enough, but you can try taking pmf calls out of the loop.
Re difference between pmf and _pmf: the real work is done in the underscored functions (_pmf, _cdf etc) while the public functions (pmf, cdf) make sure that only valid arguments make it to the _pmf (The output of _pmf is not guaranteed to be meaningful if the arguments are invalid, so use on your own risk).
>>> poisson.pmf(1, -1)
nan
>>> poisson._pmf(1, -1)
/home/br/virtualenvs/scipy-dev/local/lib/python2.7/site-packages/scipy/stats/_discrete_distns.py:432: RuntimeWarning: invalid value encountered in log
Pk = k*log(mu)-gamln(k+1) - mu
nan
Further details: https://github.com/scipy/scipy/blob/master/scipy/stats/_distn_infrastructure.py#L2721
Try implementing the pmf in cython. If your scipy is part of a package like Anaconda or Enthought you probably have cython installed. http://cython.org/
Try running it with pypy. http://pypy.org/
Rent time on a big AWS server (or similar).
I found that the scipy.stats.poisson class is tragically slow compared to just a simple python implementation.
No cython or vectors or anything.
import math
def poisson_pmf(x, mu):
return mu**x / math.factorial(x) * math.exp(-mu)
def poisson_cdf(k, mu):
p_total = 0.0
for x in range(k + 1):
p_total += poisson_pmf(x, mu)
return p_total
And if you check the source code of scipy.stats.poisson (even the underscore prefixed versions) then it is clear why!
The above implementation is now ONLY 10x slower than the exact equivalent in C (compiled with gcc -O3 v9.3). The scipy version is at least another 10x slower.
#include <math.h>
unsigned long factorial(unsigned n) {
unsigned long fact = 1;
for (unsigned k = 2; k <= n; ++k)
fact *= k;
return fact;
}
double poisson_pmf(unsigned x, double mu) {
return pow(mu, x) / factorial(x) * exp(-mu);
}
double poisson_cdf(unsigned k, double mu) {
double p_total = 0.0;
for (unsigned x = 0; x <= k; ++x)
p_total += poisson_pmf(x, mu);
return p_total;
}
PS: This is not a duplicate of How to find the overlap between 2 sequences, and return it
[Although I ask for solutions in above approach if it could be applied to the following problem]
Q: Although I got it right, it is still not a scalable solution and is definitely not optimized (low on score). Read the following description of the problem and kindly offer better solution.
Question:
For simplicity, we require prefixes and suffixes to be non-empty and shorter than the whole string S. A border of a string S is any string that is both a prefix and a suffix. For example, "cut" is a border of a string "cutletcut", and a string "barbararhubarb" has two borders: "b" and "barb".
class Solution { public int solution(String S); }
that, given a string S consisting of N characters, returns the length of its longest border that has at least three non-overlapping occurrences in the given string. If there is no such border in S, the function should return 0.
For example,
if S = "barbararhubarb" the function should return 1, as explained above;
if S = "ababab" the function should return 2, as "ab" and "abab" are both borders of S, but only "ab" has three non-overlapping occurrences;
if S = "baaab" the function should return 0, as its only border "b" occurs only twice.
Assume that:
N is an integer within the range [0..1,000,000];
string S consists only of lower-case letters (a−z).
Complexity:
expected worst-case time complexity is O(N);
expected worst-case space complexity is O(N) (not counting the storage required for input arguments).
def solution(S):
S = S.lower()
presuf = []
f = l = str()
rank = []
wordlen = len(S)
for i, j in enumerate(S):
y = -i-1
f += S[i]
l = S[y] + l
if f==l and f != S:
#print f,l
new=S[i+1:-i-1]
mindex = new.find(f)
if mindex != -1:
mid = f #new[mindex]
#print mid
else:
mid = None
presuf.append((f,mid,l,(i,y)))
#print presuf
for i,j,k,o in presuf:
if o[0]<wordlen+o[-1]: #non overlapping
if i==j:
rank.append(len(i))
else:
rank.append(0)
if len(rank)==0:
return 0
else:
return max(rank)
My solutions time complexity is: O(N2) or O(N4)
Help greatly appreciated.
My solution is combination between Rabin-Karp and Knuth–Morris–Pratt algorithms.
http://codility.com/cert/view/certB6J4FV-W89WX4ZABTDRVAG6/details
I have a (Java) solution that performs O(N) or O(N**3), for a resulting 90/100 overall, but I can't figure out how to make it go though 2 different testcases:
almost_all_same_letters
aaaaa...aa??aaaa??....aaaaaaa 2.150 s. TIMEOUT ERROR
running time: >2.15 sec., time limit: 1.20 sec.
same_letters_on_both_ends 2.120 s. TIMEOUT ERROR
running time: >2.12 sec., time limit: 1.24 sec.
Edit: Nailed it!
Now I have a solution that perform in O(N) and passes all the checks for a 100/100 result :)
I didn't know Codility, but it's a nice tool!
I have a solution with suffix arrays (there actually is algorithm for constructing SA and LCP in linear time or something bit worse than that, but surely not quadratic).
Still not sure if I can go without RMQs ( O(log n) with SegmentTree) which I couldn't make pass my own cases and seems quite complicated, but with RMQs it can (not mentioning approach with for loop instead of RMQ, that would make it quadratic anyway).
Solution is performing quite fast and passing my 21 test cases with various perks I've managed to craft, but still failing on some of their cases. Not sure if that helped you or gave you idea how to approach the problem, but I am sure that naive solution, like #Vicenco said in some of his comments, can't get you better than Silver.
EDIT:
managed to fix it all problems, but still to slow. I had to enforce some conditions but had to increase complexity with this, still not sure how to optimize that. Will keep you posted. Good luck!
protected int calcBorder(String input) {
if (null != input) {
int mean = (input.length() / 3);
while (mean >= 1) {
if (input.substring(0, mean).equals(
input.substring(input.length() - mean))) {
String reference = input.substring(0, mean);
String temp = input
.substring(mean, (input.length() - mean));
int startIndex = 0;
int endIndex = mean;
int count = 2;
while (endIndex <= temp.length()) {
if (reference.equals(temp.substring(startIndex,
endIndex))) {
count++;
if (count >= 3) {
return reference.length();
}
}
startIndex++;
endIndex++;
}
}
mean--;
}
}
return 0;
}
The Z-Algorithm would be a good solution.