caching and reusing pyspark dataframe in loop - python

I have 2 script and a use case where I need to create a dataframe in one script and use it in another in loop. something like below:
script 1 :
def generate_data(spark, logger, conf):
processed_data_final = None
path_1 = conf["raw_data_path_1"]
path_2 = conf["raw_data_path_2"]
df_path1 = spark.read.parquet(path_1)
df_path1.cache()
df_path1.take(1) //calling action item as spark does lazy evaluation
df_path2 = spark.read.parquet(path_2)
df_path2.cache()
df_path2.take(1)
for dt in date_list:
processed_data = process_data(spark, logger, conf, dt, df_path1, df_path2)
if processed_data is None:
processed_data_final = processed_data
else:
processed_data_final = processed_data_final.union(processed_data)
return processed_data_final
if __name__ == "__main__":
# generate global variables: spark, logger
if 5 == len(sys.argv):
env = sys.argv[1]
job_id = sys.argv[2]
else:
print("parameter {env} or {job_id}")
exit(1)
app_name = "past_viewership_" + job_id
spark = SparkSession \
.builder \
.appName(app_name) \
.config("spark.storage.memoryFraction", 0) \
.config("spark.driver.maxResultSize", "-1") \
.getOrCreate()
sc = spark.sparkContext
generate_data(spark, logger, conf)
In script 2 I reuse the dateframe from script1 in script2 like:
def process_data(spark, conf, df_path1, df_path2):
path3= conf['path3']
df3 = spark.read.parquet(path3)
res_df = df3.join(df_path1, ["id"],"inner").join(df_path2,["id"], "inner")
return res_df
This is rough code explaining the flow, in this flow I see in the logs that it is loading df_path1 and df_path2 again in the loop. I was expecting it to use the cached dataframe. How can can I avoid reading the df_path1 and df_path2 again in the loop?

Calling dataframe.take(1) does not materialize the entire dataframe. Spark's Catalyst optimizer will modify the physical plan to only read the first partition of the dataframe since only the first record is needed. Hence, only the first partition is cached until the rest of the records are read.

Related

What is the best way to parallel process in DataBricks to minimize query time?

I am working on a project where I need to take a list of ids and run said ids through an API pull that can only retrieve one record detail at a time per submitted id. Essentially what I have is a dataframe called df_ids that consist of over 12M ids that needs to go though the below function in order to obtain that information requested by the end user for the entire population:
def ELOQUA_CONTACT(id):
API = EloquaAPI(f'1.0/data/contact/{id}')
try:
contactid = API['id'].lower()
except:
contactid = ''
try:
company = API['accountName']
except:
company = ''
df = pd.DataFrame([contactid, company]).T.rename(columns={0:'contactid', 1:'company'})
return df
If I run something like ELOQUA_CONTACT(df_ids['Eloqua_Contact_IDs'][2]) it will give me the API record for the id = 2 in the form of a dataframe. The issue is, now I need to scale this to the entire 12M id population and build it in a way that it can be run and processed on a daily basis.
I have tried two techniques for parallel processing in DataBricks (python based; AWS backed). The first is based off a template that my manager developed for threading and when sampling it for just 1000 records it takes just shy of 2 minutes to query.
def DealEntries(df_input,n_sets):
n_rows = df_input.shape[0]
entry_per_set = n_rows // n_sets
extra = n_rows % n_sets
outlist = []
for i in range(n_sets):
if i != n_sets - 1:
idx = range(0+entry_per_set * i, entry_per_set * (i + 1))
else:
idx = range(0+entry_per_set * i, entry_per_set * (i + 1) + extra)
outlist.append(idx)
return outlist
class ThreadWithReturnValue(Thread):
def __init__(self, group=None, target=None, name=None, args=(), kwargs=None, *, daemon=None):
Thread.__init__(self, group, target, name, args, kwargs, daemon=daemon)
self._return = None
def run(self):
if self._target is not None:
self._return = self._target(*self._args, **self._kwargs)
def join(self):
Thread.join(self)
return self._return
data_input = pd.DataFrame(df_ids['Eloqua_Contact_IDs'][:1000])
rows_per_thread = 300
n_rows = data_input.shape[0]
threads = ceil(n_rows/rows_per_thread)
completed = 0
global df_results
outlist = DealEntries(data_input, threads)
df_results = []
for i in range(threads):
rng = [x for x in outlist[i]]
curr_input = data_input['Eloqua_Contact_IDs'][rng]
jobs = []
for id in curr_input.astype(str):
thread = ThreadWithReturnValue(target=ELOQUA_CONTACT, kwargs={'id' : id})
jobs.append(thread)
for j in jobs:
j.start()
for j in jobs:
df_results.append(j.join())
df_out = pd.concat(df_results)
df_out
The second method is something that I just put together and runs in about 20 seconds.
from multiprocessing.pool import ThreadPool
parallels = ThreadPool(1000)
df_results = parallels.map(ELOQUA_CONTACT, [i for i in df_ids['Eloqua_Contact_IDs'][:1000]])
df_out = pd.concat(df_results)
df_out
This issues with both of these is that when scalling the time per record up from 1k to 12M, the first method would take around 916 days to run and the second would take like 167 days to run. This needs to be scaled and parallel processed to a level that can run the 12M records in less then a day. Is there any other methologies or features associated with DataBricks/AWS/Python/Spark/etc that I can leverage to meet this objective? Once built, this would be put into a scheduled workflow(formally job) in DataBricks and run on its own spinup cluster that I can alter the backend resources with (CPU + RAM size).
Any insight or advice is very much welcomed. Thank you.

