how to refactor procedural code to mvc model? - python

I have written a webapp which queries data from a database based on some inputs (start/end datetimes, machine id and parameter id) and shows it in a bokeh figure:
As you can see so far it works as intended but I have some plans to extend this app further:
Allow data from different batches (with different start/end timestamps) to be loaded into the same graph.
Perform some statistical analysis of the different batches, e.g. averages, standard deviations, control limits, etc.
Get live streaming updates of parameters for different machines and/or parameters.
So I am now at the point where the app starts to become more complex and I want to refactor the code into a maintainable and extensible format. Currently, the code is written procedurally and i would like to move to a MVC-like model to separate the data querying from the bokeh visualizations and statistical computations but I am unsure how to approach this best.
How can i refactor my code best?
import logging
import pymssql, pandas
from dateutil import parser
from datetime import datetime, timedelta
from bokeh import layouts, models, plotting, settings
from bokeh.models import widgets
SETTINGS = {
'server': '',
'user': '',
'password': '',
'database': ''
}
def get_timestamps(datetimes):
""" DB timestamps are in milliseconds """
return [int(dt.timestamp()*1000) for dt in datetimes]
def get_db_names(timestamps):
logging.debug('Started getting DB names ...')
query = """
SELECT
[DBName]
FROM [tblDBNames]
WHERE {timestamp_ranges}
""".format(
timestamp_ranges = ' OR '.join([f'({timestamp} BETWEEN [LStart] AND [LStop])' for timestamp in timestamps])
)
logging.debug(query)
db_names = []
with pymssql.connect(**SETTINGS) as conn:
with conn.cursor(as_dict=True) as cursor:
cursor.execute(query)
for row in cursor:
db_names.append(row['DBName'])
#logging.debug(db_names)
logging.debug('Finished getting DB names')
return list(set(db_names))
def get_machines():
logging.debug('Started getting machines ...')
query = """
SELECT
CONVERT(VARCHAR(2),[ID]) AS [ID],
[Name]
FROM [tblMaschinen]
WHERE NOT [Name] = 'TestLine4'
ORDER BY [Name]
"""
logging.debug(query)
with pymssql.connect(**SETTINGS) as conn:
with conn.cursor(as_dict=False) as cursor:
cursor.execute(query)
data = cursor.fetchall()
#logging.debug(data)
logging.debug('Finished getting machines')
return data
def get_parameters(machine_id, parameters):
logging.debug('Started getting process parameteres ...')
query = """
SELECT
CONVERT(VARCHAR(4), TrendConfig.ID) AS [ID],
TrendConfig_Text.description AS [Description]
FROM [TrendConfig]
INNER JOIN TrendConfig_Text
ON TrendConfig.ID = TrendConfig_Text.ID
WHERE (TrendConfig_Text.languageText_KEY = 'nl')
AND TrendConfig.MaschinenID = {machine_id}
AND TrendConfig_Text.description IN ('{parameters}')
ORDER BY TrendConfig_Text.description
""".format(
machine_id = machine_id,
parameters = "', '".join(parameters)
)
logging.debug(query)
with pymssql.connect(**SETTINGS) as conn:
with conn.cursor(as_dict=False) as cursor:
cursor.execute(query)
data = cursor.fetchall()
#logging.debug(data)
logging.debug('Finished getting process parameters')
return data
def get_process_data(query):
logging.debug('Started getting process data ...')
with pymssql.connect(**SETTINGS) as conn:
return pandas.read_sql(query, conn, parse_dates={'LTimestamp': 'ms'}, index_col='LTimestamp')
logging.debug('Finished getting process data')
batches = widgets.Slider(start=1, end=10, value=1, step=1, title="Batches")
now, min_date = datetime.now(), datetime.fromtimestamp(1316995200)
date_start = widgets.DatePicker(title="Start date:", value=str(now.date()), min_date=str(min_date), max_date=str(now.date()))
time_start = widgets.TextInput(title="Start time:", value=str((now-timedelta(hours=1)).replace(microsecond=0).time()))
start_row = layouts.Row(children=[date_start, time_start], width = 300)
date_end = widgets.DatePicker(title="End date:", value=str(now.date()), min_date=str(min_date), max_date=str(now.date()))
time_end = widgets.TextInput(title="End time:", value=str(now.replace(microsecond=0).time()))
end_row = layouts.Row(children=[date_end, time_end], width = 300)
datetimes = layouts.Column(children=[start_row, end_row])
## Machine list
machines = get_machines()
def select_machine_cb(attr, old, new):
logging.debug(f'Changed machine ID: old={old}, new={new}')
parameters = get_parameters(select_machine.value, default_params)
select_parameters.options = parameters
select_parameters.value = [parameters[0][0]]
select_machine = widgets.Select(
options = machines,
value = machines[0][0],
title = 'Machine:'
)
select_machine.on_change('value', select_machine_cb)
## Parameters list
default_params = [
'Debiet acuteel',
'Extruder energie',
'Extruder kWh/kg',
'Gewicht bunker',
'RPM Extruder acuteel',
'Temperatuur Kop'
]
parameters = get_parameters(select_machine.value, default_params)
select_parameters = widgets.MultiSelect(
options = parameters,
value = [parameters[0][0]],
title = 'Parameter:'
)
def btn_update_cb(arg):
logging.debug('btn_update clicked')
datetime_start = parser.parse(f'{date_start.value} {time_start.value}')
datetime_end = parser.parse(f'{date_end.value} {time_end.value}')
datetimes = [datetime_start, datetime_end]
timestamps = get_timestamps(datetimes)
db_names = get_db_names(timestamps)
machine_id = select_machine.value
parameter_ids = select_parameters.value
query = """
SELECT
[LTimestamp],
[TrendConfigID],
[Text],
[Value]
FROM ({derived_table}) [Trend]
LEFT JOIN [TrendConfig] AS [TrendConfig]
ON [Trend].[TrendConfigID] = [TrendConfig].[ID]
WHERE [LTimestamp] BETWEEN {timestamp_range}
AND [Trend].[TrendConfigID] IN ({id_range})
""".format(
derived_table = ' UNION ALL '.join([f'SELECT * FROM [{db_name}].[dbo].[Trend_{machine_id}]' for db_name in db_names]),
timestamp_range = ' AND '.join(map(str,timestamps)),
id_range = ' ,'.join(parameter_ids)
)
logging.debug(query)
df = get_process_data(query)
ds = models.ColumnDataSource(df)
plot.renderers = [] # clear plot
#view = models.CDSView(source=ds, filters=[models.GroupFilter(column_name='TrendConfigID', group='')])
#plot = plotting.figure(plot_width=600, plot_height=300, x_axis_type='datetime')
plot.line(x='LTimestamp', y='Value', source=ds, name='line')
btn_update = widgets.Button(
label="Update",
button_type="primary",
width = 150
)
btn_update.on_click(btn_update_cb)
btn_row = layouts.Row(children=[btn_update])
column = layouts.Column(children=[batches, datetimes, select_machine, select_parameters, btn_row], width = 300)
plot = plotting.figure(plot_width=600, plot_height=300, x_axis_type='datetime')
row = layouts.Row(children=[column, layouts.Spacer(width=20), plot])
tab1 = models.Panel(child=row, title="Viewer")
tab2 = models.Panel(child=layouts.Spacer(), title="Settings")
tabs = models.Tabs(tabs=[tab1, tab2])
plotting.curdoc().add_root(tabs)

