How to create a conditional task in Airflow

Question:

I would like to create a conditional task in Airflow as described in the schema below. The expected scenario is the following:

  • Task 1 executes
  • If Task 1 succeed, then execute Task 2a
  • Else If Task 1 fails, then execute Task 2b
  • Finally execute Task 3

Conditional Task
All tasks above are SSHExecuteOperator.
I’m guessing I should be using the ShortCircuitOperator and / or XCom to manage the condition but I am not clear on how to implement that. Could you please describe the solution?

Asked By: Alexis.Rolland

||

Answers:

You have to use airflow trigger rules

All operators have a trigger_rule argument which defines the rule by which the generated task get triggered.

The trigger rule possibilities:

ALL_SUCCESS = 'all_success'
ALL_FAILED = 'all_failed'
ALL_DONE = 'all_done'
ONE_SUCCESS = 'one_success'
ONE_FAILED = 'one_failed'
DUMMY = 'dummy'

Here is the idea to solve your problem:

from airflow.operators.ssh_execute_operator import SSHExecuteOperator
from airflow.utils.trigger_rule import TriggerRule
from airflow.contrib.hooks import SSHHook

sshHook = SSHHook(conn_id=<YOUR CONNECTION ID FROM THE UI>)

task_1 = SSHExecuteOperator(
        task_id='task_1',
        bash_command=<YOUR COMMAND>,
        ssh_hook=sshHook,
        dag=dag)

task_2 = SSHExecuteOperator(
        task_id='conditional_task',
        bash_command=<YOUR COMMAND>,
        ssh_hook=sshHook,
        dag=dag)

task_2a = SSHExecuteOperator(
        task_id='task_2a',
        bash_command=<YOUR COMMAND>,
        trigger_rule=TriggerRule.ALL_SUCCESS,
        ssh_hook=sshHook,
        dag=dag)

task_2b = SSHExecuteOperator(
        task_id='task_2b',
        bash_command=<YOUR COMMAND>,
        trigger_rule=TriggerRule.ALL_FAILED,
        ssh_hook=sshHook,
        dag=dag)

task_3 = SSHExecuteOperator(
        task_id='task_3',
        bash_command=<YOUR COMMAND>,
        trigger_rule=TriggerRule.ONE_SUCCESS,
        ssh_hook=sshHook,
        dag=dag)


task_2.set_upstream(task_1)
task_2a.set_upstream(task_2)
task_2b.set_upstream(task_2)
task_3.set_upstream(task_2a)
task_3.set_upstream(task_2b)
Answered By: Jean S

Airflow 2.x

Airflow provides a branching decorator that allows you to return the task_id (or list of task_ids) that should run:

@task.branch(task_id="branch_task")
def branch_func(ti):
    xcom_value = int(ti.xcom_pull(task_ids="start_task"))
    if xcom_value >= 5:
        return "big_task" # run just this one task, skip all else
    elif xcom_value >= 3:
        return ["small_task", "warn_task"] # run these, skip all else
    else:
        return None # skip everything

You can also inherit directly from BaseBranchOperator overriding the choose_branch method, but for simple branching logic the decorator is best.

Airflow 1.x

Airflow has a BranchPythonOperator that can be used to express the branching dependency more directly.

The docs describe its use:

The BranchPythonOperator is much like the PythonOperator except that it expects a python_callable that returns a task_id. The task_id returned is followed, and all of the other paths are skipped. The task_id returned by the Python function has to be referencing a task directly downstream from the BranchPythonOperator task.

If you want to skip some tasks, keep in mind that you can’t have an empty path, if so make a dummy task.

Code Example

def dummy_test():
    return 'branch_a'

A_task = DummyOperator(task_id='branch_a', dag=dag)
B_task = DummyOperator(task_id='branch_false', dag=dag)

branch_task = BranchPythonOperator(
    task_id='branching',
    python_callable=dummy_test,
    dag=dag,
)

branch_task >> A_task 
branch_task >> B_task