Pyspark Luigi multiple workers issue

I want to load multiple files in spark data frame in parallel using Luigi workflow and store them in dictionary .
Once all the files are loaded,i want to be able to access these data-frame from dictionary in main and then do further processing.This process is working when i am running Luigi with one worker.if running Luigi with more than one worker,this variable is empty in main method.
Any suggestion will be helpful.
import Luigi
from Luigi import LocalTarget
from pyspark import SQLContext
from src.etl.SparkAbstract import SparkAbstract
from src.util.getSpark import get_spark_session
from src.util import getSpark,read_json
import configparser as cp
import datetime
from src.input.InputCSVFileComponent import InputCSVFile
import os
from src.etl.Component import ComponentInfo
class fileloadTask(luigi.Task):
compinfo = luigi.Parameter()
def output(self):
return luigi.LocalTarget("src/workflow_output/"+str(datetime.date.today().isoformat() )+"-"+ str(self.compinfo.id)+".csv")
def run(self):
a = InputCSVFile(self.compinfo) ##this class is responsible to return the object of spark dataframe and put it in dictionary
a.execute()
with self.output().open('w') as f:
f.write("done")
class EnqueueTask(luigi.WrapperTask):
compinfo = read_json.read_json_config('path to json file')
def requires(self):
folders = [
comp.id for comp in list(self.compinfo) if comp.component_type == 'INPUTFILE'
]
print(folders)
newcominfo = []
for index, objid in enumerate(folders):
newcominfo.append(self.compinfo[index])
for i in newcominfo:
print(f" in compingo..{i.id}")
callmethod = [fileloadTask(compinfo) for compinfo in newcominfo]
print(callmethod)
return callmethod
class MainTask(luigi.WrapperTask):
def requires(self):
return EnqueueTask()
def output(self):
return luigi.LocalTarget("src/workflow_output/"+str(datetime.date.today().isoformat() )+"-"+ "maintask"+".csv")
def run(self):
print(f"printing mapdf..{SparkAbstract.mapDf}")
res = not SparkAbstract.mapDf
print("Is dictionary empty ? : " + str(res)) ####-------------> this is empty when workers > 1 ################
for key, value in SparkAbstract.mapDf.items():
print("prinitng from dict")
print(key, value.show(10))
with self.output().open('w') as f:
f.write("done")
"""
entry point for spark application
"""
if __name__ == "__main__":
luigi.build([MainTask()],workers=2,local_scheduler=True)
Each worker runs in its own process. That mean workers can't share python object (in this instance the dictionary in which you put the results).
Generally speaking luigi is best to orchestrate tasks with side effects (like writing to files etc).
If you you are trying to parallelise tasks that load data in memory, I'd recommand using dask instead of luigi.

