I am trying to create a DAG in which one of the task does athena query using boto3. It worked for one query however I am facing issues when I try to run multiple athena queries.
This problem can be broken as follows:-
If one goes through this blog, it can be seen that athena uses start_query_execution to trigger query and get_query_execution for getting status, queryExecutionId and other data about the query (docs for athena)
After following the above pattern I have following code:-
import json
import time
import asyncio
import boto3
import logging
from airflow import DAG
from airflow.operators.python import PythonOperator
def execute_query(client, query, database, output_location):
response = client.start_query_execution(
QueryString=query,
QueryExecutionContext={
'Database': database
},
ResultConfiguration={
'OutputLocation': output_location
}
)
return response['QueryExecutionId']
async def get_ids(client_athena, query, database, output_location):
query_responses = []
for i in range(5):
query_responses.append(execute_query(client_athena, query, database, output_location))
res = await asyncio.gather(*query_responses, return_exceptions=True)
return res
def run_athena_query(query, database, output_location, region_name, **context):
BOTO_SESSION = boto3.Session(
aws_access_key_id = 'YOUR_KEY',
aws_secret_access_key = 'YOUR_ACCESS_KEY')
client_athena = BOTO_SESSION.client('athena', region_name=region_name)
loop = asyncio.get_event_loop()
query_execution_ids = loop.run_until_complete(get_ids(client_athena, query, database, output_location))
loop.close()
repetitions = 900
error_messages = []
s3_uris = []
while repetitions > 0 and len(query_execution_ids) > 0:
repetitions = repetitions - 1
query_response_list = client_athena.batch_get_query_execution(
QueryExecutionIds=query_execution_ids)['QueryExecutions']
for query_response in query_response_list:
if 'QueryExecution' in query_response and \
'Status' in query_response['QueryExecution'] and \
'State' in query_response['QueryExecution']['Status']:
state = query_response['QueryExecution']['Status']['State']
if state in ['FAILED', 'CANCELLED']:
error_reason = query_response['QueryExecution']['Status']['StateChangeReason']
error_message = 'Final state of Athena job is {}, query_execution_id is {}. Error: {}'.format(
state, query_execution_id, error_message
)
error_messages.append(error_message)
query_execution_ids.remove(query_response['QueryExecutionId'])
elif state == 'SUCCEEDED':
result_location = query_response['QueryExecution']['ResultConfiguration']['OutputLocation']
s3_uris.append(result_location)
query_execution_ids.remove(query_response['QueryExecutionId'])
time.sleep(2)
logging.exception(error_messages)
return s3_uris
DEFAULT_ARGS = {
'owner': 'ubuntu',
'depends_on_past': True,
'start_date': datetime(2021, 6, 8),
'retries': 0,
'concurrency': 2
}
with DAG('resync_job_dag', default_args=DEFAULT_ARGS, schedule_interval=None) as dag:
ATHENA_QUERY = PythonOperator(
task_id='athena_query',
python_callable=run_athena_query,
provide_context=True,
op_kwargs={
'query': 'SELECT request_timestamp FROM "sampledb"."elb_logs" limit 10;', # query provide in athena tutorial
'database':'sampledb',
'output_location':'YOUR_BUCKET',
'region_name':'YOUR_REGION'
}
)
ATHENA_QUERY
On running above code, I am getting following error:-
[2021-06-16 20:34:52,981] {taskinstance.py:1455} ERROR - An asyncio.Future, a coroutine or an awaitable is required
Traceback (most recent call last):
File "/home/ubuntu/venv/lib/python3.6/site-packages/airflow/models/taskinstance.py", line 1112, in _run_raw_task
self._prepare_and_execute_task_with_callbacks(context, task)
File "/home/ubuntu/venv/lib/python3.6/site-packages/airflow/models/taskinstance.py", line 1285, in _prepare_and_execute_task_with_callbacks
result = self._execute_task(context, task_copy)
File "/home/ubuntu/venv/lib/python3.6/site-packages/airflow/models/taskinstance.py", line 1315, in _execute_task
result = task_copy.execute(context=context)
File "/home/ubuntu/venv/lib/python3.6/site-packages/airflow/operators/python.py", line 117, in execute
return_value = self.execute_callable()
File "/home/ubuntu/venv/lib/python3.6/site-packages/airflow/operators/python.py", line 128, in execute_callable
return self.python_callable(*self.op_args, **self.op_kwargs)
File "/home/ubuntu/iac-airflow/dags/helper/tasks.py", line 93, in run_athena_query
query_execution_ids = loop.run_until_complete(get_ids(client_athena, query, database, output_location))
File "/usr/lib/python3.6/asyncio/base_events.py", line 484, in run_until_complete
return future.result()
File "/home/ubuntu/iac-airflow/dags/helper/tasks.py", line 79, in get_ids
res = await asyncio.gather(*query_responses, return_exceptions=True)
File "/usr/lib/python3.6/asyncio/tasks.py", line 602, in gather
fut = ensure_future(arg, loop=loop)
File "/usr/lib/python3.6/asyncio/tasks.py", line 526, in ensure_future
raise TypeError('An asyncio.Future, a coroutine or an awaitable is '
TypeError: An asyncio.Future, a coroutine or an awaitable is required
I am unable to get where I am going wrong. Would appreciate some hint over the issue
I think what you are doing here isn't really needed.
Your issues ares:
Executing multiple queries in parallel.
Being able to recover queryExecutionId per query.
Both issues are solved simply by using AWSAthenaOperator. The operator already handles everything you mentioned for you.
Example:
from airflow.models import DAG
from airflow.utils.dates import days_ago
from airflow.operators.dummy import DummyOperator
from airflow.providers.amazon.aws.operators.athena import AWSAthenaOperator
with DAG(
dag_id="athena",
schedule_interval='#daily',
start_date=days_ago(1),
catchup=False,
) as dag:
start_op = DummyOperator(task_id="start_task")
query_list = ["SELECT 1;", "SELECT 2;" "SELECT 3;"]
for i, sql in enumerate(query_list):
run_query = AWSAthenaOperator(
task_id=f'run_query_{i}',
query=sql,
output_location='s3://my-bucket/my-path/',
database='my_database'
)
start_op >> query_op
Athena tasks will be created dynamically simply by adding more queries to query_list:
Note that the QueryExecutionId is pushed to xcom thus you can access the in a downstream task if needed.
Following as well worked for me. I just complicated simple problem with asyncio.
Since I needed S3 URIs for each query at last therefore I went for writing script from scratch. In the current implementation of AWSAthenaOperator, one can get the queryExecutionId and then do the remaining processing(i.e create another task) for getting S3 URI of CSV result file. This can add some overhead in terms of delay between two tasks(of getting queryExecutionId and retrieving S3 URI) along with added resource usuage.
Therefore I went for doing the complete operation in a single operator as follows:-
Code:-
import json
import time
import asyncio
import boto3
import logging
from airflow import DAG
from airflow.operators.python import PythonOperator
def execute_query(client, query, database, output_location):
response = client.start_query_execution(
QueryString=query,
QueryExecutionContext={
'Database': database
},
ResultConfiguration={
'OutputLocation': output_location
}
)
return response
def run_athena_query(query, database, output_location, region_name, **context):
BOTO_SESSION = boto3.Session(
aws_access_key_id = 'YOUR_KEY',
aws_secret_access_key = 'YOUR_ACCESS_KEY')
client_athena = BOTO_SESSION.client('athena', region_name=region_name)
query_execution_ids = []
if message_list:
for parameter in message_list:
query_response = execute_query(client_athena, query, database, output_location)
query_execution_ids.append(query_response['QueryExecutionId'])
else:
raise Exception(
'Error in upstream value recived from kafka consumer. Got message list as - {}, with type {}'
.format(message_list, type(message_list))
)
repetitions = 900
error_messages = []
s3_uris = []
while repetitions > 0 and len(query_execution_ids) > 0:
repetitions = repetitions - 1
query_response_list = client_athena.batch_get_query_execution(
QueryExecutionIds=query_execution_ids)['QueryExecutions']
for query_response in query_response_list:
if 'QueryExecution' in query_response and \
'Status' in query_response['QueryExecution'] and \
'State' in query_response['QueryExecution']['Status']:
state = query_response['QueryExecution']['Status']['State']
if state in ['FAILED', 'CANCELLED']:
error_reason = query_response['QueryExecution']['Status']['StateChangeReason']
error_message = 'Final state of Athena job is {}, query_execution_id is {}. Error: {}'.format(
state, query_execution_id, error_message
)
error_messages.append(error_message)
query_execution_ids.remove(query_response['QueryExecutionId'])
elif state == 'SUCCEEDED':
result_location = query_response['QueryExecution']['ResultConfiguration']['OutputLocation']
s3_uris.append(result_location)
query_execution_ids.remove(query_response['QueryExecutionId'])
time.sleep(2)
logging.exception(error_messages)
return s3_uris
DEFAULT_ARGS = {
'owner': 'ubuntu',
'depends_on_past': True,
'start_date': datetime(2021, 6, 8),
'retries': 0,
'concurrency': 2
}
with DAG('resync_job_dag', default_args=DEFAULT_ARGS, schedule_interval=None) as dag:
ATHENA_QUERY = PythonOperator(
task_id='athena_query',
python_callable=run_athena_query,
provide_context=True,
op_kwargs={
'query': 'SELECT request_timestamp FROM "sampledb"."elb_logs" limit 10;', # query provide in athena tutorial
'database':'sampledb',
'output_location':'YOUR_BUCKET',
'region_name':'YOUR_REGION'
}
)
ATHENA_QUERY
However, the approach shared by #Elad is more clean and apt if one wants to get queryExecutionIds of all the queries.
Related
The below code works but my requirement is to pass totalbuckets as an input to the function as opposed to global variable. I am having trouble passing it as a variable and do xcom_pull in next task. This dag basically creates buckets based on the number of inputs and totalbuckets is a constant. Appreciate your help in advance.
from airflow import DAG
from airflow.operators.python import PythonOperator, BranchPythonOperator
with DAG('test-live', catchup=False, schedule_interval=None, default_args=args) as test_live:
totalbuckets = 3
# branches based on number of buckets
def branch_buckets(**context):
buckets = defaultdict(list)
for i in range(len(inputs_to_process)):
buckets[f'bucket_{(1+i % totalbuckets)}'].append(inputs_to_process[i])
for bucket_name, input_sublist in buckets.items():
context['ti'].xcom_push(key = bucket_name, value = input_sublist)
return list(buckets.keys())
# BranchPythonOperator will launch the buckets and distributes inputs among the buckets
branch_buckets = BranchPythonOperator(
task_id='branch_buckets',
python_callable=branch_buckets,
trigger_rule=TriggerRule.NONE_FAILED,
provide_context=True,
dag=test_live
)
# update provider tables with merge sql
def update_inputs(sf_conn_id, bucket_name, **context):
input_sublist = context['ti'].xcom_pull(task_ids='branch_buckets', key=bucket_name)
print(f"Processing inputs {input_sublist} in {bucket_name}")
from custom.hooks.snowflake_hook import SnowflakeHook
for p in input_sublist:
merge_sql=f"""
merge into ......"""
bucket_tasks = []
for i in range(totalbuckets):
task= PythonOperator(
task_id=f'bucket_{i+1}',
python_callable=update_inputs,
provide_context=True,
op_kwargs={'bucket_name':f'bucket_{i+1}','sf_conn_id': SF_CONN_ID},
dag=test_live
)
bucket_tasks.append(task)
If totalbuckets is different from run to other, it should be a run conf variable, you can provide it for each run crated from the UI, CLI, Airflow REST API or even python API.
from airflow import DAG
from airflow.operators.python import PythonOperator, BranchPythonOperator
from airflow.models.param import Param
with DAG(
'test-live',
catchup=False,
schedule_interval=None,
default_args=args,
params={"totalbuckets": Param(default=3, type="integer")},
) as test_live:
# branches based on number of buckets
def branch_buckets(**context):
buckets = defaultdict(list)
for i in range(len(inputs_to_process)):
buckets[f'bucket_{(1+i % int("{{ params.totalbuckets }}"))}'].append(inputs_to_process[i])
for bucket_name, input_sublist in buckets.items():
context['ti'].xcom_push(key = bucket_name, value = input_sublist)
return list(buckets.keys())
# BranchPythonOperator will launch the buckets and distributes inputs among the buckets
branch_buckets = BranchPythonOperator(
task_id='branch_buckets',
python_callable=branch_buckets,
trigger_rule=TriggerRule.NONE_FAILED,
provide_context=True,
dag=test_live
)
# update provider tables with merge sql
def update_inputs(sf_conn_id, bucket_name, **context):
input_sublist = context['ti'].xcom_pull(task_ids='branch_buckets', key=bucket_name)
print(f"Processing inputs {input_sublist} in {bucket_name}")
from custom.hooks.snowflake_hook import SnowflakeHook
for p in input_sublist:
merge_sql=f"""
merge into ......"""
bucket_tasks = []
for i in range(int("{{ params.totalbuckets }}")):
task= PythonOperator(
task_id=f'bucket_{i+1}',
python_callable=update_inputs,
provide_context=True,
op_kwargs={'bucket_name':f'bucket_{i+1}','sf_conn_id': SF_CONN_ID},
dag=test_live
)
bucket_tasks.append(task)
Example to run it:
airflow dags trigger --conf '{"totalbuckets": 10}' test-live
Or via the UI.
update:
And if it's static, but different from an environment to other, it can be an Airflow variable, and read it directly in the tasks using jinja to avoid reading it at each Dag Files processing.
But if it's completely static, the most recommended solution is using python variable as you do, because to read dag run conf and Airflow variables, the task/dag send a query to the database.
#hussein awala I am doing something like below but cannot parse totalbuckets in bucket_tasks
from airflow.operators.python import PythonOperator, BranchPythonOperator
with DAG('test-live', catchup=False, schedule_interval=None, default_args=args) as test_live:
#totalbuckets = 3
def branch_buckets(totalbuckets, **context):
buckets = defaultdict(list)
for i in range(len(inputs_to_process)):
buckets[f'bucket_{(1+i % totalbuckets)}'].append(inputs_to_process[i])
for bucket_name, input_sublist in buckets.items():
context['ti'].xcom_push(key = bucket_name, value = input_sublist)
return list(buckets.keys())
# BranchPythonOperator will launch the buckets and distributes inputs among the buckets
branch_buckets = BranchPythonOperator(
task_id='branch_buckets',
python_callable=branch_buckets,
trigger_rule=TriggerRule.NONE_FAILED,
provide_context=True, op_kwargs={'totalbuckets':3},
dag=test_live
)
# update provider tables with merge sql
def update_inputs(sf_conn_id, bucket_name, **context):
input_sublist = context['ti'].xcom_pull(task_ids='branch_buckets', key=bucket_name)
print(f"Processing inputs {input_sublist} in {bucket_name}")
from custom.hooks.snowflake_hook import SnowflakeHook
for p in input_sublist:
merge_sql=f"""
merge into ......"""
bucket_tasks = []
for i in range(totalbuckets):
task= PythonOperator(
task_id=f'bucket_{i+1}',
python_callable=update_inputs,
provide_context=True,
op_kwargs={'bucket_name':f'bucket_{i+1}','sf_conn_id': SF_CONN_ID},
dag=test_live
)
bucket_tasks.append(task)```
I am running MWAA on local docker with Airflow2.0.2, and I am trying to get data from snowflake and store into a local folder and I am getting the below "missing 1 required positional ". How to rectify this error?
I have checked arguments to snowflake_to_pandas and it matched required arguments, but still I get this error
from airflow import DAG
from contextlib import closing
from airflow.models import Variable
from airflow import models
import pandas as pd
import time
from datetime import datetime, timedelta
from airflow.contrib.hooks.snowflake_hook import *
from airflow.providers.snowflake.operators.snowflake import *
#f
from airflow.operators.python_operator import PythonOperator
default_args = {
'owner': 'Datexland',
'depends_on_past': False,
'start_date': datetime(2021, 6, 14),
'email': ['ata.kmu#roer.com'],
'email_on_failure': False,
'email_on_retry': False,
'retries': 1,
'retry_delay': timedelta(minutes=5),
}
# queries
query_test = """select * from DATA_CS_DEV.MING.CTY limit 1000;"""
query = """select current_date() as date;"""
def execute_snowflake(sql, snowflake_conn_id, with_cursor=False):
"""Execute snowflake query."""
hook_connection = SnowflakeHook(
snowflake_conn_id=snowflake_conn_id
)
with closing(hook_connection.get_conn()) as conn:
with closing(conn.cursor()) as cur:
cur.execute(sql)
res = cur.fetchall()
if with_cursor:
return (res, cur)
else:
return res
def snowflake_to_pandas(query_test, snowflake_conn_id,**kwargs):
"""Convert snowflake list to df."""
result, cursor = execute_snowflake(query_test, snowflake_conn_id, True)
headers = list(map(lambda t: t[0], cursor.description))
df = pd.DataFrame(result)
df.columns = headers
# save file before to send
# NOTE : This is not recommended in case of multi-worker deployments
df.show()
df.to_csv('/usr/local/airflow/data.csv',header=True,mode='w',sep=',')
return 'This File Sent Successfully'
dag = DAG(
'SNOWFLAKE_TO_S3',
default_args=default_args,
schedule_interval=None,
max_active_runs=1,
catchup=False
)
# Connection Test
snowflake_dag = SnowflakeOperator(
task_id='test_snowflake_connection',
sql=query,
snowflake_conn_id='snowflake_conn_id',
dag=dag
)
# Data Upload Task
upload_stage = PythonOperator(task_id='UPLOAD_FILE_INTO_S3_BUCKET',
python_callable=snowflake_to_pandas,
op_kwargs={"query":query_test,"snowflake_conn_id":'snowflake_conn_id'},
dag=dag )
snowflake_dag >> upload_stage
I am working on polling boto3 to check the status of a SageMaker Autopilot job using Airflow. I am using a PythonSensor to wait for the status to return Completed for both JobStatus and JobSecondaryStatus, then end the entire pipeline. These are the values that they can contain which I made enums of in the code:
'AutoMLJobStatus': 'Completed'|'InProgress'|'Failed'|'Stopped'|'Stopping',
'AutoMLJobSecondaryStatus': 'Starting'|'AnalyzingData'|'FeatureEngineering'|'ModelTuning'|'MaxCandidatesReached'|'Failed'|'Stopped'|'MaxAutoMLJobRuntimeReached'|'Stopping'|'CandidateDefinitionsGenerated'|'GeneratingExplainabilityReport'|'Completed'|'ExplainabilityError'|'DeployingModel'|'ModelDeploymentError'
_sagemaker_job_status takes automl_job_name through xcom from an upstream task and it successfully gets passed. With this job name I can pass it to descibe_auto_ml_job() to get the status through AutoMLJobStatus and AutoMLJobSecondaryStatus.
The main point of this is for messaging through Slack to see all the unique stages the job is at. Currently, I am trying to save all the unique job statuses to a set and then checking that set before sending a message with the job statuses in it.
But everytime _sagemaker_job_status is poked, the values of the set seem to be the same therefore sending a slack message everytime the function is poked, I logged the sets and both are empty. Below this I made a simpler example that worked.
import airflow
from airflow import DAG
from airflow.exceptions import AirflowFailException
from airflow.operators.dummy import DummyOperator
from airflow.operators.python import PythonOperator
from airflow.sensors.python import PythonSensor
import boto3
def _sagemaker_job_status(templates_dict, **context):
"""
Checks the SageMaker AutoMLJobStatus and AutoMLJobSecondaryStatus
for updates and when both are complete the entire process is marked as
successful
"""
automl_job_name = templates_dict.get("automl_job_name")
if not automl_job_name:
error_message = "AutoMLJobName was not passed from upstream"
print(error_message)
task_fail_slack_alert(
context=context,
extra_message=error_message,
)
client = boto3.client("sagemaker", "us-east-1")
response = client.describe_auto_ml_job(
AutoMLJobName=automl_job_name,
)
job_status = response.get("AutoMLJobStatus")
secondary_job_status = response.get("AutoMLJobSecondaryStatus")
past_job_statuses = set()
past_secondary_job_statuses = set()
print(f"Past Job Statuses : {past_job_statuses}")
print(f"Past Secondary Job Statuses : {past_secondary_job_statuses}")
# If the job status has not been already seen
if (
job_status not in past_job_statuses
and secondary_job_status not in past_secondary_job_statuses
):
message = f"""
JobStatus : {job_status}
JobSecondaryStatus : {secondary_job_status}
"""
print(message)
task_success_slack_alert(
context=context,
extra_message=message,
)
past_job_statuses.add(job_status)
past_secondary_job_statuses.add(secondary_job_status)
# If the main job fails
if job_status == JobStatus.Failed.value:
error_message = "SageMaker Autopilot Job Failed!"
task_fail_slack_alert(
context=context,
extra_message=error_message,
)
raise AirflowFailException(error_message)
return (
job_status == JobStatus.Completed.value
and secondary_job_status == JobSecondaryStatus.Completed.value
)
args = {
"owner": "Yudhiesh",
"start_date": airflow.utils.dates.days_ago(1),
"schedule_interval": "#once",
"on_failure_callback": task_fail_slack_alert,
}
with DAG(
dag_id="02_lasic_retraining_sagemaker_autopilot",
default_args=args,
render_template_as_native_obj=True,
) as dag:
sagemaker_job_status = PythonSensor(
task_id="sagemaker_job_status",
python_callable=_sagemaker_job_status,
templates_dict={
"automl_job_name": "{{task_instance.xcom_pull(task_ids='train_model_sagemaker_autopilot')}}", # noqa: E501
},
)
end = DummyOperator(
task_id="end",
)
sagemaker_job_status >> end
I created a similar setup as before but this time I randomly generated the values from an enum of JobStatus & JobSecondaryStatus and tried to only print the values if they are unique, and turns out it works perfectly. Could anyone explain why this happens and what I can do to the main example to get it to work?
import airflow
import random
from airflow import DAG
from airflow.sensors.python import PythonSensor
from airflow.operators.dummy import DummyOperator
from airflow.exceptions import AirflowFailException
def _mimic_sagemaker_job_status():
job_statuses = [status.value for status in JobStatus]
job_secondary_statuses = [
secondary_status.value for secondary_status in JobSecondaryStatus
]
past_job_statuses = set()
past_secondary_job_statuses = set()
job_status = random.choice(job_statuses)
job_secondary_status = random.choice(job_secondary_statuses)
if (
job_status not in past_job_statuses
and job_secondary_status not in past_secondary_job_statuses
):
message = f"""
JobStatus : {job_status}
JobSecondaryStatus : {job_secondary_status}
"""
# Send alerts on every new job status update
print(message)
past_job_statuses.add(job_status)
past_secondary_job_statuses.add(job_secondary_status)
if (
job_status == JobStatus.Failed.value
or job_secondary_status == JobSecondaryStatus.Failed.value
):
raise AirflowFailException("SageMaker Autopilot Job Failed!")
return (
job_secondary_status == JobSecondaryStatus.Completed.value
and job_status == JobStatus.Completed.value
)
with DAG(
dag_id="04_sagemaker_sensor",
start_date=airflow.utils.dates.days_ago(3),
schedule_interval="#once",
render_template_as_native_obj=True,
) as dag:
wait_for_status = PythonSensor(
task_id="wait_for_status",
python_callable=_mimic_sagemaker_job_status,
dag=dag,
)
end = DummyOperator(
task_id="end",
)
wait_for_status >> end
Enums used in the above code:
from enum import Enum
class JobStatus(Enum):
"""
Enum of all the potential values of a SageMaker Autopilot job status
"""
Completed = "Completed"
InProgress = "InProgress"
Failed = "Failed"
Stopped = "Stopped"
Stopping = "Stopping"
class JobSecondaryStatus(Enum):
"""
Enum of all the potential values of a SageMaker Autopilot job secondary
status
"""
Starting = "Starting"
AnalyzingData = "AnalyzingData"
FeatureEngineering = "FeatureEngineering"
ModelTuning = "ModelTuning"
MaxCandidatesReached = "MaxCandidatesReached"
Failed = "Failed"
Stopped = "Stopped"
MaxAutoMLJobRuntimeReached = "MaxAutoMLJobRuntimeReached"
Stopping = "Stopping"
CandidateDefinitionsGenerated = "CandidateDefinitionsGenerated"
GeneratingExplainabilityReport = "GeneratingExplainabilityReport"
Completed = "Completed"
ExplainabilityError = "ExplainabilityError"
DeployingModel = "DeployingModel"
ModelDeploymentError = "ModelDeploymentError"
EDIT:
I suppose another work around for the main example would be to have an operator create a temporary file containing JSON of the set before the sagemaker job status, then within the sagemaker job status I can check the job statuses saved to the file and then print them if they are unique. I just realised that I can make use of the database as well.
So I couldn't seem to get it working as it is so I resorted to creating a JSON file that stores the different SageMaker Autopilot job statuses which I read and write to in the PythonSensor.
This takes in the AutoMLJobName from the previous step, creates a temporary file of the job statuses, and returns the AutoMLJobName and the name of the JSON file.
import tempfile
def _create_job_status_json(templates_dict, **context):
automl_job_name = templates_dict.get("sagemaker_autopilot_data_paths")
if not automl_job_name:
error_message = "AutoMLJobName was not passed from upstream"
print(error_message)
task_fail_slack_alert(
context=context,
extra_message=error_message,
)
initial = {
"JobStatus": [],
"JobSecondaryStatus": [],
}
file = tempfile.NamedTemporaryFile(mode="w", delete=False)
json.dump({"Status": initial}, file)
file.flush()
return (file.name, automl_job_name)
Next this function reads the JSON file based on the name and then checks the different job statuses based on the boto3 sagemaker client. If the main job fails then the whole run fails. It adds the job statuses to a dictionary if one of them are unique. Once that is done it will write the dictionary to the JSON file. When the entire job finishes, it sends some details about the best model as a Slack message. It returns true when both job statuses are Completed. Just a note, I also removed the JSON file if the job is successfull or if it fails.
import airflow
from airflow import DAG
from airflow.exceptions import AirflowFailException
import boto3
def _sagemaker_job_status(templates_dict, **context):
"""
Checks the SageMaker AutoMLJobStatus and AutoMLJobSecondaryStatus
for updates and when both are complete the entire process is marked as
successful
"""
file_name, automl_job_name = templates_dict.get("automl_job_data")
job_status_dict = {}
client = boto3.client("sagemaker", "us-east-1")
if not client:
raise AirflowFailException(
"Unable to get access to boto3 sagemaker client",
)
with open(file_name, "r") as json_file:
response = client.describe_auto_ml_job(
AutoMLJobName=automl_job_name,
)
job_status = response.get("AutoMLJobStatus")
secondary_job_status = response.get("AutoMLJobSecondaryStatus")
job_status_dict = json.load(json_file)
status = job_status_dict.get("Status")
past_job_statuses = status.get("JobStatus")
past_secondary_job_statuses = status.get("JobSecondaryStatus")
if job_status == JobStatus.Failed.value:
error_message = "SageMaker Autopilot Job Failed!"
task_fail_slack_alert(
context=context,
extra_message=error_message,
)
os.remove(file_name)
raise AirflowFailException(error_message)
if (
job_status not in past_job_statuses
or secondary_job_status not in past_secondary_job_statuses
):
message = f"""
JobStatus : {job_status}
JobSecondaryStatus : {secondary_job_status}
"""
print(message)
task_success_slack_alert(
context=context,
extra_message=message,
)
past_job_statuses.append(job_status)
past_secondary_job_statuses.append(secondary_job_status)
with open(file_name, "w") as file:
json.dump(job_status_dict, file)
if (
job_status == JobStatus.Completed.value
and secondary_job_status == JobSecondaryStatus.Completed.value
):
os.remove(file_name)
response = client.describe_auto_ml_job(
AutoMLJobName=automl_job_name,
)
best_candidate = response.get("BestCandidate")
best_candidate_id = best_candidate.get("CandidateName")
best_metric_name = (
best_candidate.get("FinalAutoMLJobObjectiveMetric")
.get("MetricName")
.split(":")[1]
.upper()
)
best_metric_value = round(
best_candidate.get("FinalAutoMLJobObjectiveMetric").get(
"Value",
),
3,
)
message = f"""
Best Candidate ID : {best_candidate_id}
Best Candidate Metric Score : {best_metric_value}{best_metric_name}
""" # noqa: E501
task_success_slack_alert(
context=context,
extra_message=message,
)
return (
job_status == JobStatus.Completed.value
and secondary_job_status == JobSecondaryStatus.Completed.value
)
DAG code:
import airflow
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.sensors.python import PythonSensor
args = {
"owner": "Yudhiesh",
"start_date": airflow.utils.dates.days_ago(1),
"schedule_interval": "#once",
"on_failure_callback": task_fail_slack_alert,
}
with DAG(
dag_id="02_lasic_retraining_sagemaker_autopilot",
default_args=args,
render_template_as_native_obj=True,
) as dag:
create_job_status_json = PythonOperator(
task_id="create_job_status_json",
python_callable=_create_job_status_json,
templates_dict={
"sagemaker_autopilot_data_paths": "{{task_instance.xcom_pull(task_ids='train_model_sagemaker_autopilot')}}", # noqa: E501
},
)
sagemaker_job_status = PythonSensor(
task_id="sagemaker_job_status",
python_callable=_sagemaker_job_status,
templates_dict={
"automl_job_data": "{{task_instance.xcom_pull(task_ids='create_job_status_json')}}", # noqa: E501
},
)
# train_model_sagemaker_autopilot is not included but it initiates the training through boto3
train_model_sagemaker_autopilot >> create_job_status_json
create_job_status_json >> sagemaker_job_status
I'm trying to rendered correctly data inside a SimpleHttpOperator in Airflow with configuration that I send via dag_run
result = SimpleHttpOperator(
task_id="schema_detector",
http_conn_id='schema_detector',
endpoint='api/schema/infer',
method='PUT',
data=json.dumps({
'url': '{{ dag_run.conf["url"] }}',
'fileType': '{{ dag_run.conf["fileType"] }}',
}),
response_check=lambda response: response.ok,
response_filter=lambda response: response.json())
Issue is that the rendered data appears to be like this
{"url": "{{ dag_run.conf[\"url\"] }}", "fileType": "{{ dag_run.conf[\"fileType\"] }}"}
I'm not sure what I'm doing wrong here.
Edit
Full code below
default_args = {
'owner': 'airflow',
'start_date': days_ago(0),
}
def print_result(**kwargs):
ti = kwargs['ti']
pulled_value_1 = ti.xcom_pull(task_ids='schema_detector')
pprint.pprint(pulled_value_1)
with DAG(
dag_id='airflow_http_operator',
default_args=default_args,
catchup=False,
schedule_interval="#once",
tags=['http']
) as dag:
result = SimpleHttpOperator(
task_id="schema_detector",
http_conn_id='schema_detector',
endpoint='api/schema/infer',
method='PUT',
headers={"Content-Type": "application/json"},
data=json.dumps({
'url': '{{ dag_run.conf["url"] }}',
'fileType': '{{ dag_run.conf["fileType"] }}',
}),
response_check=lambda response: response.ok,
response_filter=lambda response: response.json())
pull = PythonOperator(
task_id='print_result',
python_callable=print_result,
)
result >> pull
I struggled a lot due to the same error. So, I created my own Operator (called as ExtendedHttpOperator) which is a combination of PythonOperator and SimpleHttpOperator. This worked for me :)
This operator receives a function where we can collect data passed from the API (using dag_run.conf), and parse it (if required) before passing it to an API.
from plugins.operators.extended_http_operator import ExtendedHttpOperator
testing_extend = ExtendedHttpOperator(
task_id="process_user_ids",
http_conn_id="user_api",
endpoint="/kafka",
headers={"Content-Type": "application/json"},
data_fn=passing_data,
op_kwargs={"api": "kafka"},
method="POST",
log_response=True,
response_check=lambda response: True
if validate_response(response) is True
else False,
)
def passing_data(**context):
api = context["api"]
dag_run_conf = context["dag_run"].conf
return json.dumps(dag_run_conf[api])
def validate_response(res):
if res.status_code == 200:
return True
else:
return False
Here is how you can add ExtendedHttpOperator to your airflow:
Put extended_http_operator.py file inside your_airflow_project/plugins/operators folder
# extended_http_operator.py file
from airflow.operators.http_operator import SimpleHttpOperator
from airflow.utils.decorators import apply_defaults
from airflow.exceptions import AirflowException
from airflow.hooks.http_hook import HttpHook
from typing import Optional, Dict
"""
Extend Simple Http Operator with a callable function to formulate data. This data function will
be able to access the context to retrieve data such as task instance. This allow us to write cleaner
code rather than writing one long template line to formulate the json data.
"""
class ExtendedHttpOperator(SimpleHttpOperator):
#apply_defaults
def __init__(
self,
data_fn,
log_response: bool = False,
op_kwargs: Optional[Dict] = None,
*args,
**kwargs
):
super(ExtendedHttpOperator, self).__init__(*args, **kwargs)
if not callable(data_fn):
raise AirflowException("`data_fn` param must be callable")
self.data_fn = data_fn
self.context = None
self.op_kwargs = op_kwargs or {}
self.log_response = log_response
def execute(self, context):
context.update(self.op_kwargs)
self.context = context
http = HttpHook(self.method, http_conn_id=self.http_conn_id)
data_result = self.execute_callable(context)
self.log.info("Calling HTTP method")
self.log.info("Post Data: {}".format(data_result))
response = http.run(
self.endpoint, data_result, self.headers, self.extra_options
)
if self.log_response:
self.log.info(response.text)
if self.response_check:
if not self.response_check(response):
raise AirflowException("Invalid parameters")
def execute_callable(self, context):
return self.data_fn(**context)
Dont forget to create empty __init__.py files inside plugins and plugins/operators folders.
I couldn't find a solution.
Only way I could do this, like passing information that I send over the --conf to the operator was adding a new PythonOperator that collect the info, and using then XCom on my SimpleHTTPOperator
Code
def generate_data(**kwargs):
confs = kwargs['dag_run'].conf
logging.info(confs)
return {'url': confs["url"], 'fileType': confs["fileType"]}
with DAG(
dag_id='airflow_http_operator',
default_args=default_args,
catchup=False,
schedule_interval="#once",
tags=['http']
) as dag:
generate_dict = PythonOperator(
task_id='generate_dict',
python_callable=generate_data,
provide_context=True
)
result = SimpleHttpOperator(
task_id="schema_detector",
http_conn_id='schema_detector',
endpoint='api/schema/infer',
method='PUT',
headers={"Content-Type": "application/json"},
data="{{ task_instance.xcom_pull(task_ids='generate_dict') |tojson}}",
log_response=True,
response_check=lambda response: response.ok,
response_filter=lambda response: response.json())
I use FastAPI to develope data layer APIs accessing SQL Server.
No mater using pytds or pyodbc,
if there is a database transaction caused any request hangs,
all the other requests would be blocked. (even without database operation)
Reproduce:
Intentaionally do a serializable SQL Server session, begin a transaction and do not rollback or commit
INSERT INTO [dbo].[KVStore] VALUES ('1', '1', 0)
begin tran
SET TRANSACTION ISOLATION LEVEL Serializable
SELECT * FROM [dbo].[KVStore]
Send a request to the API with async handler function like this:
def kv_delete_by_key_2_sql():
conn = pytds.connect(dsn='192.168.0.1', database=cfg.kvStore_db, user=cfg.kvStore_uid,
password=cfg.kvStore_upwd, port=1435, autocommit=True)
engine = conn.cursor()
try:
sql = "delete KVStore; commit"
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(engine.execute, sql)
rs = future.result()
j = {
'success': True,
'rowcount': rs.rowcount
}
return jsonable_encoder(j)
except Exception as exn:
j = {
'success': False,
'reason': exn_handle(exn)
}
return jsonable_encoder(j)
#app.post("/kvStore/delete")
async def kv_delete(request: Request, type_: Optional[str] = Query(None, max_length=50)):
request_data = await request.json()
return kv_delete_by_key_2_sql()
And send a request to the API of the same app with async handler function like this:
async def hangit0(request: Request, t: int = Query(0)):
print(t, datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
await asyncio.sleep(t)
print(t, datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3])
j = {
'success': True
}
return jsonable_encoder(j)
#app.get("/kvStore/hangit/")
async def hangit(request: Request, t: int = Query(0)):
return await hangit0(request, t)
I expected step.2 would hang and step.3 should directly return after 2 seconds.
However step.3 never return if the transaction doesn't commit or rollback...
How do I make these handler functions work concurrently?
The reason is that rs = future.result() is actually a blocking call - see python docs. Unfortunately, executor.submit() doesn't return an awaitable object (concurrent.futures.Future is different from asyncio.Future.
You can use asyncio.wrap_future which takes concurrent.futures.Future and returns asyncio.Future (see python docs). The new Future object is awaitable thus you can convert your blocking function into an async function.
An Example:
import asyncio
import concurrent.futures
async def my_async():
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(lambda x: x + 1, 1)
return await asyncio.wrap_future(future)
print(asyncio.run(my_async()))
In your code, simply change the rs = future.result() to rs = await asyncio.wrap_future(future) and make the whole function async. That should do the magic, good luck! :)