Why is FlatMap after GroupByKey in Apache Beam python so slow? - python

Given a relatively small data source (3,000-10,000) of key/value pairs, I am trying to only process records which meet a group threshold (50-100). So the simplest method is to group them by key, filter and unwind - either with FlatMap or a ParDo. The largest group has only 1,500 records so far. But this seems to be a severe bottleneck in production on Google Cloud Dataflow.
With given list
(1, 1)
(1, 2)
(1, 3)
...
(2, 1)
(2, 2)
(2, 3)
...
run through a set of transforms to filter and group by key:
p | 'Group' >> beam.GroupByKey()
| 'Filter' >> beam.Filter(lambda (key, values): len(list(values)) > 50)
| 'Unwind' >> beam.FlatMap(lambda (key, values): values)
Any ideas on how to make this more performant? Thanks for your help!

This is an interesting corner case for a pipeline. I believe that your issue here is on the way you read the data that comes from GroupByKey. Let me give you a quick summary of how GBK works.
What's GroupByKey, and how big data systems implement it
All big data systems implement ways to realize operations over multiple elements of the same key. This was called reduce in MapReduce, and in other big data systems is called Group By Key, or Combine.
When you do a GroupByKey transform, Dataflow needs to gather all the elements for a single key into the same machine. Since different elements for the same key may be processed in different machines, data needs to be serialized somehow.
This means that when you read data that comes from a GroupByKey, you are accessing the IO of the workers (i.e. not from memory), so you really want to avoid reading shuffle data too many times.
How this translates to your pipeline
I believe that your problem here is that Filter and Unwind will both read data from shuffle separately (so you will read the data for each list twice). What you want to do is to read your shuffle data only once. You can do this with a single FlatMap that both filters and unwinds your data without double-reading from shuffle. Something like this:
def unwind_and_filter((key, values)):
# This consumes all the data from shuffle
value_list = list(values)
if len(value_list) > 50:
yield value_list
p | 'Group' >> beam.GroupByKey()
| 'UnwindAndFilter' >> beam.FlatMap(unwind_and_filter)
Let me know if this helps.

Related

Is there an efficient way to fill a numba `Dict` in parallel?