create list of variables that satisfy condition without manually typing them

I want to do this:
cwd = Path(os.getcwd())
target = 'data/aggregations/'
# want this:
agg_12.to_csv(cwd / target / 'agg_12.csv', index=False)
for each of my dataframe variables:
agg_12 = make_agg_12(df)
agg_11 = make_agg_11(agg_12)
agg_10 = make_agg_10(agg_11)
agg_9 = make_agg_9(agg_12)
agg_8 = make_agg_8(agg_9)
agg_7 = make_agg_7(agg_9)
agg_6 = make_agg_6(agg_7)
agg_5 = make_agg_5(agg_7)
agg_4 = make_agg_4(agg_8)
agg_3 = make_agg_3(agg_8)
agg_2 = make_agg_2(agg_6)
agg_1 = make_agg_1(agg_2)
It doesn't seem pythonic to write it for each one.
This solution works (but...):
""" as a convenience, this step saves the aggregations locally so this doesn't need to run again. """
cwd = Path(os.getcwd())
target = 'data/aggregations/'
# template: agg_12.to_csv(cwd / target / 'agg_12.csv', index=False)
agg_list = [agg_12, agg_11, agg_10, agg_9, agg_8, agg_7, agg_6, agg_5, agg_4, agg_3, agg_2, agg_1]
agg_str = ['agg_12', 'agg_11', 'agg_10', 'agg_9', 'agg_8', 'agg_7', 'agg_6', 'agg_5', 'agg_4', 'agg_3', 'agg_2', 'agg_1']
for agg_i, agg in enumerate(agg_list):
agg_name = agg_str[agg_i]
agg.to_csv(cwd / target / f'{agg_name}.csv', index=False)
print('wrote: ', f'{agg_name}.csv')
However, what if I had 100s or 1000s of df variables that I wanted to do this for. How can I create a list of them without typing them out manually?
I have tried using dir(), which of course only gives me the string name of the variable, not an actual reference.

Multiprocessing in python - processes not closing after completing

