There is a complex system of calculations over some objects.
The difficulty is that some calculations are group calculations.
This can demonstrate by the following example:
from dask distributed import client
def load_data_from_db(id):
# load some data
...
return data
def task_a(data):
# some calculations
...
return result
def group_task(*args):
# some calculations
...
return result
def task_b(data, group_data):
# some calculations
...
return result
def task_c(data, task_a_result)
# some calculations
...
return result
ids = [1, 2]
dsk = {'id_{}'.format(i): id for i, id in enumerate(ids)}
dsk['data_0'] = (load_data_from_db, 'id_0')
dsk['data_1'] = (load_data_from_db, 'id_1')
dsk['task_a_result_0'] = (task_a, 'data_0')
dsk['task_a_result_1'] = (task_a, 'data_1')
dsk['group_result'] = (
group_task,
'data_0', 'task_a_result_0',
'data_1', 'task_a_result_1')
dsk['task_b_result_0'] = (task_b, 'data_0', 'group_result')
dsk['task_b_result_1'] = (task_b, 'data_1', 'group_result')
dsk['task_c_result_0'] = (task_c, 'data_0', 'task_a_result_0')
dsk['task_c_result_1'] = (task_c, 'data_1', 'task_a_result_1')
client = Client(scheduler_address)
result = client.get(
dsk,
['task_a_result_0',
'task_b_result_0',
'task_c_result_0',
'task_a_result_1',
'task_b_result_1',
'task_c_result_1'])
The list of objects is counted is thousands elements, and the number of tasks is dozens (including several group tasks).
With such method of graph creation it is difficult to modify the graph (add new tasks, change dependencies, etc.).
Is there a more efficient way of distributed computing using dask for these context?
Added
With futures graph is:
client = Client(scheduler_address)
ids = [1, 2]
data = client.map(load_data_from_db, ids)
result_a = client.map(task_a, data)
group_args = list(chain(*zip(data, result_a)))
result_group = client.submit(task_group, *group_args)
result_b = client.map(task_b, data, [result_group] * len(ids))
result_c = client.map(task_c, data, result_a)
result = client.gather(result_a + result_b + result_c)
And in task functions input arguments is Future instance then arg.result() before use.
If you want to modify the computation during computation then I recommend the futures interface.
Related
I started two days ago with ethereum blockchain, so my knowledge is still a little bit all over the place. Nevertheless, i managed to connect to a node, pull some general block data and so on. As a next level of difficulty, I tried to start building event filters, in order to look at more specific types of historical data (to be clear, I don't want to fetch live data, I would rather like to query through the entire chain, and get historical sample extracts for various types of data).
See here my first attempt to build an event filter for the USDC Uniswap V2 contract, in order to collect Swap events (its not about speed or efficiency right now, just to make it work):
w3 = Web3(Web3.HTTPProvider(NODE_ADDRESS))
# uniswap v2 USDC
address = w3.toChecksumAddress('0xb4e16d0168e52d35cacd2c6185b44281ec28c9dc')
# get the ABI for uniswap v2 pair events
resp = requests.get("https://unpkg.com/#uniswap/v2-core#1.0.0/build/IUniswapV2Pair.json")
if resp.status_code==200:
abi = json.loads(resp.content)['abi']
# create contract object
contract = w3.eth.contract(address=address, abi=abi)
# get topics by hashing abi event signatures
res = contract.events.Swap.build_filter()
# put this into a filter input dictionary
filter_params = {'fromBlock':int_to_hex(12000000),'toBlock':int_to_hex(12010000),**res.filter_params}
# res.filter_params contains: 'topics' and 'address'
# create a filter id (i.e. a hashed version of the filter data, representing the filter)
method = 'eth_newFilter'
params = [filter_params]
resp = self.block_manager.general_sample_request(method,params)
if 'error' in resp:
print(resp)
else:
filter_id = resp['result']
# pass on the filter id, in order to query the respective logs
params = [filter_id]
method = 'eth_getFilterLogs'
resp = self.block_manager.general_sample_request(method,params)
# takes about 10-12s for about 12000 events
the resulting array contains event logs of this structure:
resp['result'][0]
>>>
{'address': '0xb4e16d0168e52d35cacd2c6185b44281ec28c9dc',
'topics': ['0xd78ad95fa46c994b6551d0da85fc275fe613ce37657fb8d5e3d130840159d822',
'0x0000000000000000000000007a250d5630b4cf539739df2c5dacb4c659f2488d',
'0x0000000000000000000000000ffd670749d4179558b6b367e30e72ce2efea28f'],
'data': '0x0000000000000000000000000000000000000000000000000000000000000000000000000000000000000\
00000000000000000000000000034f0f8a0c7663264000000000000000000000000000000000000000000000\
000000000019002d5b60000000000000000000000000000000000000000000000000000000000000000',
'blockNumber': '0xb71b01',
'transactionHash': '0x76403053ee0300411b68fc223b327b51fb4f1a26e1f6cb8667e05ec370e8176e',
'transactionIndex': '0x22',
'blockHash': '0x4bd35cb48395e77fd317a0309342c95d6687dbc4fcb85ada2d635fe266d1e769',
'logIndex': '0x16',
'removed': False}
As far as I understand now, I can somehow apply the ABI to decode the 'data' field.
I tried with this function:
contract.decode_function_input(resp['result'][0]['data'])
but it gives me this error:
>>> ValueError: Could not find any function with matching selector
Seems like there is some problem with decoding the data. However, I am so close now to getting the real data, I dont wanna give up xD. Any help will be appreciated!
Thanks!
import json
import traceback
from pprint import pprint
from eth_utils import event_abi_to_log_topic, to_hex
from hexbytes import HexBytes
from web3._utils.events import get_event_data
from web3.auto import w3
def decode_tuple(t, target_field):
output = dict()
for i in range(len(t)):
if isinstance(t[i], (bytes, bytearray)):
output[target_field[i]['name']] = to_hex(t[i])
elif isinstance(t[i], (tuple)):
output[target_field[i]['name']] = decode_tuple(t[i], target_field[i]['components'])
else:
output[target_field[i]['name']] = t[i]
return output
def decode_list_tuple(l, target_field):
output = l
for i in range(len(l)):
output[i] = decode_tuple(l[i], target_field)
return output
def decode_list(l):
output = l
for i in range(len(l)):
if isinstance(l[i], (bytes, bytearray)):
output[i] = to_hex(l[i])
else:
output[i] = l[i]
return output
def convert_to_hex(arg, target_schema):
"""
utility function to convert byte codes into human readable and json serializable data structures
"""
output = dict()
for k in arg:
if isinstance(arg[k], (bytes, bytearray)):
output[k] = to_hex(arg[k])
elif isinstance(arg[k], (list)) and len(arg[k]) > 0:
target = [a for a in target_schema if 'name' in a and a['name'] == k][0]
if target['type'] == 'tuple[]':
target_field = target['components']
output[k] = decode_list_tuple(arg[k], target_field)
else:
output[k] = decode_list(arg[k])
elif isinstance(arg[k], (tuple)):
target_field = [a['components'] for a in target_schema if 'name' in a and a['name'] == k][0]
output[k] = decode_tuple(arg[k], target_field)
else:
output[k] = arg[k]
return output
def _get_topic2abi(abi):
if isinstance(abi, (str)):
abi = json.loads(abi)
event_abi = [a for a in abi if a['type'] == 'event']
topic2abi = {event_abi_to_log_topic(_): _ for _ in event_abi}
return topic2abi
def _sanitize_log(log):
for i, topic in enumerate(log['topics']):
if not isinstance(topic, HexBytes):
log['topics'][i] = HexBytes(topic)
if 'address' not in log:
log['address'] = None
if 'blockHash' not in log:
log['blockHash'] = None
if 'blockNumber' not in log:
log['blockNumber'] = None
if 'logIndex' not in log:
log['logIndex'] = None
if 'transactionHash' not in log:
log['transactionHash'] = None
if 'transactionIndex' not in log:
log['transactionIndex'] = None
def decode_log(log, abi):
if abi is not None:
try:
# get a dict with all available events from the ABI
topic2abi = _get_topic2abi(abi)
# ensure the log contains all necessary keys
_sanitize_log(log)
# get the ABI of the event in question (stored as the first topic)
event_abi = topic2abi[log['topics'][0]]
# get the event name
evt_name = event_abi['name']
# get the event data
data = get_event_data(w3.codec, event_abi, log)['args']
target_schema = event_abi['inputs']
decoded_data = convert_to_hex(data, target_schema)
return (evt_name, decoded_data, target_schema)
except Exception:
return ('decode error', traceback.format_exc(), None)
else:
return ('no matching abi', None, None)
Example usage:
output = decode_log(
{'data': '0x000000000000000000000000000000000000000000000000000000009502f90000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000093f8f932b016b1c',
'topics': [
'0xd78ad95fa46c994b6551d0da85fc275fe613ce37657fb8d5e3d130840159d822',
'0x0000000000000000000000007a250d5630b4cf539739df2c5dacb4c659f2488d',
'0x000000000000000000000000242301fa62f0de9e3842a5fb4c0cdca67e3a2fab'],
},
pair_abi
)
print(output[0])
pprint(output[1])
# Swap
# {'amount0In': 2500000000,
# 'amount0Out': 0,
# 'amount1In': 0,
# 'amount1Out': 666409132118600476,
# 'sender': '0x7a250d5630B4cF539739dF2C5dAcb4c659F2488D',
# 'to': '0x242301FA62f0De9e3842A5Fb4c0CdCa67e3A2Fab'}
Or in your case:
output = decode_log(resp['result'][0], pair_abi)
print(output[0])
pprint(output[1])
# Swap
# {'amount0In': 0,
# 'amount0Out': 6711072182,
# 'amount1In': 3814822253806629476,
# 'amount1Out': 0,
# 'sender': '0x7a250d5630B4cF539739dF2C5dAcb4c659F2488D',
# 'to': '0x0Ffd670749D4179558b6B367E30e72ce2efea28F'}
Now, note that you need to provide the pair_abi variable. It depends on the type of smart contract that you're using. I've found that when on Uniswap V3, the UniswapV2Pair ABI worked for some events, while UniswapV3Pool ABI worked for others, in particular for the Swap event that I've found the most useful.
After a few hours of digging I managed to find this solution, which is a slightly modified version of the one proposed in: https://towardsdatascience.com/decoding-ethereum-smart-contract-data-eed513a65f76 Big thumbs up to its author ๐ You can read more there on parsing the transaction input too.
I am running a dask.distributed Client that gets data from an API with several parameters, parses results and joins/aggregates on each result. This is done with client.map()
Sometimes the API call gives an empty string because the specific combination of input parameters doesn't exist. It doesn't make sense to continue with computations and I would like to just kill that worker (without passing on e.g. a None).
How do I tell Dask to kill a worker if its result is None/error and exclude that future from the following operations?
Please let me know if you need more details.
Thanks.
EDIT:
Added a minimal working example to show the logic: the first map produces a lot of "useless" workers that I would like to kill.
Please notice that this is not my actual use case, I am querying an Influx database via http requests but the general structure of the code is the same. I am open to any comments on how to do that faster/more efficiently.
ยดยดยดยดpython
import requests
import numpy as np
import pandas as pd
from dask.distributed import Client, LocalCluster, as_completed
import dask.dataframe as dd
def fetch_html(pair):
req_string = 'https://www.bitstamp.net/api/v2/order_book/{currency_pair}/'
response = requests.get(req_string.format(currency_pair=pair))
try:
result = response.json()
return result
except Exception as e:
print('Error: {}\nMessage: {}'.format(e,response.reason))
return None
def parse_result(result):
if result:
data = {}
data['prices'] = [e[0] for e in result['bids']]
data['vols'] = [e[1] for e in result['bids']]
data['index'] = [result['timestamp'] for i in data['prices']]
df = pd.DataFrame.from_dict(data).set_index('index')
return df
else:
return pd.DataFrame()
def other_calcs(result):
if not result.empty:
# something
return result
else:
return pd.DataFrame()
def aggregator(res1, res2):
if (not res1.empty) and (not res2.empty):
# something
return res1
elif not res2.empty:
# something
return res2
elif not res1.empty:
return res1
else:
return pd.DataFrame()
if __name__=='__main__':
pairs = [
# legit params (100s of these):
'btcusd',
'btceur',
'btcgbp',
'bateur',
'batbtc',
'umausd',
'xrpusdt',
'eurteur',
'eurtusd',
'manausd',
'sandeur',
'storjusd',
'storjeur',
'adausd',
'adaeur',
# bad params resulting in error / empty result (100s of these)
'foobar',
'foobaz',
'foousd',
'barbaz',
'bazbar',
]
cluster = LocalCluster(n_workers=16, threads_per_worker=1)
client = Client(cluster)
futures_list = client.map(fetch_html, pairs)
futures_list = client.map(parse_result, futures_list)
futures_list = client.map(other_calcs, futures_list)
seq = as_completed(futures_list)
while seq.count() > 1:
f1 = next(seq)
f2 = next(seq)
new = client.submit(aggregator, f1, f2, priority=1)
seq.add(new)
final = next(seq)
final = final.result()
print(final.head())
ยดยดยดยด
got some functions with sqlstatements. My first func is fine because i get only 1 result.
My second function returns a large list of errorcodes and i dont know how to get them back for response.
TypeError: <sqlalchemy.engine.result.ResultProxy object at 0x7f98b85ef910> is not JSON serializable
Tried everything need help.
My Code:
def topalarms():
customer_name = request.args.get('customer_name')
machine_serial = request.args.get('machine_serial')
#ts = request.args.get('ts')
#ts_start = request.args.get('ts')
if (customer_name is None) or (machine_serial is None):
return missing_param()
# def form_response(response, session):
# response['customer'] = customer_name
# response['serial'] = machine_serial
# return do_response(customer_name, form_response)
def form_response(response, session):
result_machine_id = machine_id(session, machine_serial)
if not result_machine_id:
response['Error'] = 'Seriennummer nicht vorhanden/gefunden'
return
#response[''] = result_machine_id[0]["id"]
machineid = result_machine_id[0]["id"]
result_errorcodes = error_codes(session, machineid)
response['ErrorCodes'] = result_errorcodes
return do_response(customer_name, form_response)
def machine_id(session, machine_serial):
stmt_raw = '''
SELECT
id
FROM
machine
WHERE
machine.serial = :machine_serial_arg
'''
utc_now = datetime.datetime.utcnow()
utc_now_iso = pytz.utc.localize(utc_now).isoformat()
utc_start = datetime.datetime.utcnow() - datetime.timedelta(days = 30)
utc_start_iso = pytz.utc.localize(utc_start).isoformat()
stmt_args = {
'machine_serial_arg': machine_serial,
}
stmt = text(stmt_raw).columns(
#ts_insert = ISODateTime
)
result = session.execute(stmt, stmt_args)
ts = utc_now_iso
ts_start = utc_start_iso
ID = []
for row in result:
ID.append({
'id': row[0],
'ts': ts,
'ts_start': ts_start,
})
return ID
def error_codes(session, machineid):
stmt_raw = '''
SELECT
name
FROM
identifier
WHERE
identifier.machine_id = :machineid_arg
'''
stmt_args = {
'machineid_arg': machineid,
}
stmt = text(stmt_raw).columns(
#ts_insert = ISODateTime
)
result = session.execute(stmt, stmt_args)
errors = []
for row in result:
errors.append(result)
#({'result': [dict(row) for row in result]})
#errors = {i: result[i] for i in range(0, len(result))}
#errors = dict(result)
return errors
My problem is func error_codes somethiing is wrong with my result.
my Output should be like this:
ABCNormal
ABCSafety
Alarm_G01N01
Alarm_G01N02
Alarm_G01N03
Alarm_G01N04
Alarm_G01N05
I think you need to take a closer look at what you are doing correctly with your working function and compare that to your non-working function.
Firstly, what do you think this code does?
for row in result:
errors.append(result)
This adds to errors one copy of the result object for each row in result. So if you have six rows in result, errors contains six copies of result. I suspect this isn't what you are looking for. You want to be doing something with the row variable.
Taking a closer look at your working function, you are taking the first value out of the row, using row[0]. So, you probably want to do the same in your non-working function:
for row in result:
errors.append(row[0])
I don't have SQLAlchemy set up so I haven't tested this: I have provided this answer based solely on the differences between your working function and your non-working function.
You need a json serializer. I suggest using Marshmallow: https://marshmallow.readthedocs.io/en/stable/
There are some great tutorials online on how to do this.
I have written a webapp which queries data from a database based on some inputs (start/end datetimes, machine id and parameter id) and shows it in a bokeh figure:
As you can see so far it works as intended but I have some plans to extend this app further:
Allow data from different batches (with different start/end timestamps) to be loaded into the same graph.
Perform some statistical analysis of the different batches, e.g. averages, standard deviations, control limits, etc.
Get live streaming updates of parameters for different machines and/or parameters.
So I am now at the point where the app starts to become more complex and I want to refactor the code into a maintainable and extensible format. Currently, the code is written procedurally and i would like to move to a MVC-like model to separate the data querying from the bokeh visualizations and statistical computations but I am unsure how to approach this best.
How can i refactor my code best?
import logging
import pymssql, pandas
from dateutil import parser
from datetime import datetime, timedelta
from bokeh import layouts, models, plotting, settings
from bokeh.models import widgets
SETTINGS = {
'server': '',
'user': '',
'password': '',
'database': ''
}
def get_timestamps(datetimes):
""" DB timestamps are in milliseconds """
return [int(dt.timestamp()*1000) for dt in datetimes]
def get_db_names(timestamps):
logging.debug('Started getting DB names ...')
query = """
SELECT
[DBName]
FROM [tblDBNames]
WHERE {timestamp_ranges}
""".format(
timestamp_ranges = ' OR '.join([f'({timestamp} BETWEEN [LStart] AND [LStop])' for timestamp in timestamps])
)
logging.debug(query)
db_names = []
with pymssql.connect(**SETTINGS) as conn:
with conn.cursor(as_dict=True) as cursor:
cursor.execute(query)
for row in cursor:
db_names.append(row['DBName'])
#logging.debug(db_names)
logging.debug('Finished getting DB names')
return list(set(db_names))
def get_machines():
logging.debug('Started getting machines ...')
query = """
SELECT
CONVERT(VARCHAR(2),[ID]) AS [ID],
[Name]
FROM [tblMaschinen]
WHERE NOT [Name] = 'TestLine4'
ORDER BY [Name]
"""
logging.debug(query)
with pymssql.connect(**SETTINGS) as conn:
with conn.cursor(as_dict=False) as cursor:
cursor.execute(query)
data = cursor.fetchall()
#logging.debug(data)
logging.debug('Finished getting machines')
return data
def get_parameters(machine_id, parameters):
logging.debug('Started getting process parameteres ...')
query = """
SELECT
CONVERT(VARCHAR(4), TrendConfig.ID) AS [ID],
TrendConfig_Text.description AS [Description]
FROM [TrendConfig]
INNER JOIN TrendConfig_Text
ON TrendConfig.ID = TrendConfig_Text.ID
WHERE (TrendConfig_Text.languageText_KEY = 'nl')
AND TrendConfig.MaschinenID = {machine_id}
AND TrendConfig_Text.description IN ('{parameters}')
ORDER BY TrendConfig_Text.description
""".format(
machine_id = machine_id,
parameters = "', '".join(parameters)
)
logging.debug(query)
with pymssql.connect(**SETTINGS) as conn:
with conn.cursor(as_dict=False) as cursor:
cursor.execute(query)
data = cursor.fetchall()
#logging.debug(data)
logging.debug('Finished getting process parameters')
return data
def get_process_data(query):
logging.debug('Started getting process data ...')
with pymssql.connect(**SETTINGS) as conn:
return pandas.read_sql(query, conn, parse_dates={'LTimestamp': 'ms'}, index_col='LTimestamp')
logging.debug('Finished getting process data')
batches = widgets.Slider(start=1, end=10, value=1, step=1, title="Batches")
now, min_date = datetime.now(), datetime.fromtimestamp(1316995200)
date_start = widgets.DatePicker(title="Start date:", value=str(now.date()), min_date=str(min_date), max_date=str(now.date()))
time_start = widgets.TextInput(title="Start time:", value=str((now-timedelta(hours=1)).replace(microsecond=0).time()))
start_row = layouts.Row(children=[date_start, time_start], width = 300)
date_end = widgets.DatePicker(title="End date:", value=str(now.date()), min_date=str(min_date), max_date=str(now.date()))
time_end = widgets.TextInput(title="End time:", value=str(now.replace(microsecond=0).time()))
end_row = layouts.Row(children=[date_end, time_end], width = 300)
datetimes = layouts.Column(children=[start_row, end_row])
## Machine list
machines = get_machines()
def select_machine_cb(attr, old, new):
logging.debug(f'Changed machine ID: old={old}, new={new}')
parameters = get_parameters(select_machine.value, default_params)
select_parameters.options = parameters
select_parameters.value = [parameters[0][0]]
select_machine = widgets.Select(
options = machines,
value = machines[0][0],
title = 'Machine:'
)
select_machine.on_change('value', select_machine_cb)
## Parameters list
default_params = [
'Debiet acuteel',
'Extruder energie',
'Extruder kWh/kg',
'Gewicht bunker',
'RPM Extruder acuteel',
'Temperatuur Kop'
]
parameters = get_parameters(select_machine.value, default_params)
select_parameters = widgets.MultiSelect(
options = parameters,
value = [parameters[0][0]],
title = 'Parameter:'
)
def btn_update_cb(arg):
logging.debug('btn_update clicked')
datetime_start = parser.parse(f'{date_start.value} {time_start.value}')
datetime_end = parser.parse(f'{date_end.value} {time_end.value}')
datetimes = [datetime_start, datetime_end]
timestamps = get_timestamps(datetimes)
db_names = get_db_names(timestamps)
machine_id = select_machine.value
parameter_ids = select_parameters.value
query = """
SELECT
[LTimestamp],
[TrendConfigID],
[Text],
[Value]
FROM ({derived_table}) [Trend]
LEFT JOIN [TrendConfig] AS [TrendConfig]
ON [Trend].[TrendConfigID] = [TrendConfig].[ID]
WHERE [LTimestamp] BETWEEN {timestamp_range}
AND [Trend].[TrendConfigID] IN ({id_range})
""".format(
derived_table = ' UNION ALL '.join([f'SELECT * FROM [{db_name}].[dbo].[Trend_{machine_id}]' for db_name in db_names]),
timestamp_range = ' AND '.join(map(str,timestamps)),
id_range = ' ,'.join(parameter_ids)
)
logging.debug(query)
df = get_process_data(query)
ds = models.ColumnDataSource(df)
plot.renderers = [] # clear plot
#view = models.CDSView(source=ds, filters=[models.GroupFilter(column_name='TrendConfigID', group='')])
#plot = plotting.figure(plot_width=600, plot_height=300, x_axis_type='datetime')
plot.line(x='LTimestamp', y='Value', source=ds, name='line')
btn_update = widgets.Button(
label="Update",
button_type="primary",
width = 150
)
btn_update.on_click(btn_update_cb)
btn_row = layouts.Row(children=[btn_update])
column = layouts.Column(children=[batches, datetimes, select_machine, select_parameters, btn_row], width = 300)
plot = plotting.figure(plot_width=600, plot_height=300, x_axis_type='datetime')
row = layouts.Row(children=[column, layouts.Spacer(width=20), plot])
tab1 = models.Panel(child=row, title="Viewer")
tab2 = models.Panel(child=layouts.Spacer(), title="Settings")
tabs = models.Tabs(tabs=[tab1, tab2])
plotting.curdoc().add_root(tabs)
I would recommend OOP (Object Oriented Programming).
To do this you would need to:
Come up with a standardized object that is being graphed. For
example (Point:{value: x, startTime: a, endTime: y})
Extract obtaining the object (or list of objects) into a separate
class/file (example found here: http://www.qtrac.eu/pyclassmulti.html)
Come up with a standardized interface. For
example, the interface could be called 'Batch' and could have a
method 'obtainPoints' that returns a list of 'Point' objects defined
in step 1.
Now, you can create multiple Batch Implementations that implement the interface and the main class could call the obtainPoints method on the separate implementations and graph them. In the end you would have 1 interface (Batch) X implementations of Batch (ie. SQLDatabaseBatch, LDAPBatch, etc...) and a main class that utilizes all of these Batch implementation(s) and creates a graph.
Suppose I am given a collection of documents. I am required to tokenize them, and then turn them into vectors for further work. As I find elasticsearch's tokenizer works much better than my own solution, I am switching to that. However, it is considerably slower. Then the end result is expected to be fed into the vectorizer in a stream.
The whole process can be done with a chained list of generators
def fetch_documents(_cursor):
with _cursor:
# a lot of documents expected, may not fit in memory
_cursor.execute('select ... from ...')
for doc in _cursor:
yield doc
def tokenize(documents):
for doc in documents:
yield elasticsearch_tokenize_me(doc)
def build_model(documents):
some_model = SomeModel()
for doc in documents:
some_model.add_document(doc)
return some_model
build_model(tokenize(fetch_documents))
So this basically works fine, but doesn't utilize all the available processing capability. As dask is used in other related projects, I try to adapt and get this (I am using psycopg2 for database access).
from dask import delayed
import psycopg2
import psycopg2.extras
from elasticsearch import Elasticsearch
from elasticsearch.client import IndicesClient
def loader():
conn = psycopg2.connect()
cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
cur.execute('''
SELECT document, ... FROM ...
''')
return cur
#delayed
def tokenize(partition):
result = []
client = IndicesClient(Elasticsearch())
for row in partition:
_result = client.analyze(analyzer='standard', text=row['document'])
result.append(dict(row,
tokens=tuple(item['token'] for item in _result['tokens'])))
return result
#delayed
def build_model(sequence_of_data):
some_model = SomeModel()
for item in chain.from_iterable(sequence_of_data):
some_model.add_document(item)
return some_model
with loader() as cur:
partitions = []
for idx_start in range(0, cur.rowcount, 200):
partitions.append(delayed(cur.fetchmany)(200))
tokenized = []
for partition in partitions:
tokenized.append(tokenize(partition))
result = do_something(tokenized)
result.compute()
Code more or less work, except at the end all documents are tokenized, before being fed into the model. While this works for smaller collection of data, however not for a huge collection of data (due to huge memory consumption). Should I just use plain concurrent.futures for this work or am I using dask wrongly?
A simple solution would be to load data locally on your machine (it's hard to partition a single SQL query) and then send the data to the dask-cluster for the expensive tokenization step. Perhaps something as follows:
rows = cur.execute(''' SELECT document, ... FROM ... ''')
from toolz import partition_all, concat
partitions = partition_all(10000, rows)
from dask.distributed import Executor
e = Executor('scheduler-address:8786')
futures = []
for part in partitions:
x = e.submit(tokenize, part)
y = e.submit(process, x)
futures.append(y)
results = e.gather(futures)
result = list(concat(results))
In this example the functions tokenize and process expect to consume and return a list of elements.
Using just concurrent.futures for the work
from concurrent.futures import ProcessPoolExecutor
def loader():
conn = psycopg2.connect()
cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
cur.execute('''
SELECT document, ... FROM ...
''')
return cur
def tokenize(partition):
result = []
client = IndicesClient(Elasticsearch())
for row in partition:
_result = client.analyze(analyzer='standard', text=row['document'])
result.append(dict(row,
tokens=tuple(item['token'] for item in _result['tokens'])))
return result
def do_something(partitions, total):
some_model = 0
for partition in partitions:
result = partition.result()
for item in result:
some_model.add_document(item)
return some_model
with loader() as cur, \
ProcessPoolExecutor(max_workers=8) as executor:
print(cur.rowcount)
partitions = []
for idx_start in range(0, cur.rowcount, 200):
partitions.append(executor.submit(tokenize,
cur.fetchmany(200)))
build_model(partitions)