Execute cleanup task on upstream task failure

Question:

I’m using the taskflow API on Airflow v2.3.0.

I have a DAG that has a bunch of tasks. The first task (TaskA) turns on an EC2 instance and one of the last tasks (TaskX) turns off the EC2 instance.

TaskA returns the instance ID of the EC2 instance, which TaskC then uses to turn it off
Between the two tasks, there’s Tasks B,C,D,…W, any of which may fail. When any of these tasks fail, I’d like to trigger TaskX to turn off my EC2 instance.

I know that I can use on_failure_callback in @dag, but the problem is that TaskX requires an EC2 instance ID which I will not know at the time of DAG definition – it needs to be acquired from the return value of TaskA (see MCVE code below)

@dag(on_failure_callback = taskX(unknownInstanceID), ...)  # <- that instanceID is unknown at this time
def my_dag():

    @task
    def taskA():
        instanceID = turnOnEC2Instance()
        return instanceID

    @task
    def taskB(instanceID):
        # do stuff on EC2 instance
        return stuff

    @task
    def taskC(...):
        # do stuff

    # other tasks

    @task
    def taskX(instanceID, *_dependencies):
        shutDownEC2Instance(instanceID)

    instanceID = taskA()
    dependency = taskB(instanceID)
    taskX(instanceID, dependency)  # `dependency` ensures that the EC2 instance is not shutdown before TaskB finishes

Therefore, could I instead use a try/catch semantic? Is this supported in airflow? (I don’t see it in the docs). A task failure doesn’t exactly raise a python error, so I doubt this will work.

@dag(...)
def my_dag():

    @task
    def taskA():
        instanceID = turnOnEC2Instance()
        return instanceID

    @task
    def taskB(instanceID):
        # do stuff on EC2 instance
        return stuff

    # other tasks

    @task
    def taskX(instanceID, *_dependencies):
        shutDownEC2Instance(instanceID)

    instanceID = taskA()
    try:
        dependency = taskB(instanceID)
    finally:
       taskX(instanceID)  # no need for `dependency` here

What’s the solution here? Airflow has /got/ to have a semantic for this that I’m just not seeing

Update after trying Lucas’ answer. Here’s a MCVE to reproduce the errors I’m seeing:

import datetime as dt

from airflow.decorators import dag, task
from airflow.utils.trigger_rule import TriggerRule


@dag(start_date = dt.datetime(2022, 5, 16),
     schedule_interval = None,  # manual trigger only
     catchup = False,
     max_active_runs = 1,
     )
def test_trigger():

    @task
    def taskA():
        return 5

    @task
    def taskB(*_deps):
        time.sleep(45)  # I can manually fail this task during the sleep
        return 10

    @task(trigger_rule=TriggerRule.ALL_DONE)
    def taskX(r, *_deps):
        print('Ending it all')
        print(r)

    i = taskA()
    d = taskB(i)
    taskX(i, d)


taskflow_dag = test_trigger()

If I run this and manually fail taskB, taskX also fails because taskB did not post a return value.

enter image description here

taskX Log