I have a Process pool in python that is starting processes as normal, however, I have just realized that these processes are not closed after the completion (I know that they completed as the last statement is a file write).
Below the code, with an example function ppp:
from multiprocessing import Pool
import itertools
def ppp(element):
window,day = element
print(window,day)
time.sleep(10)
if __name__ == '__main__': ##The line marked
print('START')
start_time = current_milli_time()
days = ['0808', '0810', '0812', '0813', '0814', '0817', '0818', '0827']
windows = [1000,2000,3000,4000,5000,10000,15000, 20000,30000,60000,120000,180000]
processes_args = list(itertools.product(windows, days))
pool = Pool(8)
results = pool.map(ppp, processes_args)
pool.close()
pool.join()
print('END', current_milli_time()-start_time)
I am working on Linux, Ubuntu 16.04. Everything was working fine before I added the line marked in the example. I am wondering if that behavior can be related to the missing of a return statement. Anyway, that is what looks like my 'htop':
As you can see, no process is closed, but all have completed their work.
I found that related question: Python Multiprocessing pool.close() and join() does not close processes, however, I have not understood if the solution to this problem is to use map_async instead of map.
EDIT: real function code:
def process_day(element):
window,day = element
noise = 0.2
print('Processing day:', day,', window:', window)
individual_files = glob.glob('datan/'+day+'/*[0-9].csv')
individual = readDataset(individual_files)
label_time = individual.loc[(individual['LABEL_O'] != -2) | (individual['LABEL_F'] != -2), 'TIME']
label_time = list(np.unique(list(label_time)))
individual = individual[individual['TIME'].isin(label_time)]
#Saving IDs for further processing
individual['ID'] = individual['COLLAR']
#Time variable in seconds for aggregation and merging
individual['TIME_S'] = individual['TIME'].copy()
noise_x = np.random.normal(0,noise,len(individual))
noise_y = np.random.normal(0,noise,len(individual))
noise_z = np.random.normal(0,noise,len(individual))
individual['X_AXIS'] = individual['X_AXIS'] + noise_x
individual['Y_AXIS'] = individual['Y_AXIS'] + noise_y
individual['Z_AXIS'] = individual['Z_AXIS'] + noise_z
#Time syncronization (applying milliseconds for time series processing)
print('Time syncronization:')
with progressbar.ProgressBar(max_value=len(individual.groupby('ID'))) as bar:
for baboon,df_baboon in individual.groupby('ID'):
times = list(df_baboon['TIME'].values)
d = Counter(times)
result = []
for timestamp in np.unique(times):
for i in range(0,d[timestamp]):
result.append(str(timestamp+i*1000/d[timestamp]))
individual.loc[individual['ID'] == baboon,'TIME'] = result
bar.update(1)
#Time series process
ts_process = time_series_processing(window, 'TIME_S', individual, 'COLLAR', ['COLLAR', 'TIME', 'X_AXIS','Y_AXIS','Z_AXIS'])
#Aggregation and tsfresh
ts_process.do_process()
individual = ts_process.get_processed_dataframe()
individual.to_csv('noise2/processed_data/'+str(window)+'/agg/'+str(day)+'.csv', index = False)
#NEtwork inference process
ni = network_inference_process(individual, 'TIME_S_mean')
#Inference
ni.do_process()
final = ni.get_processed_dataframe()
final.to_csv('noise2/processed_data/'+str(window)+'/net/'+str(day)+'.csv', index = False)
#Saving not aggregated ground truth
ground_truth = final[['ID_mean', 'TIME_S_mean', 'LABEL_O_values', 'LABEL_F_values']].copy()
#Neighbor features process
neighbors_features_f = ni.get_neighbor_features(final, 'TIME_S_mean', 'ID_mean')
neighbors_features_f = neighbors_features_f.drop(['LABEL_O_values_n', 'LABEL_F_values_n'], axis=1)
neighbors_features_f.to_csv('noise2/processed_data/'+str(window)+'/net/'+str(day)+'_neigh.csv', index = False)
# Final features dataframe
final_neigh = pd.merge(final, neighbors_features_f, how='left', left_on=['TIME_S_mean','ID_mean'], right_on = ['TIME_S_mean_n','BABOON_NODE_n'])
final_neigh.to_csv('noise2/processed_data/'+str(window)+'/complete/'+str(day)+'.csv', index = False)
return
So as you can see, the last statement is a write to file, and it is executed by all the processes, I do not actually think that the problem is inside this function.

in python: child processes going defunct while others are not, unsure why