I would recommend OOP (Object Oriented Programming).
To do this you would need to:
Come up with a standardized object that is being graphed. For
example (Point:{value: x, startTime: a, endTime: y})
Extract obtaining the object (or list of objects) into a separate
class/file (example found here: http://www.qtrac.eu/pyclassmulti.html)
Come up with a standardized interface. For
example, the interface could be called 'Batch' and could have a
method 'obtainPoints' that returns a list of 'Point' objects defined
in step 1.
Now, you can create multiple Batch Implementations that implement the interface and the main class could call the obtainPoints method on the separate implementations and graph them. In the end you would have 1 interface (Batch) X implementations of Batch (ie. SQLDatabaseBatch, LDAPBatch, etc...) and a main class that utilizes all of these Batch implementation(s) and creates a graph.

Related

Databricks DLT reading a table from one schema(bronze), process CDC data and store to another schema (processed)

I am developing an ETL pipeline using databricks DLT pipelines for CDC data that I recieve from kafka. I have created 2 pipelines successfully for landing, and raw zone. The raw one will have operation flag, a sequence column, and I would like to process the CDC and store the clean data in processed layer (SCD 1 type). I am having difficulties in reading table from one schema, apply CDC changes, and load to target db schema tables.
I have 100 plus tables, so i am planning to loop through the tables in RAW layer and apply CDC, move to processed layer. Following is my code that I have tried (I have left the commented code just for your reference).
import dlt
from pyspark.sql.functions import *
from pyspark.sql.types import *
raw_db_name = "raw_db"
processed_db_name = "processed_db_name"
def generate_curated_table(src_table_name, tgt_table_name, df):
# #dlt.view(
# name= src_table_name,
# spark_conf={
# "pipelines.incompatibleViewCheck.enabled": "false"
# },
# comment="Processed data for " + str(src_table_name)
# )
# # def create_target_table():
# # return (df)
# dlt.create_target_table(name=tgt_table_name,
# comment= f"Clean, merged {tgt_table_name}",
# #partition_cols=["topic"],
# table_properties={
# "quality": "silver"
# }
# )
# #dlt.view
# def users():
# return spark.readStream.format("delta").table(src_table_name)
#dlt.view
def raw_tbl_data():
return df
dlt.create_target_table(name=tgt_table_name,
comment="Clean, merged customers",
table_properties={
"quality": "silver"
})
dlt.apply_changes(
target = tgt_table_name,
source = f"{raw_db_name}.raw_tbl_data,
keys = ["id"],
sequence_by = col("timestamp_ms"),
apply_as_deletes = expr("op = 'DELETE'"),
apply_as_truncates = expr("op = 'TRUNCATE'"),
except_column_list = ["id", "timestamp_ms"],
stored_as_scd_type = 1
)
return
tbl_name = 'raw_po_details'
df = spark.sql(f'select * from {raw_dbname}.{tbl_name}')
processed_tbl_name = tbl_name.replace("raw", "processed") //processed_po_details
generate_curated_table(tbl_name, processed_tbl_name, df)
I have tried with dlt.view(), dlt.table(), dlt.create_streaming_live_table(), dlt.create_target_table(), but ending up with either of the following errors:
AttributeError: 'function' object has no attribute '_get_object_id'
pyspark.sql.utils.AnalysisException: Failed to read dataset '<raw_db_name.mytable>'. Dataset is not defined in the pipeline
.Expected result:
Read the dataframe which is passed as a parameter (RAW_DB) and
Create new tables in PROCESSED_DB which is configured in DLT pipeline settings
https://www.databricks.com/blog/2022/04/27/how-uplift-built-cdc-and-multiplexing-data-pipelines-with-databricks-delta-live-tables.html
https://cprosenjit.medium.com/databricks-delta-live-tables-job-workflows-orchestration-patterns-bc7643935299
Appreciate any help please.
Thanks in advance
I got the solution myself and got it working, thanks to all. Am adding my solution so it could be a reference to others.
import dlt
from pyspark.sql.functions import *
from pyspark.sql.types import *
def generate_silver_tables(target_table, source_table):
#dlt.table
def customers_filteredB():
return spark.table("my_raw_db.myraw_table_name")
### Create the target table definition
dlt.create_target_table(name=target_table,
comment= f"Clean, merged {target_table}",
#partition_cols=["topic"],
table_properties={
"quality": "silver",
"pipelines.autoOptimize.managed": "true"
}
)
## Do the merge
dlt.apply_changes(
target = target_table,
source = "customers_filteredB",
keys = ["id"],
apply_as_deletes = expr("operation = 'DELETE'"),
sequence_by = col("timestamp_ms"),#primary key, auto-incrementing ID of any kind that can be used to identity order of events, or timestamp
ignore_null_updates = False,
except_column_list = ["operation", "timestamp_ms"],
stored_as_scd_type = "1"
)
return
raw_dbname = "raw_db"
raw_tbl_name = 'raw_table_name'
processed_tbl_name = raw_tbl_name.replace("raw", "processed")
generate_silver_tables(processed_tbl_name, raw_tbl_name)

How do I compile and bring in multiple outputs from the same worker?

I'm developing a kubeflow pipeline that takes in a data set, splits that dataset into two different datasets based on a filter inside the code, and outputs both datasets. That function looks like the following:
def merge_promo_sales(input_data: Input[Dataset],
output_data_hd: OutputPath("Dataset"),
output_data_shop: OutputPath("Dataset")):
import pandas as pd
pd.set_option('display.max_rows', 100)
pd.set_option('display.max_columns', 500)
import numpy as np
from google.cloud import bigquery
from utils import google_bucket
client = bigquery.Client("gcp-sc-demand-plan-analytics")
print("Client creating using default project: {}".format(client.project), "Pulling Data")
query = """
SELECT * FROM `gcp-sc-demand-plan-analytics.Modeling_Input.monthly_delivery_type_sales` a
Left Join `gcp-sc-demand-plan-analytics.Modeling_Input.monthly_promotion` b
on a.ship_base7 = b.item_no
and a.oper_cntry_id = b.corp_cd
and a.dmand_mo_yr = b.dates
"""
query_job = client.query(
query,
# Location must match that of the dataset(s) referenced in the query.
location="US",
) # API request - starts the query
df = query_job.to_dataframe()
df.drop(['corp_cd', 'item_no', 'dates'], axis = 1, inplace=True)
df.loc[:, 'promo_objective_increase_margin':] = df.loc[:, 'promo_objective_increase_margin':].fillna(0)
items = df_['ship_base7'].unique()
df = df[df['ship_base7'].isin(items)]
df_hd = df[df['location_type'] == 'home_delivery']
df_shop = df[df['location_type'] != 'home_delivery']
df_hd.to_pickle(output_data_hd)
df_shop.to_pickle(output_data_shop)
That part works fine. When I try to feed those two data sets into the next function with the compiler, I hit errors.
I tried the following:
#kfp.v2.dsl.pipeline(name=PIPELINE_NAME)
def my_pipeline():
merge_promo_sales_nl = merge_promo_sales(input_data = new_launch.output)
rule_3_hd = rule_3(input_data = merge_promo_sales_nl.output_data_hd)
rule_3_shop = rule_3(input_data = merge_promo_sales_nl.output_data_shop)`
The error I get is the following:
AttributeError: 'ContainerOp' object has no attribute 'output_data_hd'
output_data_hd is the parameter I put that dataset out to but apparently it's not the name of parameter kubeflow is looking for.
I just figured this out.
When you run multiple outputs, you use the following in the compile section:
rule_3_hd = rule_3(input_data = merge_promo_sales_nl.outputs['output_data_hd'])
rule_3_shop = rule_3(input_data = merge_promo_sales_nl.outputs['output_data_shop'])

Can't make apache beam write outputs to bigquery when using DataflowRunner

I'm trying to understand why this pipeline writes no output to BigQuery.
What I'm trying to achieve is to calculate the USD index for the last 10 years, starting from different currency pairs observations.
All the data is in BigQuery and I need to organize it and sort it in a chronollogical way (if there is a better way to achieve this, I'm glad to read it because I think this might not be the optimal way to do this).
The idea behing the class Currencies() is to start grouping (and keep) the last observation of a currency pair (eg: EURUSD), update all currency pair values as they "arrive", sort them chronologically and finally get the open, high, low and close value of the USD index for that day.
This code works in my jupyter notebook and in cloud shell using DirectRunner, but when I use DataflowRunner it does not write any output. In order to see if I could figure it out, I tried to just create the data using beam.Create() and then write it to BigQuery (which it worked) and also just read something from BQ and write it on other table (also worked), so my best guess is that the problem is in the beam.CombineGlobally part, but I don't know what it is.
The code is as follows:
import logging
import collections
import apache_beam as beam
from datetime import datetime
SYMBOLS = ['usdjpy', 'usdcad', 'usdchf', 'eurusd', 'audusd', 'nzdusd', 'gbpusd']
TABLE_SCHEMA = "date:DATETIME,index:STRING,open:FLOAT,high:FLOAT,low:FLOAT,close:FLOAT"
class Currencies(beam.CombineFn):
def create_accumulator(self):
return {}
def add_input(self,accumulator,inputs):
logging.info(inputs)
date,currency,bid = inputs.values()
if '.' not in date:
date = date+'.0'
date = datetime.strptime(date,'%Y-%m-%dT%H:%M:%S.%f')
data = currency+':'+str(bid)
accumulator[date] = [data]
return accumulator
def merge_accumulators(self,accumulators):
merged = {}
for accum in accumulators:
ordered_data = collections.OrderedDict(sorted(accum.items()))
prev_date = None
for date,date_data in ordered_data.items():
if date not in merged:
merged[date] = {}
if prev_date is None:
prev_date = date
else:
prev_data = merged[prev_date]
merged[date].update(prev_data)
prev_date = date
for data in date_data:
currency,bid = data.split(':')
bid = float(bid)
currency = currency.lower()
merged[date].update({
currency:bid
})
return merged
def calculate_index_value(self,data):
return data['usdjpy']*data['usdcad']*data['usdchf']/(data['eurusd']*data['audusd']*data['nzdusd']*data['gbpusd'])
def extract_output(self,accumulator):
ordered = collections.OrderedDict(sorted(accumulator.items()))
index = {}
for dt,currencies in ordered.items():
if not all([symbol in currencies.keys() for symbol in SYMBOLS]):
continue
date = str(dt.date())
index_value = self.calculate_index_value(currencies)
if date not in index:
index[date] = {
'date':date,
'index':'usd',
'open':index_value,
'high':index_value,
'low':index_value,
'close':index_value
}
else:
max_value = max(index_value,index[date]['high'])
min_value = min(index_value,index[date]['low'])
close_value = index_value
index[date].update({
'high':max_value,
'low':min_value,
'close':close_value
})
return index
def main():
query = """
select date,currency,bid from data_table
where date(date) between '2022-01-13' and '2022-01-16'
and currency like ('%USD%')
"""
options = beam.options.pipeline_options.PipelineOptions(
temp_location = 'gs://PROJECT/temp',
project = 'PROJECT',
runner = 'DataflowRunner',
region = 'REGION',
num_workers = 1,
max_num_workers = 1,
machine_type = 'n1-standard-1',
save_main_session = True,
staging_location = 'gs://PROJECT/stag'
)
with beam.Pipeline(options = options) as pipeline:
inputs = (pipeline
| 'Read From BQ' >> beam.io.ReadFromBigQuery(query=query,use_standard_sql=True)
| 'Accumulate' >> beam.CombineGlobally(Currencies())
| 'Flat' >> beam.ParDo(lambda x: x.values())
| beam.io.Write(beam.io.WriteToBigQuery(
table = 'TABLE',
dataset = 'DATASET',
project = 'PROJECT',
schema = TABLE_SCHEMA))
)
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
main()
They way I execute this is from shell, using python3 -m first_script (is this the way I should run this batch jobs?).
What I'm missing or doing wrong? This is my first attemp to use Dataflow, so i'm probably making several mistakes in the book.
For whom it may help: I faced a similar problem but I already used the same code for a different flow that had a pubsub as input where it worked flawless instead a file based input where it simply did not. After a lot of experimenting I found that in the options I changed the flag
options = PipelineOptions(streaming=True, ..
to
options = PipelineOptions(streaming=False,
as of course it is not a streaming source, it's a bounded source, a batch. After I set this flag to true I found my rows in the BigQuery table. After it had finished it even stopped the pipeline as it where a batch operation. Hope this helps

Airflow Pipeline to read CSVs and load into PostgreSQL

So, I am trying to write an airflow Dag to 1) Read a few different CSVs from my local desk, 2) Create different PostgresQL tables, 3) Load the files into their respective tables. When I am running the DAG, the second step seems to fail.
Below are the DAG logic operators code:
AIRFLOW_HOME = os.getenv('AIRFLOW_HOME')
def get_listings_data ():
listings = pd.read_csv(AIRFLOW_HOME + '/dags/data/listings.csv')
return listings
def get_g01_data ():
demographics= pd.read_csv(AIRFLOW_HOME + '/dags/data/demographics.csv')
return demographics
def insert_listing_data_func(**kwargs):
ps_pg_hook = PostgresHook(postgres_conn_id="postgres")
conn_ps = ps_pg_hook.get_conn()
ti = kwargs['ti']
insert_df = pd.DataFrame.listings
if len(insert_df) > 0:
col_names = ['host_id', 'host_name', 'host_neighbourhood', 'host_total_listings_count', 'neighbourhood_cleansed', 'property_type', 'price', 'has_availability', 'availability_30']
values = insert_df[col_names].to_dict('split')
values = values['data']
logging.info(values)
insert_sql = """
INSERT INTO assignment_2.listings (host_name, host_neighbourhood, host_total_listings_count, neighbourhood_cleansed, property_type, price, has_availability, availability_30)
VALUES %s
"""
result = execute_values(conn_ps.cursor(), insert_sql, values, page_size=len(insert_df))
conn_ps.commit()
else:
None
return None
def insert_demographics_data_func(**kwargs):
ps_pg_hook = PostgresHook(postgres_conn_id="postgres")
conn_ps = ps_pg_hook.get_conn()
ti = kwargs['ti']
insert_df = pd.DataFrame.demographics
if len(insert_df) > 0:
col_names = ['LGA', 'Median_age_persons', 'Median_mortgage_repay_monthly', 'Median_tot_prsnl_inc_weekly', 'Median_rent_weekly', 'Median_tot_fam_inc_weekly', 'Average_num_psns_per_bedroom', 'Median_tot_hhd_inc_weekly', 'Average_household_size']
values = insert_df[col_names].to_dict('split')
values = values['data']
logging.info(values)
insert_sql = """
INSERT INTO assignment_2.demographics (LGA,Median_age_persons,Median_mortgage_repay_monthly,Median_tot_prsnl_inc_weekly,Median_rent_weekly,Median_tot_fam_inc_weekly,Average_num_psns_per_bedroom,Median_tot_hhd_inc_weekly,Average_household_size)
VALUES %s
"""
result = execute_values(conn_ps.cursor(), insert_sql, values, page_size=len(insert_df))
conn_ps.commit()
else:
None
return None
And my postgresQL hook for the demographics table (just an example) is below:
create_psql_table_demographics= PostgresOperator(
task_id="create_psql_table_demographics",
postgres_conn_id="postgres",
sql="""
CREATE TABLE IF NOT EXISTS postgres.demographics (
LGA VARCHAR,
Median_age_persons INT,
Median_mortgage_repay_monthly INT,
Median_tot_prsnl_inc_weekly INT,
Median_rent_weekly INT,
Median_tot_fam_inc_weekly INT,
Average_num_psns_per_bedroom DECIMAL(10,1),
Median_tot_hhd_inc_weekly INT,
Average_household_size DECIMAL(10,2)
);
""",
dag=dag)
Am I missing something in my code that stops the completion of that create_psql_table_demographics from running successfully on Airflow?
If your Postgresql database has access to the CSV files,
you may simply use the copy_expert method of the PostgresHook class (cf documentation).
Postgresql is pretty efficient in loading flat files: you'll save a lot of cpu cycles by not involving python (and Pandas!), not to mention the potential encoding issues that you would have to address.

Simplest way complex dask graph creation

There is a complex system of calculations over some objects.
The difficulty is that some calculations are group calculations.
This can demonstrate by the following example:
from dask distributed import client
def load_data_from_db(id):
# load some data
...
return data
def task_a(data):
# some calculations
...
return result
def group_task(*args):
# some calculations
...
return result
def task_b(data, group_data):
# some calculations
...
return result
def task_c(data, task_a_result)
# some calculations
...
return result
ids = [1, 2]
dsk = {'id_{}'.format(i): id for i, id in enumerate(ids)}
dsk['data_0'] = (load_data_from_db, 'id_0')
dsk['data_1'] = (load_data_from_db, 'id_1')
dsk['task_a_result_0'] = (task_a, 'data_0')
dsk['task_a_result_1'] = (task_a, 'data_1')
dsk['group_result'] = (
group_task,
'data_0', 'task_a_result_0',
'data_1', 'task_a_result_1')
dsk['task_b_result_0'] = (task_b, 'data_0', 'group_result')
dsk['task_b_result_1'] = (task_b, 'data_1', 'group_result')
dsk['task_c_result_0'] = (task_c, 'data_0', 'task_a_result_0')
dsk['task_c_result_1'] = (task_c, 'data_1', 'task_a_result_1')
client = Client(scheduler_address)
result = client.get(
dsk,
['task_a_result_0',
'task_b_result_0',
'task_c_result_0',
'task_a_result_1',
'task_b_result_1',
'task_c_result_1'])
The list of objects is counted is thousands elements, and the number of tasks is dozens (including several group tasks).
With such method of graph creation it is difficult to modify the graph (add new tasks, change dependencies, etc.).
Is there a more efficient way of distributed computing using dask for these context?
Added
With futures graph is:
client = Client(scheduler_address)
ids = [1, 2]
data = client.map(load_data_from_db, ids)
result_a = client.map(task_a, data)
group_args = list(chain(*zip(data, result_a)))
result_group = client.submit(task_group, *group_args)
result_b = client.map(task_b, data, [result_group] * len(ids))
result_c = client.map(task_c, data, result_a)
result = client.gather(result_a + result_b + result_c)
And in task functions input arguments is Future instance then arg.result() before use.
If you want to modify the computation during computation then I recommend the futures interface.

Categories