If you’re installing an Airflow version >=1.10.3, you can also return a list of task ids, allowing you to skip multiple downstream paths in a single Operator and don’t have to use a dummy task before joining.

Answered By: villasv

Let me add my take on this.

First of all, sorry for the lengthy post, but I wanted to share the complete solution that works for me.

background

We have a script that pulls data from a very crappy and slow API.
It’s slow so we need to be selective about what we do and what we don’t pull from it (1 request/s with more than 750k requests to make)
Occasionally the requirements change that forces us to pull the data in full but only for one/some endpoints. So we need something we can control.

The strict rate limit of 1 request/s with several seconds of delay if breached would halt all parallel tasks.

The meaning of the 'catchup': True is essentially a backfill that is translated into a command line option (-c).

There are no data dependencies between our tasks, we only need to follow the order of (some) tasks.

solution

Introducing the pre_execute callable with the extra DAG config takes care of the proper skip of tasks which throws the AirflowSkipException.

Secondly, based on the config we can swap the original operator for a simple Python operator with the same name with a simple definition.
This way the UI won’t be confused and the trigger history will be kept complete – showing the executions when a task was skipped.

from airflow import DAG
from airflow.exceptions import AirflowSkipException
from airflow.operators.python import PythonOperator

from plugins.airflow_utils import default_args, kubernetes_pod_task


# callable for pre_execute arg
def skip_if_specified(context):
    task_id = context['task'].task_id
    conf = context['dag_run'].conf or {}
    skip_tasks = conf.get('skip_task', [])
    if task_id in skip_tasks:
        raise AirflowSkipException()

# these are necessary to make this solution work
support_task_skip_args = {'pre_execute': skip_if_specified,
                          'trigger_rule': 'all_done'}
extended_args = {**default_args, **support_task_skip_args}

dag_name = 'optional_task_skip'

dag = DAG(dag_name,
          max_active_runs=3,
          schedule_interval=None,
          catchup=False,
          default_args=extended_args)

# select endpoints and modes
# !! make sure the dict items are in the same order as the order you want them to run !!
task_options = {
    'option_name_1':
        {'param': 'fetch-users', 'enabled': True, 'catchup': False},
    'option_name_2':
        {'param': 'fetch-jobs', 'enabled': True},
    'option_name_3':
        {'param': 'fetch-schedules', 'enabled': True, 'catchup': True},
    'option_name_4':
        {'param': 'fetch-messages', 'enabled': True, 'catchup': False},
    'option_name_5':
        {'param': 'fetch-holidays', 'enabled': True, 'catchup': False},
}


def add_tasks():
    task_list_ = []
    for task_name_, task_config_ in task_options.items():
        if task_config_['enabled']:
            parameter_ = task_config_['param']
            catchup_ = '-c ' if task_config_.get('catchup') else ''
            task_list_.append(
                kubernetes_pod_task(
                    dag=dag,
                    command=f"cd people_data; python3 get_people_data.py {parameter_} {catchup_}",
                    task_id=f"{task_name_}"))
            if len(task_list_) > 1:
                task_list_[-2] >> task_list_[-1]
        else:
            # the callable that throws the skip signal
            def skip_task(): raise AirflowSkipException()

            task_list_.append(
                PythonOperator(dag=dag,
                               python_callable=skip_task,
                               task_id=f"{task_name_}",
                               )
            )
            if len(task_list_) > 1:
                task_list_[-2] >> task_list_[-1]


# populate the DAG
add_tasks()

Note:
The default_args, kubernetes_pod_task are just wrappers for convenience.
The kubernetes pod task injects some variables and secrets in a simple function and uses the from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator module, I won’t and can’t share those with you.

The solution extends the great ideas of this gentleman:
https://www.youtube.com/watch?v=abLGyapcbw0

Although, this solution works with Kubernetes operators, too.

Of course, this could be improved, and you absolutely can extend or rework the code to parse manual trigger config as well (as it is shown in the video).

Here’s what it looks like in my UI :
enter image description here

(it doesn’t reflect the example config above but rather the actual runs in our staging infrastructure)

Answered By: Gergely M