[2022-09-18, 22:21:35 UTC] {taskinstance.py:1889} ERROR - Task failed with exception
Traceback (most recent call last):
  File "/home/infra/.pyenv/versions/3.8.5/lib/python3.8/site-packages/airflow/models/taskinstance.py", line 1451, in _run_raw_task
    self._execute_task_with_callbacks(context, test_mode)
  File "/home/infra/.pyenv/versions/3.8.5/lib/python3.8/site-packages/airflow/models/taskinstance.py", line 1555, in _execute_task_with_callbacks
    task_orig = self.render_templates(context=context)
  File "/home/infra/.pyenv/versions/3.8.5/lib/python3.8/site-packages/airflow/models/taskinstance.py", line 2202, in render_templates
    rendered_task = self.task.render_template_fields(context)
  File "/home/infra/.pyenv/versions/3.8.5/lib/python3.8/site-packages/airflow/models/baseoperator.py", line 1179, in render_template_fields
    self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())
  File "/home/infra/.pyenv/versions/3.8.5/lib/python3.8/site-packages/airflow/utils/session.py", line 71, in wrapper
    return func(*args, session=session, **kwargs)
  File "/home/infra/.pyenv/versions/3.8.5/lib/python3.8/site-packages/airflow/models/abstractoperator.py", line 344, in _do_render_template_fields
    rendered_content = self.render_template(
  File "/home/infra/.pyenv/versions/3.8.5/lib/python3.8/site-packages/airflow/models/abstractoperator.py", line 398, in render_template
    return tuple(self.render_template(element, context, jinja_env) for element in value)
  File "/home/infra/.pyenv/versions/3.8.5/lib/python3.8/site-packages/airflow/models/abstractoperator.py", line 398, in <genexpr>
    return tuple(self.render_template(element, context, jinja_env) for element in value)
  File "/home/infra/.pyenv/versions/3.8.5/lib/python3.8/site-packages/airflow/models/abstractoperator.py", line 394, in render_template
    return value.resolve(context)
  File "/home/infra/.pyenv/versions/3.8.5/lib/python3.8/site-packages/airflow/utils/session.py", line 71, in wrapper
    return func(*args, session=session, **kwargs)
  File "/home/infra/.pyenv/versions/3.8.5/lib/python3.8/site-packages/airflow/models/xcom_arg.py", line 152, in resolve
    raise AirflowException(
airflow.exceptions.AirflowException: XComArg result from taskB at test_trigger with key="return_value" is not found!

Attempting to use the >> operator results in this error:

    @task(trigger_rule=TriggerRule.ALL_DONE)
    def taskX(r, *_deps):
        print('Ending it all')
        print(r)

    i = taskA()
    d = taskB(i)
    taskX(i)

    taskB >> taskXusername


taskflow_dag = test_trigger()
Broken DAG: [/home/infra/airflow/dags/test.py] Traceback (most recent call last):
  File "/home/username/.pyenv/versions/3.8.5/lib/python3.8/site-packages/airflow/models/dag.py", line 2967, in factory
    f(**f_kwargs)
  File "/home/username/airflow/dags/test.py", line 34, in test_trigger
    taskB >> taskX
TypeError: unsupported operand type(s) for >>: '_TaskDecorator' and '_TaskDecorator'
Asked By: inspectorG4dget

||

Answers:

The task in airflow they have a trigger rule, which can be pass to the decorators you are using.

TriggerRule.ALL_SUCCESS: will trigger a task if all of the previous are succesfull (default one)

TriggerRule.ONE_FAILED: will trigger a task if one of the previous failed

TriggerRule.ALWAYS: will always trigger the task

TriggerRule.ALL_DONE: at the end of all previous task
more info on trigger rules: https://airflow.apache.org/docs/apache-airflow/stable/concepts/dags.html#concepts-trigger-rules

from airflow.utils.trigger_rule import TriggerRule

@task(trigger_rule=TriggerRule.ALL_DONE)
def taskX(instanceID, *_dependencies):
        shutDownEC2Instance(instanceID)

That way your taskX will be executed when all other task are done

If you want to turn it of only if a previous task fail use TriggerRule.ONE_FAILED

In order your code to work the end of your dag should look something like this at the end when you organize your task

original = taskA()
taskC(taskB(original)) >> taskX(original)

This indicates that task taskC and taskB are before taskX and since you want to use decorators to pass your instance to your taskX you need the output of your taskA which is what I named original

enter image description here

Answered By: Lucas M. Uriarte

The solution is to upgrade Airflow to v2.4.0, which closed this issue:

pip install -U apache-airflow

and migrate the database

airflow db upgrade

Once upgraded, the bitshift operator still raises this error:

TypeError: unsupported operand type(s) for >>: '_TaskDecorator' and '_TaskDecorator'

However, the following code solves the issue:

import datetime as dt
import time

from airflow.decorators import dag, task
from airflow.utils.trigger_rule import TriggerRule


@dag(start_date = dt.datetime(2022, 5, 16),
     schedule_interval = None,  # manual trigger only
     catchup = False,
     max_active_runs = 1,
     )
def test_trigger():

    @task
    def taskA():
        return 5


    @task
    def taskB(*_deps):
        time.sleep(45)  # I can manually fail this task during the sleep
        return 10


    @task
    def taskC(*_deps):
        return 15


    @task(trigger_rule=TriggerRule.ALL_DONE)
    def taskX(r, *_deps):
        print('Ending it all')
        print(r)

    i = taskA()
    d1 = taskB(i)
    d2 = taskC(d1)
    taskX(i, d2)


taskflow_dag = test_trigger()

This causes taskC to no longer run (upstream failed), but taskX still runs to completion

enter image description here

Answered By: inspectorG4dget
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.