edit: the answer was that the os was axing processes because i was consuming all the memory
i am spawning enough subprocesses to keep the load average 1:1 with cores, however at some point within the hour, this script could run for days, 3 of the processes go :
tipu 14804 0.0 0.0 328776 428 pts/1 Sl 00:20 0:00 python run.py
tipu 14808 64.4 24.1 2163796 1848156 pts/1 Rl 00:20 44:41 python run.py
tipu 14809 8.2 0.0 0 0 pts/1 Z 00:20 5:43 [python] <defunct>
tipu 14810 60.3 24.3 2180308 1864664 pts/1 Rl 00:20 41:49 python run.py
tipu 14811 20.2 0.0 0 0 pts/1 Z 00:20 14:04 [python] <defunct>
tipu 14812 22.0 0.0 0 0 pts/1 Z 00:20 15:18 [python] <defunct>
tipu 15358 0.0 0.0 103292 872 pts/1 S+ 01:30 0:00 grep python
i have no idea why this is happening, attached is the master and slave. i can attach the mysql/pg wrappers if needed as well, any suggestions?
slave.py:
from boto.s3.key import Key
import multiprocessing
import gzip
import os
from mysql_wrapper import MySQLWrap
from pgsql_wrapper import PGSQLWrap
import boto
import re
class Slave:
CHUNKS = 250000
BUCKET_NAME = "bucket"
AWS_ACCESS_KEY = ""
AWS_ACCESS_SECRET = ""
KEY = Key(boto.connect_s3(AWS_ACCESS_KEY, AWS_ACCESS_SECRET).get_bucket(BUCKET_NAME))
S3_ROOT = "redshift_data_imports"
COLUMN_CACHE = {}
DEFAULT_COLUMN_VALUES = {}
def __init__(self, job_queue):
self.log_handler = open("logs/%s" % str(multiprocessing.current_process().name), "a");
self.mysql = MySQLWrap(self.log_handler)
self.pg = PGSQLWrap(self.log_handler)
self.job_queue = job_queue
def do_work(self):
self.log(str(os.getpid()))
while True:
#sample job in the abstract: mysql_db.table_with_date-iteration
job = self.job_queue.get()
#queue is empty
if job is None:
self.log_handler.close()
self.pg.close()
self.mysql.close()
print("good bye and good day from %d" % (os.getpid()))
self.job_queue.task_done()
break
#curtail iteration
table = job.split('-')[0]
#strip redshift table from job name
redshift_table = re.sub(r"(_[1-9].*)", "", table.split(".")[1])
iteration = int(job.split("-")[1])
offset = (iteration - 1) * self.CHUNKS
#columns redshift is expecting
#bad tables will slip through and error out, so we catch it
try:
colnames = self.COLUMN_CACHE[redshift_table]
except KeyError:
self.job_queue.task_done()
continue
#mysql fields to use in SELECT statement
fields = self.get_fields(table)
#list subtraction determining which columns redshift has that mysql does not
delta = (list(set(colnames) - set(fields.keys())))
#subtract columns that have a default value and so do not need padding
if delta:
delta = list(set(delta) - set(self.DEFAULT_COLUMN_VALUES[redshift_table]))
#concatinate columns with padded \N
select_fields = ",".join(fields.values()) + (",\\N" * len(delta))
query = "SELECT %s FROM %s LIMIT %d, %d" % (select_fields, table,
offset, self.CHUNKS)
rows = self.mysql.execute(query)
self.log("%s: %s\n" % (table, len(rows)))
if not rows:
self.job_queue.task_done()
continue
#if there is more data potentially, add it to the queue
if len(rows) == self.CHUNKS:
self.log("putting %s-%s" % (table, (iteration+1)))
self.job_queue.put("%s-%s" % (table, (iteration+1)))
#various characters need escaping
clean_rows = []
redshift_escape_chars = set( ["\\", "|", "\t", "\r", "\n"] )
in_chars = ""
for row in rows:
new_row = []
for value in row:
if value is not None:
in_chars = str(value)
else:
in_chars = ""
#escape any naughty characters
new_row.append("".join(["\\" + c if c in redshift_escape_chars else c for c in in_chars]))
new_row = "\t".join(new_row)
clean_rows.append(new_row)
rows = ",".join(fields.keys() + delta)
rows += "\n" + "\n".join(clean_rows)
offset = offset + self.CHUNKS
filename = "%s-%s.gz" % (table, iteration)
self.move_file_to_s3(filename, rows)
self.begin_data_import(job, redshift_table, ",".join(fields.keys() +
delta))
self.job_queue.task_done()
def move_file_to_s3(self, uri, contents):
tmp_file = "/dev/shm/%s" % str(os.getpid())
self.KEY.key = "%s/%s" % (self.S3_ROOT, uri)
self.log("key is %s" % self.KEY.key )
f = gzip.open(tmp_file, "wb")
f.write(contents)
f.close()
#local saving allows for debugging when copy commands fail
#text_file = open("tsv/%s" % uri, "w")
#text_file.write(contents)
#text_file.close()
self.KEY.set_contents_from_filename(tmp_file, replace=True)
def get_fields(self, table):
"""
Returns a dict used as:
{"column_name": "altered_column_name"}
Currently only the debug column gets altered
"""
exclude_fields = ["_qproc_id", "_mob_id", "_gw_id", "_batch_id", "Field"]
query = "show columns from %s" % (table)
fields = self.mysql.execute(query)
#key raw field, value mysql formatted field
new_fields = {}
#for field in fields:
for field in [val[0] for val in fields]:
if field in exclude_fields:
continue
old_field = field
if "debug_mode" == field.strip():
field = "IFNULL(debug_mode, 0)"
new_fields[old_field] = field
return new_fields
def log(self, text):
self.log_handler.write("\n%s" % text)
def begin_data_import(self, table, redshift_table, fields):
query = "copy %s (%s) from 's3://bucket/redshift_data_imports/%s' \
credentials 'aws_access_key_id=%s;aws_secret_access_key=%s' delimiter '\\t' \
gzip NULL AS '' COMPUPDATE ON ESCAPE IGNOREHEADER 1;" \
% (redshift_table, fields, table, self.AWS_ACCESS_KEY, self.AWS_ACCESS_SECRET)
self.pg.execute(query)
master.py:
from slave import Slave as Slave
import multiprocessing
from mysql_wrapper import MySQLWrap as MySQLWrap
from pgsql_wrapper import PGSQLWrap as PGSQLWrap
class Master:
SLAVE_COUNT = 5
def __init__(self):
self.mysql = MySQLWrap()
self.pg = PGSQLWrap()
def do_work(table):
pass
def get_table_listings(self):
"""Gathers a list of MySQL log tables needed to be imported"""
query = 'show databases'
result = self.mysql.execute(query)
#turns list[tuple] into a flat list
databases = list(sum(result, ()))
#overriding during development
databases = ['db1', 'db2', 'db3']]
exclude = ('mysql', 'Database', 'information_schema')
scannable_tables = []
for database in databases:
if database in exclude:
continue
query = "show tables from %s" % database
result = self.mysql.execute(query)
#turns list[tuple] into a flat list
tables = list(sum(result, ()))
for table in tables:
exclude = ("Tables_in_%s" % database, "(", "201303", "detailed", "ltv")
#exclude any of the unfavorables
if any(s in table for s in exclude):
continue
scannable_tables.append("%s.%s-1" % (database, table))
return scannable_tables
def init(self):
#fetch redshift columns once and cache
#get columns from redshift so we can pad the mysql column delta with nulls
tables = ('table1', 'table2', 'table3')
for table in tables:
#cache columns
query = "SELECT column_name FROM information_schema.columns WHERE \
table_name = '%s'" % (table)
result = self.pg.execute(query, async=False, ret=True)
Slave.COLUMN_CACHE[table] = list(sum(result, ()))
#cache default values
query = "SELECT column_name FROM information_schema.columns WHERE \
table_name = '%s' and column_default is not \
null" % (table)
result = self.pg.execute(query, async=False, ret=True)
#turns list[tuple] into a flat list
result = list(sum(result, ()))
Slave.DEFAULT_COLUMN_VALUES[table] = result
def run(self):
self.init()
job_queue = multiprocessing.JoinableQueue()
tables = self.get_table_listings()
for table in tables:
job_queue.put(table)
processes = []
for i in range(Master.SLAVE_COUNT):
process = multiprocessing.Process(target=slave_runner, args=(job_queue,))
process.daemon = True
process.start()
processes.append(process)
#blocks this process until queue reaches 0
job_queue.join()
#signal each child process to GTFO
for i in range(Master.SLAVE_COUNT):
job_queue.put(None)
#blocks this process until queue reaches 0
job_queue.join()
job_queue.close()
#do not end this process until child processes close out
for process in processes:
process.join()
#toodles !
print("this is master saying goodbye")
def slave_runner(queue):
slave = Slave(queue)
slave.do_work()
There's not enough information to be sure, but the problem is very likely to be that Slave.do_work is raising an unhandled exception. (There are many lines of your code that could do that in various different conditions.)
When you do that, the child process will just exit.
On POSIX systems… well, the full details are a bit complicated, but in the simple case (what you have here), a child process that exits will stick around as a <defunct> process until it gets reaped (because the parent either waits on it, or exits). Since your parent code doesn't wait on the children until the queue is finished, that's exactly what happens.
So, there's a simple duct-tape fix:
def do_work(self):
self.log(str(os.getpid()))
while True:
try:
# the rest of your code
except Exception as e:
self.log("something appropriate {}".format(e))
# you may also want to post a reply back to the parent
You might also want to break the massive try up into different ones, so you can distinguish between all the different stages where things could go wrong (especially if some of them mean you need a reply, and some mean you don't).
However, it looks like what you're attempting to do is duplicate exactly the behavior of multiprocessing.Pool, but have missed the bar in a couple places. Which raises the question: why not just use Pool in the first place? You could then simplify/optimize things ever further by using one of the map family methods. For example, your entire Master.run could be reduced to:
self.init()
pool = multiprocessing.Pool(Master.SLAVE_COUNT, initializer=slave_setup)
pool.map(slave_job, tables)
pool.join()
And this will handle exceptions for you, and allow you to return values/exceptions if you later need that, and let you use the built-in logging library instead of trying to build your own, and so on. And it should only take about a dozens lines of minor code changes to Slave, and then you're done.
If you want to submit new jobs from within jobs, the easiest way to do this is probably with a Future-based API (which turns things around, making the future result the focus and the pool/executor the dumb thing that provides them, instead of making the pool the focus and the result the dumb thing it gives back), but there are multiple ways to do it with Pool as well. For example, right now, you're not returning anything from each job, so, you can just return a list of tables to execute. Here's a simple example that shows how to do it:
import multiprocessing
def foo(x):
print(x, x**2)
return list(range(x))
if __name__ == '__main__':
pool = multiprocessing.Pool(2)
jobs = [5]
while jobs:
jobs, oldjobs = [], jobs
for job in oldjobs:
jobs.extend(pool.apply(foo, [job]))
pool.close()
pool.join()
Obviously you can condense this a bit by replacing the whole loop with, e.g., a list comprehension fed into itertools.chain, and you can make it a lot cleaner-looking by passing "a submitter" object to each job and adding to that instead of returning a list of new jobs, and so on. But I wanted to make it as explicit as possible to show how little there is to it.
At any rate, if you think the explicit queue is easier to understand and manage, go for it. Just look at the source for multiprocessing.worker and/or concurrent.futures.ProcessPoolExecutor to see what you need to do yourself. It's not that hard, but there are enough things you could get wrong (personally, I always forget at least one edge case when I try to do something like this myself) that it's work looking at code that gets it right.
Alternatively, it seems like the only reason you can't use concurrent.futures.ProcessPoolExecutor here is that you need to initialize some per-process state (the boto.s3.key.Key, MySqlWrap, etc.), for what are probably very good caching reasons. (If this involves a web-service query, a database connect, etc., you certainly don't want to do that once per task!) But there are a few different ways around that.
But you can subclass ProcessPoolExecutor and override the undocumented function _adjust_process_count (see the source for how simple it is) to pass your setup function, and… that's all you have to do.
Or you can mix and match. Wrap the Future from concurrent.futures around the AsyncResult from multiprocessing.

Categories