Celery - bottom line: I want to get the task name by using the task id (I don't have a task object)
Suppose I have this code:
res = chain(add.s(4,5), add.s(10)).delay()
cache.save_task_id(res.task_id)
And then in some other place:
task_id = cache.get_task_ids()[0]
task_name = get_task_name_by_id(task_id) #how?
print(f'Some information about the task status of: {task_name}')
I know I can get the task name if I have a task object, like here: celery: get function name by task id?.
But I don't have a task object (perhaps it can be created by the task_id or by some other way? I didn't see anything related to that in the docs).
In addition, I don't want to save in the cache the task name. (Suppose I have a very long chain/other celery primitives, I don't want to save all their names/task_ids. Just the last task_id should be enough to get all the information regarding all the tasks, using .parents, etc)
I looked at all the relevant methods of AsyncResult and AsyncResult.Backend objects. The only thing that seemed relevant is backend.get_task_meta(task_id), but that doesn't contain the task name.
Thanks in advance
PS: AsyncResult.name always returns None:
result = AsyncResult(task_id, app=celery_app)
result.name #Returns None
result.args #Also returns None
Finally found an answer.
For anyone wondering:
You can solve this by enabling result_extended = True in your celery config.
Then:
result = AsyncResult(task_id, app=celery_app)
result.task_name #tasks.add
You have to enable it first in Celery configurations:
celery_app = Celery()
...
celery_app.conf.update(result_extended=True)
Then, you can access it:
task = AsyncResult(task_id, app=celery_app)
task.name
Something like the following (pseudocode) should be enough:
app = Celery("myapp") # add your parameters here
task_id = "6dc5f968-3554-49c9-9e00-df8aaf9e7eb5"
aresult = app.AsyncResult(task_id)
task_name = aresult.name
task_args = aresult.args
print(task_name, task_args)
Unfortunately, it does not work (I would say it is a bug in Celery), so we have to find an alternative. First thing that came to my mind was that Celery CLI has inspect query_task feature, and that hinted me that it would be possible to find task name by using the inspect API, and I was right. Here is the code:
# Since the expected way does not work we need to use the inspect API:
insp = app.control.inspect()
task_ids = [task_id]
inspect_result = insp.query_task(*task_ids)
# print(inspect_result)
for node_name in inspect_result:
val = inspect_result[node_name]
if val:
# we found node that executes the task
arr = val[task_id]
state = arr[0]
meta = arr[1]
task_name = meta["name"]
task_args = meta["args"]
print(task_name, task_args)
Problem with this approach is that it works only while the task is running. The moment it is done you will not be able to use the code above.
This is not very clear from the docs for celery.result.AsyncResult but not all the properties are populated unless you enable result_extended = True as per configuration docs:
result_extended
Default: False
Enables extended task result attributes (name, args, kwargs, worker, retries, queue, delivery_info) to be written to backend.
Then the following will work:
result = AsyncResult(task_id)
result.name = 'project.tasks.my_task'
result.args = [2, 3]
result.kwargs = {'a': 'b'}
Also be aware that the rpc:// backend does not store this data, you will need Redis, or similar. If you are using rpc, even with result_extended = True you will still get None returned.
I found a good answer in this code snippet.
If and when you have an instance of AsyncResult you do not need the task_id, rather you can simply do this:
result # instance of AsyncResult
result_meta = result._get_task_meta()
task_name = result_meta.get("task_name")
Of course this relies on a private method, so it's a bit hacky. I hope celery introduces a simpler way to retrieve this - it's especially useful for testing.
Related
I can't figure out how to use pytest to test a dag task waiting for xcom_arg.
I created the following DAG using the new airflow API syntax :
#dag(...)
def transfer_files():
#task()
def retrieve_existing_files():
existing = []
for elem in os.listdir("./backup"):
existing.append(elem)
return existing
#task()
def get_new_file_to_sync(existing: list[str]):
new_files = []
for elem in os.listdir("./prod"):
if not elem in existing:
new_files.append(elem)
return new_files
r = retrieve_existing_files()
get_new_file_to_sync(r)
Now I want to perform unit testing on the get_new_file_to_sync task. I wrote the following test :
def test_get_new_elan_list():
mocked_existing = ["a.out", "b.out"]
dag_bag = DagBag(include_examples=False)
dag = dag_bag.get_dag("transfer_files")
task = dag.get_task("get_new_file_to_sync")
result = task.execute({}, mocked_existing)
print(result)
The test fails because task.execute is waiting for 2 parameters but 3 were given.
My issue is that I don't have any clue of how to proceed in order to test my tasks waiting for arguments with a mocked custom argument.
Thanks for your insights
I managed to find a way to unit test airflow tasks declared using the new airflow API.
Here is a test case for the task get_new_file_to_sync contained in the DAG transfer_files declared in the question :
def test_get_new_file_to_synct():
mocked_existing = ["a.out", "b.out"]
# Asking airflow to load the dags in its home folder
dag_bag = DagBag(include_examples=False)
# Retrieving the dag to test
dag = dag_bag.get_dag("transfer_files")
# Retrieving the task to test
task = dag.get_task("get_new_file_to_sync")
# extracting the function to test from the task
function_to_unit_test = task.python_callable
# Calling the function normally
results = function_to_unit_test(mocked_existing)
assert len(results) == 10
This allows bypassing all the airflow mechanics triggered before calling the actual code you have written for your task. Thus, you can focus on writing tests for the code you have written for your task.
For testing such a task, I believe you'll need to use mocking from pytest.
Let's take this user defined operator for an example:
class MovielensPopularityOperator(BaseOperator):
def __init__(self, conn_id, start_date, end_date, min_ratings=4, top_n=5, **kwargs):
super().__init__(**kwargs)
self._conn_id = conn_id
self._start_date = start_date
self._end_date = end_date
self._min_ratings = min_ratings
self._top_n = top_n
def execute(self, context):
with MovielensHook(self._conn_id) as hook:
ratings = hook.get_ratings(start_date=self._start_date, end_date=self._end_date)
rating_sums = defaultdict(Counter)
for rating in ratings:
rating_sums[rating["movieId"]].update(count=1, rating=rating["rating"])
averages = {
movie_id: (rating_counter["rating"] / rating_counter["count"], rating_counter["count"])
for movie_id, rating_counter in rating_sums.items()
if rating_counter["count"] >= self._min_ratings
}
return sorted(averages.items(), key=lambda x: x[1], reverse=True)[: self._top_n]
And a test written just like the one you did:
def test_movielenspopularityoperator():
task = MovielensPopularityOperator(
task_id="test_id",
start_date="2015-01-01",
end_date="2015-01-03",
top_n=5,
)
result = task.execute(context={})
assert len(result) == 5
Running this test fail as:
=============================== FAILURES ===============================
___________________ test_movielenspopularityoperator ___________________
mocker = <pytest_mock.plugin.MockFixture object at 0x10fb2ea90>
def test_movielenspopularityoperator(mocker: MockFixture):
task = MovielensPopularityOperator(
➥
>
task_id="test_id", start_date="2015-01-01", end_date="2015-01-
03", top_n=5
)
➥
E
TypeError: __init__() missing 1 required positional argument:
'conn_id'
tests/dags/chapter9/custom/test_operators.py:30: TypeError
========================== 1 failed in 0.10s ==========================
The test failed because we’re missing the required argument conn_id, which points to the connection ID in the metastore. But how do you provide this in a test? Tests should be isolated from each other; they should not be able to influence the results of other tests, so a database shared between tests is not an ideal situation. In this case, mocking comes to the rescue.
Mocking is “faking” certain operations or objects. For example, the call to a database that is expected to exist in a production setting but not while testing could be faked, or mocked, by telling Python to return a certain value instead of making the actual call to the (nonexistent during testing) database. This allows you to develop and run tests without requiring a connection to external systems. It requires insight into the internals of whatever it is you’re testing, and thus sometimes requires you to dive into third-party code.
After installing pytest-mock in your enviroment:
pip install pytest-mock
Here is the test written where mocking is used:
def test_movielenspopularityoperator(mocker):
mocker.patch.object(
MovielensHook,
"get_connection",
return_value=Connection(conn_id="test", login="airflow", password="airflow"),
)
task = MovielensPopularityOperator(
task_id="test_id",
conn_id="test",
start_date="2015-01-01",
end_date="2015-01-03",
top_n=5,
)
result = task.execute(context=None)
assert len(result) == 5
Now, hopefully this will give you an idea about how to write your tests for Airflow Tasks.
For more about mocking and unit tests, you can check here and here.
Hi all I have a function
def get_campaign_active(ds, **kwargs):
logging.info('Checking for inactive campaign types..')
the_db = ds['client']
db = the_db['misc-server']
collection = db.campaigntypes
campaign = list(collection.find({}))
for item in campaign:
if item['active'] == False:
# storing false 'active' campaigns
result = "'{}' active status set to False".format(item['text'])
logging.info("'{}' active status set to False".format(item['text']))
mapped to an airflow task
get_campaign_active = PythonOperator(
task_id='get_campaign_active',
provide_context=True,
python_callable=get_campaign_active,
xcom_push=True,
op_kwargs={'client': client_production},
dag=dag)
As you can see I pass in the client_production variable into op_kwargs with the task. The hope is this variable to be passed in through '**kwargs' parameter in the function when this task is run in airflow.
However for testing, when I try to call the function like so
get_campaign_active({"client":client_production})
The client_production variable is found inside the ds parameter. I don't have a staging server for airflow to test this out, but could someone tell me if I deploy this function/task to airflow, will it read the client_production variable from ds or kwargs?
Right now if I try to access the 'client' key in kwargs, kwargs is empty.
Thanks
You should do:
def get_campaign_active(ds, **kwargs):
logging.info('Checking for inactive campaign types..')
the_db = kwargs['client']
the ds (and all other macros are passed to kwargs as you set provide_context=True, you can either use named params like you did or let the ds be passed into kwargs as well)
Since in your code you don't actually use ds nor any other macros you can change your function signature to get_campaign_active(**kwargs) and remove provide_context=True. Note that from Airflow>=2.0 the provide_context=True is not needed at all.
I am building a flask api that allows users to pass an xml and a transformation that returns the xml on which the transformation is performed using Saxon/C's python API (https://www.saxonica.com/saxon-c/doc/html/saxonc.html).
The incoming endpoint looks like this (removed logging and irrelevant info):
#app.route("/v1/transform/", methods=["POST"])
def transform():
xml = request.data
transformation = request.args.get("transformation")
result = transform_xml(xml, transformation)
return result
The transform function looks like this:
def transform_xml(xml: bytes, transformation: str) -> str:
with saxonc.PySaxonProcessor(license=False) as proc:
base_dir = os.getcwd()
xslt_path = os.path.join(base_dir, "resources", transformation, "main.xslt")
xslt_proc = proc.new_xslt30_processor()
node = proc.parse_xml(xml_text=xml.decode("utf-8"))
result = xslt_proc.transform_to_string(stylesheet_file=xslt_path, xdm_node=node)
return result
The xslt's are locally available and a user should choose one of the available ones by passing the corresponding transformation name.
Now the problem is, this works (fast) for the first incoming call, but the second one crashes:
JNI_CreateJavaVM() failed with result: -5
DAMN ! worker 1 (pid: 517095) died :( trying respawn ...
What does work is changing the transform_xml function like this:
proc = saxonc.PySaxonProcessor(license=False)
xslt_path = self.__get_path_to_xslt(transformation)
xslt_proc = proc.new_xslt30_processor()
node = proc.parse_xml(xml_text=xml.decode("utf-8"))
result = xslt_proc.transform_to_string(stylesheet_file=xslt_path, xdm_node=node)
return result
But this leads to the resources never getting released and over time (1k+ requests) this starts to fill up the memory.
It seems like Saxon is trying to create a new VM while the old one is going down.
I found this thread from 2016: https://saxonica.plan.io/boards/4/topics/6399 but this didn't clear it up for me. I looked at the github for the pysaxon repo, but I have found no answer to this problem.
Also made a ticket at Saxon: https://saxonica.plan.io/issues/4942
Whenever my Spyne application receives a request, XSD validation is performed. This is good, but whenever there is an XSD violation a fault is raised and my app returns a Client.SchemaValidationError like so:
<soap11env:Fault>
<faultcode>soap11env:Client.SchemaValidationError</faultcode>
<faultstring>:25:0:ERROR:SCHEMASV:SCHEMAV_CVC_DATATYPE_VALID_1_2_1: Element '{http://services.sp.pas.ng.org}DateTimeStamp': '2018-07-25T13:01' is not a valid value of the atomic type 'xs:dateTime'.</faultstring>
<faultactor></faultactor>
</soap11env:Fault>
I would like to know how to handle the schema validation error gracefully and return the details in the Details field of my service's out_message, rather than just raising a standard Client.SchemaValidationError. I want to store the details of the error as a variable and pass it to my OperationOne function.
Here is my code, I have changed var names for sensitivity.
TNS = "http://services.so.example.org"
class InMessageType(ComplexModel):
__namespace__ = TNS
class Attributes(ComplexModel.Attributes):
declare_order = 'declared'
field_one = Unicode(values=["ONE", "TWO"],
min_occurs=1)
field_two = Unicode(20, min_occurs=1)
field_three = Unicode(20, min_occurs=0)
Confirmation = Unicode(values=["ACCEPTED", "REJECTED"], min_occurs=1)
FileReason = Unicode(200, min_occurs=0)
DateTimeStamp = DateTime(min_occurs=1)
class OperationOneResponse(ComplexModel):
__namespace__ = TNS
class Attributes(ComplexModel.Attributes):
declare_order = 'declared'
ResponseMessage = Unicode(values=["SUCCESS", "FAILURE"], min_occurs=1)
Details = Unicode(min_len=0, max_len=2000)
class ServiceOne(ServiceBase):
#rpc(InMessageType,
_returns=OperationOneResponse,
_out_message_name='OperationOneResponse',
_in_message_name='InMessageType',
_body_style='bare',
)
def OperationOne(ctx, message):
# DO STUFF HERE
# e.g. return {'ResponseMessage': Failure, 'Details': XSDValidationError}
application = Application([ServiceOne],
TNS,
in_protocol=Soap11(validator='lxml'),
out_protocol=Soap11(),
name='ServiceOne',)
wsgi_application = WsgiApplication(application)
if __name__ == '__main__':
pass
I have considered the following approach but I can't quite seem to make it work yet:
create subclass MyApplication with call_wrapper() function overridden.
Instantiate the application with in_protocol=Soap11(validator=None)
Inside the call wrapper set the protocol to Soap11(validator='lxml') and (somehow) call something which will validate the message. Wrap this in a try/except block and in case of error, catch the error and handle it in whatever way necessary.
I just haven't figured out what I can call inside my overridden call_wrapper() function which will actually perform the validation. I have tried protocol.decompose_incoming_envelope() and other such things but no luck yet.
Overriding the call_wrapper would not work as the validation error is raised before it's called.
You should instead use the event subsystem. More specifically, you must register an application-level handler for the method_exception_object event.
Here's an example:
def _on_exception_object(ctx):
if isinstance(ctx.out_error, ValidationError):
ctx.out_error = NicerValidationError(...)
app = Application(...)
app.event_manager.add_listener('method_exception_object', _on_exception_object)
See this test for more info: https://github.com/arskom/spyne/blob/4a74cfdbc7db7552bc89c0e5d5c19ed5d0755bc7/spyne/test/test_service.py#L69
As per your clarification, if you don't want to reply with a nicer error but a regular response, I'm afraid Spyne is not designed to satisfy that use-case. "Converting" an errored-out request processing state to a regular one would needlessly complicate the already heavy request handling logic.
What you can do instead is to HACK the heck out of the response document.
One way to do it is to implement an additional method_exception_document event handler where the <Fault> tag and its contents are either edited to your taste or even swapped out.
Off the top of my head:
class ValidationErrorReport(ComplexModel):
_type_info = [
('foo', Unicode),
('bar', Integer32),
]
def _on_exception_document(ctx):
fault_elt, = ctx.out_document.xpath("//soap11:Fault", namespaces={'soap11': NS_SOAP11_ENV})
explanation_elt = get_object_as_xml(ValidationErrorReport(...))
fault_parent = fault_elt.parent()
fault_parent.remove(fault_elt)
fault_parent.add(explanation_elt)
The above needs to be double-checked with the relevant Spyne and lxml APIs (maybe you can use find() instead of xpath()), but you get the idea.
Hope that helps!
I have this datastore model
class Project(db.Model)
projectname = db.StringProperty()
projecturl = db.StringProperty()
class Task(db.Model)
project = db.ReferenceProperty(Project)
taskname= db.StringProperty()
taskdesc = db.StringProperty()
How do I edit the value of taskname ? say I have task1 and i want to change it to task1-project
oops sorry, Here is the formatted code:
taskkey = self.request.get("taskkey")
taskid = Task.get(taskkey)
query = db.GqlQuery("SELECt * FROM Task WHERE key =:taskid", taskid=taskid)
if query.count() > 0:
task = Task()
task.taskname = "task1-project"
task.put()
by the way, I get it now. I changed the task=Task() into task = query.get() and it worked.
Thanks for helping by the way.
Given an instance t of Task (e.g. from some get operation on the db) you can perform the alteration you want e.g. by t.taskname = t.taskname + '-project' (if what you want is to "append '-project' to whatever was there before). Eventually, you also probably need to .put t back into the store, of course (but if you make multiple changes you don't need to put it back after each and every change -- only when you're done changing it!-).
Probably the easiest way is to use the admin console. Locally it's:
http://localhost:8080/_ah/admin
and if you've uploaded it, it's the dashboard:
http://appengine.google.com/dashboard?&app_id=******
Here's a link: