The below code works but my requirement is to pass totalbuckets as an input to the function as opposed to global variable. I am having trouble passing it as a variable and do xcom_pull in next task. This dag basically creates buckets based on the number of inputs and totalbuckets is a constant. Appreciate your help in advance.
from airflow import DAG
from airflow.operators.python import PythonOperator, BranchPythonOperator
with DAG('test-live', catchup=False, schedule_interval=None, default_args=args) as test_live:
totalbuckets = 3
# branches based on number of buckets
def branch_buckets(**context):
buckets = defaultdict(list)
for i in range(len(inputs_to_process)):
buckets[f'bucket_{(1+i % totalbuckets)}'].append(inputs_to_process[i])
for bucket_name, input_sublist in buckets.items():
context['ti'].xcom_push(key = bucket_name, value = input_sublist)
return list(buckets.keys())
# BranchPythonOperator will launch the buckets and distributes inputs among the buckets
branch_buckets = BranchPythonOperator(
task_id='branch_buckets',
python_callable=branch_buckets,
trigger_rule=TriggerRule.NONE_FAILED,
provide_context=True,
dag=test_live
)
# update provider tables with merge sql
def update_inputs(sf_conn_id, bucket_name, **context):
input_sublist = context['ti'].xcom_pull(task_ids='branch_buckets', key=bucket_name)
print(f"Processing inputs {input_sublist} in {bucket_name}")
from custom.hooks.snowflake_hook import SnowflakeHook
for p in input_sublist:
merge_sql=f"""
merge into ......"""
bucket_tasks = []
for i in range(totalbuckets):
task= PythonOperator(
task_id=f'bucket_{i+1}',
python_callable=update_inputs,
provide_context=True,
op_kwargs={'bucket_name':f'bucket_{i+1}','sf_conn_id': SF_CONN_ID},
dag=test_live
)
bucket_tasks.append(task)
If totalbuckets is different from run to other, it should be a run conf variable, you can provide it for each run crated from the UI, CLI, Airflow REST API or even python API.
from airflow import DAG
from airflow.operators.python import PythonOperator, BranchPythonOperator
from airflow.models.param import Param
with DAG(
'test-live',
catchup=False,
schedule_interval=None,
default_args=args,
params={"totalbuckets": Param(default=3, type="integer")},
) as test_live:
# branches based on number of buckets
def branch_buckets(**context):
buckets = defaultdict(list)
for i in range(len(inputs_to_process)):
buckets[f'bucket_{(1+i % int("{{ params.totalbuckets }}"))}'].append(inputs_to_process[i])
for bucket_name, input_sublist in buckets.items():
context['ti'].xcom_push(key = bucket_name, value = input_sublist)
return list(buckets.keys())
# BranchPythonOperator will launch the buckets and distributes inputs among the buckets
branch_buckets = BranchPythonOperator(
task_id='branch_buckets',
python_callable=branch_buckets,
trigger_rule=TriggerRule.NONE_FAILED,
provide_context=True,
dag=test_live
)
# update provider tables with merge sql
def update_inputs(sf_conn_id, bucket_name, **context):
input_sublist = context['ti'].xcom_pull(task_ids='branch_buckets', key=bucket_name)
print(f"Processing inputs {input_sublist} in {bucket_name}")
from custom.hooks.snowflake_hook import SnowflakeHook
for p in input_sublist:
merge_sql=f"""
merge into ......"""
bucket_tasks = []
for i in range(int("{{ params.totalbuckets }}")):
task= PythonOperator(
task_id=f'bucket_{i+1}',
python_callable=update_inputs,
provide_context=True,
op_kwargs={'bucket_name':f'bucket_{i+1}','sf_conn_id': SF_CONN_ID},
dag=test_live
)
bucket_tasks.append(task)
Example to run it:
airflow dags trigger --conf '{"totalbuckets": 10}' test-live
Or via the UI.
update:
And if it's static, but different from an environment to other, it can be an Airflow variable, and read it directly in the tasks using jinja to avoid reading it at each Dag Files processing.
But if it's completely static, the most recommended solution is using python variable as you do, because to read dag run conf and Airflow variables, the task/dag send a query to the database.
#hussein awala I am doing something like below but cannot parse totalbuckets in bucket_tasks
from airflow.operators.python import PythonOperator, BranchPythonOperator
with DAG('test-live', catchup=False, schedule_interval=None, default_args=args) as test_live:
#totalbuckets = 3
def branch_buckets(totalbuckets, **context):
buckets = defaultdict(list)
for i in range(len(inputs_to_process)):
buckets[f'bucket_{(1+i % totalbuckets)}'].append(inputs_to_process[i])
for bucket_name, input_sublist in buckets.items():
context['ti'].xcom_push(key = bucket_name, value = input_sublist)
return list(buckets.keys())
# BranchPythonOperator will launch the buckets and distributes inputs among the buckets
branch_buckets = BranchPythonOperator(
task_id='branch_buckets',
python_callable=branch_buckets,
trigger_rule=TriggerRule.NONE_FAILED,
provide_context=True, op_kwargs={'totalbuckets':3},
dag=test_live
)
# update provider tables with merge sql
def update_inputs(sf_conn_id, bucket_name, **context):
input_sublist = context['ti'].xcom_pull(task_ids='branch_buckets', key=bucket_name)
print(f"Processing inputs {input_sublist} in {bucket_name}")
from custom.hooks.snowflake_hook import SnowflakeHook
for p in input_sublist:
merge_sql=f"""
merge into ......"""
bucket_tasks = []
for i in range(totalbuckets):
task= PythonOperator(
task_id=f'bucket_{i+1}',
python_callable=update_inputs,
provide_context=True,
op_kwargs={'bucket_name':f'bucket_{i+1}','sf_conn_id': SF_CONN_ID},
dag=test_live
)
bucket_tasks.append(task)```
Related
I am working on polling boto3 to check the status of a SageMaker Autopilot job using Airflow. I am using a PythonSensor to wait for the status to return Completed for both JobStatus and JobSecondaryStatus, then end the entire pipeline. These are the values that they can contain which I made enums of in the code:
'AutoMLJobStatus': 'Completed'|'InProgress'|'Failed'|'Stopped'|'Stopping',
'AutoMLJobSecondaryStatus': 'Starting'|'AnalyzingData'|'FeatureEngineering'|'ModelTuning'|'MaxCandidatesReached'|'Failed'|'Stopped'|'MaxAutoMLJobRuntimeReached'|'Stopping'|'CandidateDefinitionsGenerated'|'GeneratingExplainabilityReport'|'Completed'|'ExplainabilityError'|'DeployingModel'|'ModelDeploymentError'
_sagemaker_job_status takes automl_job_name through xcom from an upstream task and it successfully gets passed. With this job name I can pass it to descibe_auto_ml_job() to get the status through AutoMLJobStatus and AutoMLJobSecondaryStatus.
The main point of this is for messaging through Slack to see all the unique stages the job is at. Currently, I am trying to save all the unique job statuses to a set and then checking that set before sending a message with the job statuses in it.
But everytime _sagemaker_job_status is poked, the values of the set seem to be the same therefore sending a slack message everytime the function is poked, I logged the sets and both are empty. Below this I made a simpler example that worked.
import airflow
from airflow import DAG
from airflow.exceptions import AirflowFailException
from airflow.operators.dummy import DummyOperator
from airflow.operators.python import PythonOperator
from airflow.sensors.python import PythonSensor
import boto3
def _sagemaker_job_status(templates_dict, **context):
"""
Checks the SageMaker AutoMLJobStatus and AutoMLJobSecondaryStatus
for updates and when both are complete the entire process is marked as
successful
"""
automl_job_name = templates_dict.get("automl_job_name")
if not automl_job_name:
error_message = "AutoMLJobName was not passed from upstream"
print(error_message)
task_fail_slack_alert(
context=context,
extra_message=error_message,
)
client = boto3.client("sagemaker", "us-east-1")
response = client.describe_auto_ml_job(
AutoMLJobName=automl_job_name,
)
job_status = response.get("AutoMLJobStatus")
secondary_job_status = response.get("AutoMLJobSecondaryStatus")
past_job_statuses = set()
past_secondary_job_statuses = set()
print(f"Past Job Statuses : {past_job_statuses}")
print(f"Past Secondary Job Statuses : {past_secondary_job_statuses}")
# If the job status has not been already seen
if (
job_status not in past_job_statuses
and secondary_job_status not in past_secondary_job_statuses
):
message = f"""
JobStatus : {job_status}
JobSecondaryStatus : {secondary_job_status}
"""
print(message)
task_success_slack_alert(
context=context,
extra_message=message,
)
past_job_statuses.add(job_status)
past_secondary_job_statuses.add(secondary_job_status)
# If the main job fails
if job_status == JobStatus.Failed.value:
error_message = "SageMaker Autopilot Job Failed!"
task_fail_slack_alert(
context=context,
extra_message=error_message,
)
raise AirflowFailException(error_message)
return (
job_status == JobStatus.Completed.value
and secondary_job_status == JobSecondaryStatus.Completed.value
)
args = {
"owner": "Yudhiesh",
"start_date": airflow.utils.dates.days_ago(1),
"schedule_interval": "#once",
"on_failure_callback": task_fail_slack_alert,
}
with DAG(
dag_id="02_lasic_retraining_sagemaker_autopilot",
default_args=args,
render_template_as_native_obj=True,
) as dag:
sagemaker_job_status = PythonSensor(
task_id="sagemaker_job_status",
python_callable=_sagemaker_job_status,
templates_dict={
"automl_job_name": "{{task_instance.xcom_pull(task_ids='train_model_sagemaker_autopilot')}}", # noqa: E501
},
)
end = DummyOperator(
task_id="end",
)
sagemaker_job_status >> end
I created a similar setup as before but this time I randomly generated the values from an enum of JobStatus & JobSecondaryStatus and tried to only print the values if they are unique, and turns out it works perfectly. Could anyone explain why this happens and what I can do to the main example to get it to work?
import airflow
import random
from airflow import DAG
from airflow.sensors.python import PythonSensor
from airflow.operators.dummy import DummyOperator
from airflow.exceptions import AirflowFailException
def _mimic_sagemaker_job_status():
job_statuses = [status.value for status in JobStatus]
job_secondary_statuses = [
secondary_status.value for secondary_status in JobSecondaryStatus
]
past_job_statuses = set()
past_secondary_job_statuses = set()
job_status = random.choice(job_statuses)
job_secondary_status = random.choice(job_secondary_statuses)
if (
job_status not in past_job_statuses
and job_secondary_status not in past_secondary_job_statuses
):
message = f"""
JobStatus : {job_status}
JobSecondaryStatus : {job_secondary_status}
"""
# Send alerts on every new job status update
print(message)
past_job_statuses.add(job_status)
past_secondary_job_statuses.add(job_secondary_status)
if (
job_status == JobStatus.Failed.value
or job_secondary_status == JobSecondaryStatus.Failed.value
):
raise AirflowFailException("SageMaker Autopilot Job Failed!")
return (
job_secondary_status == JobSecondaryStatus.Completed.value
and job_status == JobStatus.Completed.value
)
with DAG(
dag_id="04_sagemaker_sensor",
start_date=airflow.utils.dates.days_ago(3),
schedule_interval="#once",
render_template_as_native_obj=True,
) as dag:
wait_for_status = PythonSensor(
task_id="wait_for_status",
python_callable=_mimic_sagemaker_job_status,
dag=dag,
)
end = DummyOperator(
task_id="end",
)
wait_for_status >> end
Enums used in the above code:
from enum import Enum
class JobStatus(Enum):
"""
Enum of all the potential values of a SageMaker Autopilot job status
"""
Completed = "Completed"
InProgress = "InProgress"
Failed = "Failed"
Stopped = "Stopped"
Stopping = "Stopping"
class JobSecondaryStatus(Enum):
"""
Enum of all the potential values of a SageMaker Autopilot job secondary
status
"""
Starting = "Starting"
AnalyzingData = "AnalyzingData"
FeatureEngineering = "FeatureEngineering"
ModelTuning = "ModelTuning"
MaxCandidatesReached = "MaxCandidatesReached"
Failed = "Failed"
Stopped = "Stopped"
MaxAutoMLJobRuntimeReached = "MaxAutoMLJobRuntimeReached"
Stopping = "Stopping"
CandidateDefinitionsGenerated = "CandidateDefinitionsGenerated"
GeneratingExplainabilityReport = "GeneratingExplainabilityReport"
Completed = "Completed"
ExplainabilityError = "ExplainabilityError"
DeployingModel = "DeployingModel"
ModelDeploymentError = "ModelDeploymentError"
EDIT:
I suppose another work around for the main example would be to have an operator create a temporary file containing JSON of the set before the sagemaker job status, then within the sagemaker job status I can check the job statuses saved to the file and then print them if they are unique. I just realised that I can make use of the database as well.
So I couldn't seem to get it working as it is so I resorted to creating a JSON file that stores the different SageMaker Autopilot job statuses which I read and write to in the PythonSensor.
This takes in the AutoMLJobName from the previous step, creates a temporary file of the job statuses, and returns the AutoMLJobName and the name of the JSON file.
import tempfile
def _create_job_status_json(templates_dict, **context):
automl_job_name = templates_dict.get("sagemaker_autopilot_data_paths")
if not automl_job_name:
error_message = "AutoMLJobName was not passed from upstream"
print(error_message)
task_fail_slack_alert(
context=context,
extra_message=error_message,
)
initial = {
"JobStatus": [],
"JobSecondaryStatus": [],
}
file = tempfile.NamedTemporaryFile(mode="w", delete=False)
json.dump({"Status": initial}, file)
file.flush()
return (file.name, automl_job_name)
Next this function reads the JSON file based on the name and then checks the different job statuses based on the boto3 sagemaker client. If the main job fails then the whole run fails. It adds the job statuses to a dictionary if one of them are unique. Once that is done it will write the dictionary to the JSON file. When the entire job finishes, it sends some details about the best model as a Slack message. It returns true when both job statuses are Completed. Just a note, I also removed the JSON file if the job is successfull or if it fails.
import airflow
from airflow import DAG
from airflow.exceptions import AirflowFailException
import boto3
def _sagemaker_job_status(templates_dict, **context):
"""
Checks the SageMaker AutoMLJobStatus and AutoMLJobSecondaryStatus
for updates and when both are complete the entire process is marked as
successful
"""
file_name, automl_job_name = templates_dict.get("automl_job_data")
job_status_dict = {}
client = boto3.client("sagemaker", "us-east-1")
if not client:
raise AirflowFailException(
"Unable to get access to boto3 sagemaker client",
)
with open(file_name, "r") as json_file:
response = client.describe_auto_ml_job(
AutoMLJobName=automl_job_name,
)
job_status = response.get("AutoMLJobStatus")
secondary_job_status = response.get("AutoMLJobSecondaryStatus")
job_status_dict = json.load(json_file)
status = job_status_dict.get("Status")
past_job_statuses = status.get("JobStatus")
past_secondary_job_statuses = status.get("JobSecondaryStatus")
if job_status == JobStatus.Failed.value:
error_message = "SageMaker Autopilot Job Failed!"
task_fail_slack_alert(
context=context,
extra_message=error_message,
)
os.remove(file_name)
raise AirflowFailException(error_message)
if (
job_status not in past_job_statuses
or secondary_job_status not in past_secondary_job_statuses
):
message = f"""
JobStatus : {job_status}
JobSecondaryStatus : {secondary_job_status}
"""
print(message)
task_success_slack_alert(
context=context,
extra_message=message,
)
past_job_statuses.append(job_status)
past_secondary_job_statuses.append(secondary_job_status)
with open(file_name, "w") as file:
json.dump(job_status_dict, file)
if (
job_status == JobStatus.Completed.value
and secondary_job_status == JobSecondaryStatus.Completed.value
):
os.remove(file_name)
response = client.describe_auto_ml_job(
AutoMLJobName=automl_job_name,
)
best_candidate = response.get("BestCandidate")
best_candidate_id = best_candidate.get("CandidateName")
best_metric_name = (
best_candidate.get("FinalAutoMLJobObjectiveMetric")
.get("MetricName")
.split(":")[1]
.upper()
)
best_metric_value = round(
best_candidate.get("FinalAutoMLJobObjectiveMetric").get(
"Value",
),
3,
)
message = f"""
Best Candidate ID : {best_candidate_id}
Best Candidate Metric Score : {best_metric_value}{best_metric_name}
""" # noqa: E501
task_success_slack_alert(
context=context,
extra_message=message,
)
return (
job_status == JobStatus.Completed.value
and secondary_job_status == JobSecondaryStatus.Completed.value
)
DAG code:
import airflow
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.sensors.python import PythonSensor
args = {
"owner": "Yudhiesh",
"start_date": airflow.utils.dates.days_ago(1),
"schedule_interval": "#once",
"on_failure_callback": task_fail_slack_alert,
}
with DAG(
dag_id="02_lasic_retraining_sagemaker_autopilot",
default_args=args,
render_template_as_native_obj=True,
) as dag:
create_job_status_json = PythonOperator(
task_id="create_job_status_json",
python_callable=_create_job_status_json,
templates_dict={
"sagemaker_autopilot_data_paths": "{{task_instance.xcom_pull(task_ids='train_model_sagemaker_autopilot')}}", # noqa: E501
},
)
sagemaker_job_status = PythonSensor(
task_id="sagemaker_job_status",
python_callable=_sagemaker_job_status,
templates_dict={
"automl_job_data": "{{task_instance.xcom_pull(task_ids='create_job_status_json')}}", # noqa: E501
},
)
# train_model_sagemaker_autopilot is not included but it initiates the training through boto3
train_model_sagemaker_autopilot >> create_job_status_json
create_job_status_json >> sagemaker_job_status
I am trying to create a DAG in which one of the task does athena query using boto3. It worked for one query however I am facing issues when I try to run multiple athena queries.
This problem can be broken as follows:-
If one goes through this blog, it can be seen that athena uses start_query_execution to trigger query and get_query_execution for getting status, queryExecutionId and other data about the query (docs for athena)
After following the above pattern I have following code:-
import json
import time
import asyncio
import boto3
import logging
from airflow import DAG
from airflow.operators.python import PythonOperator
def execute_query(client, query, database, output_location):
response = client.start_query_execution(
QueryString=query,
QueryExecutionContext={
'Database': database
},
ResultConfiguration={
'OutputLocation': output_location
}
)
return response['QueryExecutionId']
async def get_ids(client_athena, query, database, output_location):
query_responses = []
for i in range(5):
query_responses.append(execute_query(client_athena, query, database, output_location))
res = await asyncio.gather(*query_responses, return_exceptions=True)
return res
def run_athena_query(query, database, output_location, region_name, **context):
BOTO_SESSION = boto3.Session(
aws_access_key_id = 'YOUR_KEY',
aws_secret_access_key = 'YOUR_ACCESS_KEY')
client_athena = BOTO_SESSION.client('athena', region_name=region_name)
loop = asyncio.get_event_loop()
query_execution_ids = loop.run_until_complete(get_ids(client_athena, query, database, output_location))
loop.close()
repetitions = 900
error_messages = []
s3_uris = []
while repetitions > 0 and len(query_execution_ids) > 0:
repetitions = repetitions - 1
query_response_list = client_athena.batch_get_query_execution(
QueryExecutionIds=query_execution_ids)['QueryExecutions']
for query_response in query_response_list:
if 'QueryExecution' in query_response and \
'Status' in query_response['QueryExecution'] and \
'State' in query_response['QueryExecution']['Status']:
state = query_response['QueryExecution']['Status']['State']
if state in ['FAILED', 'CANCELLED']:
error_reason = query_response['QueryExecution']['Status']['StateChangeReason']
error_message = 'Final state of Athena job is {}, query_execution_id is {}. Error: {}'.format(
state, query_execution_id, error_message
)
error_messages.append(error_message)
query_execution_ids.remove(query_response['QueryExecutionId'])
elif state == 'SUCCEEDED':
result_location = query_response['QueryExecution']['ResultConfiguration']['OutputLocation']
s3_uris.append(result_location)
query_execution_ids.remove(query_response['QueryExecutionId'])
time.sleep(2)
logging.exception(error_messages)
return s3_uris
DEFAULT_ARGS = {
'owner': 'ubuntu',
'depends_on_past': True,
'start_date': datetime(2021, 6, 8),
'retries': 0,
'concurrency': 2
}
with DAG('resync_job_dag', default_args=DEFAULT_ARGS, schedule_interval=None) as dag:
ATHENA_QUERY = PythonOperator(
task_id='athena_query',
python_callable=run_athena_query,
provide_context=True,
op_kwargs={
'query': 'SELECT request_timestamp FROM "sampledb"."elb_logs" limit 10;', # query provide in athena tutorial
'database':'sampledb',
'output_location':'YOUR_BUCKET',
'region_name':'YOUR_REGION'
}
)
ATHENA_QUERY
On running above code, I am getting following error:-
[2021-06-16 20:34:52,981] {taskinstance.py:1455} ERROR - An asyncio.Future, a coroutine or an awaitable is required
Traceback (most recent call last):
File "/home/ubuntu/venv/lib/python3.6/site-packages/airflow/models/taskinstance.py", line 1112, in _run_raw_task
self._prepare_and_execute_task_with_callbacks(context, task)
File "/home/ubuntu/venv/lib/python3.6/site-packages/airflow/models/taskinstance.py", line 1285, in _prepare_and_execute_task_with_callbacks
result = self._execute_task(context, task_copy)
File "/home/ubuntu/venv/lib/python3.6/site-packages/airflow/models/taskinstance.py", line 1315, in _execute_task
result = task_copy.execute(context=context)
File "/home/ubuntu/venv/lib/python3.6/site-packages/airflow/operators/python.py", line 117, in execute
return_value = self.execute_callable()
File "/home/ubuntu/venv/lib/python3.6/site-packages/airflow/operators/python.py", line 128, in execute_callable
return self.python_callable(*self.op_args, **self.op_kwargs)
File "/home/ubuntu/iac-airflow/dags/helper/tasks.py", line 93, in run_athena_query
query_execution_ids = loop.run_until_complete(get_ids(client_athena, query, database, output_location))
File "/usr/lib/python3.6/asyncio/base_events.py", line 484, in run_until_complete
return future.result()
File "/home/ubuntu/iac-airflow/dags/helper/tasks.py", line 79, in get_ids
res = await asyncio.gather(*query_responses, return_exceptions=True)
File "/usr/lib/python3.6/asyncio/tasks.py", line 602, in gather
fut = ensure_future(arg, loop=loop)
File "/usr/lib/python3.6/asyncio/tasks.py", line 526, in ensure_future
raise TypeError('An asyncio.Future, a coroutine or an awaitable is '
TypeError: An asyncio.Future, a coroutine or an awaitable is required
I am unable to get where I am going wrong. Would appreciate some hint over the issue
I think what you are doing here isn't really needed.
Your issues ares:
Executing multiple queries in parallel.
Being able to recover queryExecutionId per query.
Both issues are solved simply by using AWSAthenaOperator. The operator already handles everything you mentioned for you.
Example:
from airflow.models import DAG
from airflow.utils.dates import days_ago
from airflow.operators.dummy import DummyOperator
from airflow.providers.amazon.aws.operators.athena import AWSAthenaOperator
with DAG(
dag_id="athena",
schedule_interval='#daily',
start_date=days_ago(1),
catchup=False,
) as dag:
start_op = DummyOperator(task_id="start_task")
query_list = ["SELECT 1;", "SELECT 2;" "SELECT 3;"]
for i, sql in enumerate(query_list):
run_query = AWSAthenaOperator(
task_id=f'run_query_{i}',
query=sql,
output_location='s3://my-bucket/my-path/',
database='my_database'
)
start_op >> query_op
Athena tasks will be created dynamically simply by adding more queries to query_list:
Note that the QueryExecutionId is pushed to xcom thus you can access the in a downstream task if needed.
Following as well worked for me. I just complicated simple problem with asyncio.
Since I needed S3 URIs for each query at last therefore I went for writing script from scratch. In the current implementation of AWSAthenaOperator, one can get the queryExecutionId and then do the remaining processing(i.e create another task) for getting S3 URI of CSV result file. This can add some overhead in terms of delay between two tasks(of getting queryExecutionId and retrieving S3 URI) along with added resource usuage.
Therefore I went for doing the complete operation in a single operator as follows:-
Code:-
import json
import time
import asyncio
import boto3
import logging
from airflow import DAG
from airflow.operators.python import PythonOperator
def execute_query(client, query, database, output_location):
response = client.start_query_execution(
QueryString=query,
QueryExecutionContext={
'Database': database
},
ResultConfiguration={
'OutputLocation': output_location
}
)
return response
def run_athena_query(query, database, output_location, region_name, **context):
BOTO_SESSION = boto3.Session(
aws_access_key_id = 'YOUR_KEY',
aws_secret_access_key = 'YOUR_ACCESS_KEY')
client_athena = BOTO_SESSION.client('athena', region_name=region_name)
query_execution_ids = []
if message_list:
for parameter in message_list:
query_response = execute_query(client_athena, query, database, output_location)
query_execution_ids.append(query_response['QueryExecutionId'])
else:
raise Exception(
'Error in upstream value recived from kafka consumer. Got message list as - {}, with type {}'
.format(message_list, type(message_list))
)
repetitions = 900
error_messages = []
s3_uris = []
while repetitions > 0 and len(query_execution_ids) > 0:
repetitions = repetitions - 1
query_response_list = client_athena.batch_get_query_execution(
QueryExecutionIds=query_execution_ids)['QueryExecutions']
for query_response in query_response_list:
if 'QueryExecution' in query_response and \
'Status' in query_response['QueryExecution'] and \
'State' in query_response['QueryExecution']['Status']:
state = query_response['QueryExecution']['Status']['State']
if state in ['FAILED', 'CANCELLED']:
error_reason = query_response['QueryExecution']['Status']['StateChangeReason']
error_message = 'Final state of Athena job is {}, query_execution_id is {}. Error: {}'.format(
state, query_execution_id, error_message
)
error_messages.append(error_message)
query_execution_ids.remove(query_response['QueryExecutionId'])
elif state == 'SUCCEEDED':
result_location = query_response['QueryExecution']['ResultConfiguration']['OutputLocation']
s3_uris.append(result_location)
query_execution_ids.remove(query_response['QueryExecutionId'])
time.sleep(2)
logging.exception(error_messages)
return s3_uris
DEFAULT_ARGS = {
'owner': 'ubuntu',
'depends_on_past': True,
'start_date': datetime(2021, 6, 8),
'retries': 0,
'concurrency': 2
}
with DAG('resync_job_dag', default_args=DEFAULT_ARGS, schedule_interval=None) as dag:
ATHENA_QUERY = PythonOperator(
task_id='athena_query',
python_callable=run_athena_query,
provide_context=True,
op_kwargs={
'query': 'SELECT request_timestamp FROM "sampledb"."elb_logs" limit 10;', # query provide in athena tutorial
'database':'sampledb',
'output_location':'YOUR_BUCKET',
'region_name':'YOUR_REGION'
}
)
ATHENA_QUERY
However, the approach shared by #Elad is more clean and apt if one wants to get queryExecutionIds of all the queries.
I created a dag which contains a subdag for loop through a list which is return value of a task.
subdag function
def mySubDag(parent: Text, child: Text, args, **context):
task = context['tasl_instance']
data = task.xcom_pull(task_ids='task1', dag_id=parent)
for d in data:
# do something...
parent dag
with DAG(...) as dag:
task1 = PythonOperator(task_id="task1", ..., providde_context=True, dag=dag)
task2 = SubDagOperator(subdag=mySubDag(...),..., provide_context=True, dag=dag)
task1 >> task2
I dont know where to put the argument 'context' or how to put it for the subdag function to use it.
really appreciate if any one could help to resolve it.
code define xcom_pull in taskinstance.py
def xcom_pull(
self,
task_ids=None,
dag_id=None,
key=XCOM_RETURN_KEY,
include_prior_dates=False):
if dag_id is None:
dag_id = self.dag_id
...
it's pass dag_id of current dag to xom_pull
so if you want to get data from parent dag, please override dag_id argument with parent's dag_id
For your example, pass more context with op_kwargs:
def set_cookies_func(config, **kwargs):
cookies = service_get_cookies_login(config)
kwargs['ti'].xcom_push(key="SESSION", value=cookies)
def get_data_func(parent_dag_name, **kwargs):
cookies = kwargs['ti'].xcom_pull(
task_ids='set_cookies_task',
key="SESSION",
dag_id="my_dag_id.set_cookies_task"
)
def sub_cache_load_to_gcs(parent_dag_name, child_dag_name):
sub_dag = DAG(
dag_id= "{}.{}".format(parent_dag_name, child_dag_name),
...
)
PythonOperator(
task_id="get_data_func",
python_callable=get_data_func,
op_kwargs={"parent_dag_name": parent_dag_name},
providde_context=True,
dag=sub_dag
)
with DAG(
dag_id= "my_dag_id",
...
) as dag:
task1 = PythonOperator(
task_id="set_cookies_task",
python_callable=set_cookies_func,
op_kwargs={"config": config},
providde_context=True,
dag=dag
)
task2 = SubDagOperator(
task_id='branch_cache_task',
subdag=sub_cache_load_to_gcs(dag.dag_id, 'branch_cache_task'),
provide_context=True,
dag=dag
)
task1 >> task2
As we get run_id in Airflow, how to get timestamp(ts)?
First:
In your task set provide_context=True
bye_operator = PythonOperator(
task_id='bye_task',
python_callable=print_goodbye,
provide_context=True,
dag=dag
)
Second:
Ensure you are passing the known arguments into your callback function:
def print_goodbye(**kwargs):
ts = kwargs.get('ts', None)
print(ts)
return 'Good bye world!'
I have 3 tasks to run in same dags. While Task1 return list of dictionary task2 and task3 try to use one dictionary element from result return by
task1.
def get_list():
....
return listOfDict
def parse_1(example_dict):
...
def parse_2(example_dict):
...
dag = DAG('dagexample', default_args=default_args)
data_list = PythonOperator(
task_id='get_lists',
python_callable=get_list,
dag=dag)
for data in data_list:
sub_task1 = PythonOperator(
task_id='data_parse1' + data['id'],
python_callable=parse_1,
op_kwargs={'dataObject': data},
dag=dag,
)
sub_task2 = PythonOperator(
task_id='data_parse2' + data['id'],
python_callable=parse_2,
op_kwargs={'dataObject': data},
dag=dag,
)
You should use XCom for passing variables/messages between different task. Take a look at this example: https://github.com/apache/incubator-airflow/blob/master/airflow/example_dags/example_xcom.py
For your case, it should be something similar as below:
default_args = {
'owner': 'airflow',
'start_date': airflow.utils.dates.days_ago(2),
'provide_context': True, # This is needed
}
def get_list():
....
return listOfDict
def parse_1(**kwargs):
ti = kwargs['ti']
# get listOfDict
v1 = ti.xcom_pull(key=None, task_ids='get_lists')
# You can now use this v1 dictionary as a normal python dict
...
def parse_2(**kwargs):
ti = kwargs['ti']
# get listOfDict
v1 = ti.xcom_pull(key=None, task_ids='get_lists')
...
dag = DAG('dagexample', default_args=default_args)
data_list = PythonOperator(
task_id='get_lists',
python_callable=get_list,
dag=dag)
for data in get_list():
sub_task1 = PythonOperator(
task_id='data_parse1' + data['id'],
python_callable=parse_1,
op_kwargs={'dataObject': data},
dag=dag,
)
sub_task2 = PythonOperator(
task_id='data_parse2' + data['id'],
python_callable=parse_2,
op_kwargs={'dataObject': data},
dag=dag,
)
You can use XComs as they are designed for inter-task communication. If your dictionary is very big, then I recommend storing it as a csv file.
Generally, tasks in Airflow don't share data between them, so XComs are a way to achieve them but are limited to small amounts of data.