I'm having some trouble quickly filling a numba Dict object with key-value pairs (around 63 million of them). Is there an efficient way to do this in parallel?
The documentation (https://numba.pydata.org/numba-doc/dev/reference/pysupported.html#typed-dict) is clear that numba.typed.Dict is not thread-safe, and so I think to use prange with a single Dict object would be a bad idea. I've tried to use a numba List of Dicts, to populate them in parallel and then stitch them together using update, but I think this last step is also inefficient.
Note that one thing (which may be important) is that all the keys are unique, i.e. once assigned, a key will not be reassigned a value. I think this property makes the problem amenable to an efficient parallelised solution.
Below is an example of the serial approach, which is slow with a large number of key-value pairs.
d = typed.Dict.empty(
key_type=types.UnicodeCharSeq(128), value_type=types.int64
)
#njit
def fill_dict(keys_list, values_list, d):
n = len(keys_list)
for i in range(n):
d[keys_list[i]] = values_list[i]
fill_dict(keys_list, values_list, d)
Can anybody help me?
Many thanks.
You don't have to stitch them together if you preprocess the key into an integer that can be computed for its modulo num_shard value.
# assuming hash() returns an arbitrary integer computed by ascii values
shard = hash(key) % num_shard;
selected_dictionary = dictionary[shard]
value = selected_dictionary[key]
# inserting
# lock only the selected_dictionary
shard = hash(key) % num_shard;
selected_dictionary = dictionary[shard]
selected_dictionary.push((key,value))
The hashing could be something like sum of all ascii codes of chars in key. The modulo based indexing separates blocks of keys so that they can work independently without extra processing except hashing.

Can I get more than one iterators from group by?

I am a python beginner and was facing a issue with iterating over a grouped data more than once. I understand that once consumed an iterator can't be re-used but is it possible to get multiple iterators from single groupby()?
This answer says that multiple iterators can be created over lists etc. But i don't understand how I can do the same for groupby?
Multiple Iterators
What I am trying to do is as follows:
I have data that are (key, value) pairs and I want to groupby key.
There is some special kind of data based on the value part in each
group and I want to extract these special pairs and process them
separately.
After I am done I need to go back to the original data and process
the remaining pairs (this is where I need the second iterator).
If you need to see my code here is the basic layout of what I am doing but I dunno if it is really required:
for current_vertex, group in groupby(data, itemgetter(0)):
try:
# Special data extraction
matching = [int(value.rstrip().split(':')[0]) for key, value in group if CURRENT_NODE_IDENTIFIER in value]
if len(matching) != 0:
# Do something with the data extracted (some variables generated here -- say x, y z)
for key, value in group:
if not CURRENT_NODE_IDENTIFIER in value:
# Do something with remaining key, value pairs (use x, y, z)
In case anyone is wondering the same, I resolved the problem by duplicating the iterator as described here:
How to duplicate an Iterator?
Since the group itself is an iterator all I had to do was duplicate it as:
# To duplicate an iterator given the iterator group
group, duplicate_iterator = tee(group)
Don't forget to import tee function from itertools. I don't know if this is best way possible but at least it works and get the job done.

Python: update a frequency table (in the form of a list of lists)

I have two lists of US state abbreviations (for example):
s1=['CO','MA','IN','OH','MA','CA','OH','OH']
s2=['MA','FL','CA','GA','MA','OH']
What I want to end up with is this (basically an ordered frequency table):
S=[['CA',2],['CO',1],['FL',1],['GA',1],['IN',1],['MA',4],['OH',4]]
The way I came up with was:
s3=s1+s2
S=[[x,s3.count(x)] for x in set(s3)]
This works great - though, tbh, I don't know that this is very memory efficient.
BUT... there is a catch.
s1+s2
...is too big to hold in memory, so what I'm doing is appending to s1 until it reaches a length of 10K (yes, resources are THAT limited), then summarizing it (using the list comprehension step above), deleting the contents of s1, and re-filling s1 with the next chunk of data (only represented as 's2' above for purpose of demonstration). ...and so on through the loop until it reaches the end of the data.
So with each iteration of the loop, I want to sum the 'base' list of lists 'S' with the current iteration's list of lists 's'. My question is, essentially, how do I add these:
(the current master data):
S=[['CA',1],['CO',1],['IN',1],['MA',2],['OH',3]]
(the new data):
s=[['CA',1],['FL',1],['GA',1],['MA',2],['OH',1]]
...to get (the new master data):
S=[['CA',2],['CO',1],['FL',1],['GA',1],['IN',1],['MA',4],['OH',4]]
...in some sort of reasonably efficient way. If this is better to do with dictionaries or something else, I am fine with that. What I can't do, unfortunately is make use of ANY remotely specialized Python module -- all I have to work with is the most stripped-down version of Python 2.6 imaginable in a closed-off, locked-down, resource-poor Linux environment (hazards of the job). Any help is greatly appreciated!!
You can use itertools.chain to chain two iterators efficiently:
import itertools
import collections
counts = collections.Counter()
for val in itertools.chain(s1, s2): # memory efficient
counts[val] += 1
A collections.Counter object is a dict specialized for counting... if you know how to use a dict you can use a collections.Counter. However, it allows you to write the above more succinctly as:
counts = collections.Counter(itertools.chain(s1, s2))
Also note, the following construction:
S=[[x,s3.count(x)] for x in set(s3)]
Happens to also be very time inefficient, since you are calling s3.count in a loop. Although, this might not be too bad if len(set(s3)) << len(s3)
Note, you can do the chaining "manually" by doing something like:
it1 = iter(s1)
it2 = iter(s2)
for val in it1:
...
for val in it2:
...
You can run Counter.update as many times as you like, cutting your data to fit in memory / streaming them as you like.
import collections
counter = collections.Counter()
counter.update(['foo', 'bar'])
assert counter['foo'] == counter['bar'] == 1
counter.update(['foo', 'bar', 'foo'])
assert counter['foo'] == 3
assert counter['bar'] == 2
assert sorted(counter.items(), key=lambda rec: -rec[1]) == [('foo', 3), ('bar', 2)]
The last line uses negated count as the sorting key to make the higher counts come first.
If with that your count structure does not fit in memory, you need a (disk-based) database, such as Postgres, or likely just a machine with more memory and a more efficient key-value store, such as Redis.

Fast Looping over combinations of elements from two large ditionaries of sets

I have two very big dictionaries of sets and I want to loop over all combinations of pairs of them to calculate a "score" for each pair and store this in another object. The key in each dictionary is a unique identifier for each set. The code I am currently using is something along the lines of:
score_matrix = {}
for id_1, set_1 in set_dict_1:
number_elements_1 = len(set_1)
score_matrix[id_1] = {}
for id_2, set_2 in set_dict_2:
score_matrix[id_1][id_2] = float(len(set_1 & set_2)) / number_elements_1
I am testing this on data where dict_1 and dict_2 have around 25k elements. So this thing has to process around 625,000,000 combinations! Obviously I don't expect this to be done in seconds but in pure python like this it is taking too long. In fact, I don't know how long it takes because I didn't let it finish running but it was going for around 15 mins before I gave up.
I assume there is a more efficient way to try and achieve this. Any ideas? Perhaps using numpy might help but I'm not sure how.

Best way to sort 1M records in Python

I have a service that runs that takes a list of about 1,000,000 dictionaries and does the following
myHashTable = {}
myLists = { 'hits':{}, 'misses':{}, 'total':{} }
sorted = { 'hits':[], 'misses':[], 'total':[] }
for item in myList:
id = item.pop('id')
myHashTable[id] = item
for k, v in item.iteritems():
myLists[k][id] = v
So, if I had the following list of dictionaries:
[ {'id':'id1', 'hits':200, 'misses':300, 'total':400},
{'id':'id2', 'hits':300, 'misses':100, 'total':500},
{'id':'id3', 'hits':100, 'misses':400, 'total':600}
]
I end up with
myHashTable =
{
'id1': {'hits':200, 'misses':300, 'total':400},
'id2': {'hits':300, 'misses':100, 'total':500},
'id3': {'hits':100, 'misses':400, 'total':600}
}
and
myLists =
{
'hits': {'id1':200, 'id2':300, 'id3':100},
'misses': {'id1':300, 'id2':100, 'id3':400},
'total': {'id1':400, 'id2':500, 'id3':600}
}
I then need to sort all of the data in each of the myLists dictionaries.
What I doing currently is something like the following:
def doSort(key):
sorted[key] = sorted(myLists[key].items(), key=operator.itemgetter(1), reverse=True)
which would yield, in the case of misses:
[('id3', 400), ('id1', 300), ('id2', 200)]
This works great when I have up to 100,000 records or so, but with 1,000,000 it is taking at least 5 - 10 minutes to sort each with a total of 16 (my original list of dictionaries actually has 17 fields including id which is popped)
* EDIT * This service is a ThreadingTCPServer which has a method
allowing a client to connect and add
new data. The new data may include
new records (meaning dictionaries with
unique 'id's to what is already in
memory) or modified records (meaning
the same 'id' with different data for
the other key value pairs
So, once this is running I would pass
in
[
{'id':'id1', 'hits':205, 'misses':305, 'total':480},
{'id':'id4', 'hits':30, 'misses':40, 'total':60},
{'id':'id5', 'hits':50, 'misses':90, 'total':20
]
I have been using dictionaries to
store the data so that I don't end up
with duplicates. After the
dictionaries are updated with the
new/modified data I resort each of
them.
* END EDIT *
So, what is the best way for me to sort these? Is there a better method?
You may find this related answer from Guido: Sorting a million 32-bit integers in 2MB of RAM using Python
What you really want is an ordered container, instead of an unordered one. That would implicitly sort the results as they're inserted. The standard data structure for this is a tree.
However, there doesn't seem to be one of these in Python. I can't explain that; this is a core, fundamental data type in any language. Python's dict and set are both unordered containers, which map to the basic data structure of a hash table. It should definitely have an optimized tree data structure; there are many things you can do with them that are impossible with a hash table, and they're quite tricky to implement well, so people generally don't want to be doing it themselves.
(There's also nothing mapping to a linked list, which also should be a core data type. No, a deque is not equivalent.)
I don't have an existing ordered container implementation to point you to (and it should probably be implemented natively, not in Python), but hopefully this will point you in the right direction.
A good tree implementation should support iterating across a range by value ("iterate all values from [2,100] in order"), find next/prev value from any other node in O(1), efficient range extraction ("delete all values in [2,100] and return them in a new tree"), etc. If anyone has a well-optimized data structure like this for Python, I'd love to know about it. (Not all operations fit nicely in Python's data model; for example, to get next/prev value from another value, you need a reference to a node, not the value itself.)
If you have a fixed number of fields, use tuples instead of dictionaries. Place the field you want to sort on in first position, and just use mylist.sort()
This seems to be pretty fast.
raw= [ {'id':'id1', 'hits':200, 'misses':300, 'total':400},
{'id':'id2', 'hits':300, 'misses':100, 'total':500},
{'id':'id3', 'hits':100, 'misses':400, 'total':600}
]
hits= [ (r['hits'],r['id']) for r in raw ]
hits.sort()
misses = [ (r['misses'],r['id']) for r in raw ]
misses.sort()
total = [ (r['total'],r['id']) for r in raw ]
total.sort()
Yes, it makes three passes through the raw data. I think it's faster than pulling out the data in one pass.
Instead of trying to keep your list ordered, maybe you can get by with a heap queue. It lets you push any item, keeping the 'smallest' one at h[0], and popping this item (and 'bubbling' the next smallest) is an O(nlogn) operation.
so, just ask yourself:
do i need the whole list ordered all the time? : use an ordered structure (like Zope's BTree package, as mentioned by Ealdwulf)
or the whole list ordered but only after a day's work of random insertions?: use sort like you're doing, or like S.Lott's answer
or just a few 'smallest' items at any moment? : use heapq
Others have provided some excellent advices, try them out.
As a general advice, in situations like that you need to profile your code. Know exactly where most of the time is spent. Bottlenecks hide well, in places you least expect them to be.
If there is a lot of number crunching involved then a JIT compiler like the (now-dead) psyco might also help. When processing takes minutes or hours 2x speed-up really counts.
http://docs.python.org/library/profile.html
http://www.vrplumber.com/programming/runsnakerun/
http://psyco.sourceforge.net/
sorted(myLists[key], key=mylists[key].get, reverse=True)
should save you some time, though not a lot.
I would look into using a different sorting algorithm. Something like a Merge Sort might work. Break the list up into smaller lists and sort them individually. Then loop.
Pseudo code:
list1 = [] // sorted separately
list2 = [] // sorted separately
// Recombine sorted lists
result = []
while (list1.hasMoreElements || list2.hasMoreElements):
if (! list1.hasMoreElements):
result.addAll(list2)
break
elseif (! list2.hasMoreElements):
result.AddAll(list1)
break
if (list1.peek < list2.peek):
result.add(list1.pop)
else:
result.add(list2.pop)
Glenn Maynard is correct that a sorted mapping would be appropriate here. This is one for python: http://wiki.zope.org/ZODB/guide/node6.html#SECTION000630000000000000000
I've done some quick profiling of both the original way and SLott's proposal. In neither case does it take 5-10 minutes per field. The actual sorting is not the problem. It looks like most of the time is spent in slinging data around and transforming it. Also, my memory usage is skyrocketing - my python is over 350 megs of ram! are you sure you're not using up all your ram and paging to disk? Even with my crappy 3 year old power saving processor laptop, I am seeing results way less than 5-10 minutes per key sorted for a million items. What I can't explain is the variability in the actual sort() calls. I know python sort is extra good at sorting partially sorted lists, so maybe his list is getting partially sorted in the transform from the raw data to the list to be sorted.
Here's the results for slott's method:
done creating data
done transform. elapsed: 16.5160000324
sorting one key slott's way takes 1.29699993134
here's the code to get those results:
starttransform = time.time()
hits= [ (r['hits'],r['id']) for r in myList ]
endtransform = time.time()
print "done transform. elapsed: " + str(endtransform - starttransform)
hits.sort()
endslottsort = time.time()
print "sorting one key slott's way takes " + str(endslottsort - endtransform)
Now the results for the original method, or at least a close version with some instrumentation added:
done creating data
done transform. elapsed: 8.125
about to get stuff to be sorted
done getting data. elapsed time: 37.5939998627
about to sort key hits
done sorting on key <hits> elapsed time: 5.54699993134
Here's the code:
for k, v in myLists.iteritems():
time1 = time.time()
print "about to get stuff to be sorted "
tobesorted = myLists[k].items()
time2 = time.time()
print "done getting data. elapsed time: " + str(time2-time1)
print "about to sort key " + str(k)
mysorted[k] = tobesorted.sort( key=itemgetter(1))
time3 = time.time()
print "done sorting on key <" + str(k) + "> elapsed time: " + str(time3-time2)
Honestly, the best way is to not use Python. If performance is a major concern for this, use a faster language.

Categories