Run event loop until all tasks are blocked in python

Question:

I am writing code that has some long-running coroutines that interact with each other. These coroutines can be blocked on await until something external happens. I want to be able to drive these coroutines in a unittest. The regular way of doing await on the coroutine doesn’t work, because I want to be able to intercept something in the middle of their operation. I would also prefer not to mess with the coroutine internals either, unless there is something generic/reusable that can be done.

Ideally I would want to run an event loop until all tasks are currently blocked. This should be fairly easy to tell in an event loop implementation. Once everything is blocked, the event loop yields back control, where I can assert some state about the coroutines, and poke them externally. Then I can resume the loop until it gets blocked again. This would allow for deterministic simulation of tasks in an event loop.

Minimal example of the desired API:

import asyncio
from asyncio import Event

# Imagine this is a complicated "main" with many coroutines.
# But event is some external "mockable" event
# that can be used to drive in unit tests
async def wait_on_event(event: Event):
  print("Waiting on event")
  await event.wait()
  print("Done waiting on event")

def test_deterministic():
  loop = asyncio.get_event_loop()
  event = Event()
  task = loop.create_task(wait_on_event(event))
  run_until_blocked_or_complete(loop) # define this magic function
  # Should print "Waiting on event"

  # can make some test assertions here
  event.set()

  run_until_blocked_or_complete(loop)
  # Should print "Done waiting on event"

Anything like that possible? Or would this require writing a custom event loop just for tests?

Additionally, I am currently on Python 3.9 (AWS runtime limitation). If it’s not possible to do this in 3.9, what version would support this?

Answers:

Yes, you can achieve this by creating a custom event loop policy and using a mock event loop in your test. The basic idea is to create a loop that only runs until all the coroutines are blocked, then yield control back to the test code to perform any necessary assertions or external pokes, and then continue running the loop until everything is blocked again, and so on.

import asyncio

class DeterministicEventLoopPolicy(asyncio.DefaultEventLoopPolicy):
def new_event_loop(self):
    loop = super().new_event_loop()
    loop._blocked = set()
    return loop

def get_event_loop(self):
    loop = super().get_event_loop()
    if not hasattr(loop, "_blocked"):
        loop._blocked = set()
    return loop

def _enter_task(self, task):
    super()._enter_task(task)
    if not task._source_traceback:
        task._source_traceback = asyncio.Task.current_task().get_stack()
    task._loop._blocked.add(task)

def _leave_task(self, task):
    super()._leave_task(task)
    task._loop._blocked.discard(task)

def run_until_blocked(self, coro):
    loop = self.new_event_loop()
    asyncio.set_event_loop(loop)
    try:
        task = loop.create_task(coro)
        while loop._blocked:
            loop.run_until_complete(asyncio.sleep(0))
    finally:
        task.cancel()
        loop.run_until_complete(task)
        asyncio.set_event_loop(None)

This policy creates a new event loop with a _blocked set attribute that tracks the tasks that are currently blocked. When a new task is scheduled on the loop, the _enter_task method is called, and we add it to the _blocked set. When a task is completed or canceled, the _leave_task method is called, and we remove it from the _blocked set.

The run_until_blocked method takes a coroutine and runs the event loop until all the tasks are blocked. It creates a new event loop using the custom policy, schedules the coroutine on the loop, and then repeatedly runs the loop until the _blocked set is empty. This is the point where you can perform any necessary assertions or external pokes.

Here’s an example usage of this policy:

async def wait_on_event(event: asyncio.Event):
print("Waiting on event")
await event.wait()
print("Done waiting on event")

def test_deterministic():
asyncio.set_event_loop_policy(DeterministicEventLoopPolicy())

event = asyncio.Event()
asyncio.get_event_loop().run_until_blocked(wait_on_event(event))
assert not event.is_set()  # assert that the event has not been set yet

event.set()  # set the event
asyncio.get_event_loop().run_until_blocked(wait_on_event(event))
assert event.is_set()  # assert that the event has been set

asyncio.get_event_loop().close()

In this test, we create a new Event object and pass it to the wait_on_event coroutine. We use the run_until_blocked method to run the coroutine until it blocks on the event.wait() call. At this point, we can perform any necessary assertions, such as checking that the event has not been set yet. We then set the event, and call run_until_blocked again to resume the coroutine until it completes.

This pattern allows for deterministic simulation of tasks in an event loop and can be used to test coroutines that block on external events.

Hope this helps!

Answered By: CristianGabriel

The default event loop simply runs everything that is scheduled in each "pass". If you simply schedule your pause with "loop.call_soon" after getting your tasks running, you should be called at the desired point:

import asyncio
async def worker(n=1):
    await asyncio.sleep(n)
def event():
    print("blah")
    breakpoint()
    print("bleh")
async def worker(id):
    print(f"starting task {id}")
    await asyncio.sleep(0.1)
    print(f"ending task {id}")
async def main():
    t = []
    for id in (1,2,3):
        t.append(asyncio.create_task(worker(id)))
    loop = asyncio.get_running_loop()
    loop.call_soon(event)
    await asyncio.sleep(0.2)

And running this on the REPL:

In [8]: asyncio.run(main())
starting task 1
starting task 2
starting task 3
blah
> <ipython-input-3-450374919d79>(4)event()
-> print("bleh")
(Pdb) 
Exception in callback event() at <ipython-input-3-450374919d79>:1
[...]
bdb.BdbQuit
ending task 1
ending task 2
ending task 3
Answered By: jsbueno

After some experimenting I came up with something. Here is the usage first:

# This coroutine could spawn many others. Keeping it simple here
async def wait_on_event(event: asyncio.Event) -> int:
    print("Waiting")
    await event.wait()
    print("Done")
    return 42


def test_deterministic_loop():
    loop = DeterministicLoop()
    event = asyncio.Event()
    task = loop.add_coro(wait_on_event(event))

    assert not task.done()
    loop.step()
    # prints Waiting
    assert not task.done()
    assert not loop.done()

    event.set()
    loop.step()
    # prints Done
    assert task.done()
    assert task.result() == 42
    assert loop.done()

The implementation:

"""Module for testing facilities. Don't use these in production!"""
import asyncio
from enum import IntEnum
from typing import Any, Optional, TypeVar, cast
from collections.abc import Coroutine, Awaitable


def _get_other_tasks(loop: Optional[asyncio.AbstractEventLoop]) -> set[asyncio.Task]:
    """Get a set of currently scheduled tasks in an event loop that are not the current task"""
    current = asyncio.current_task(loop)
    tasks = asyncio.all_tasks(loop)
    if current is not None:
        tasks.discard(current)
    return tasks



# Works on python 3.9, cannot guarantee on other versions
def _get_unblocked_tasks(tasks: set[asyncio.Task]) -> set[asyncio.Task]:
    """Get the subset of tasks that can make progress. This is the most magic
    function, and is heavily dependent on eventloop implementation and python version"""

    def is_not_blocked(task: asyncio.Task):
        # pylint: disable-next=protected-access
        wait_for = cast(Optional[asyncio.Future], task._fut_waiter)  # type: ignore
        if wait_for is None:
            return True
        return wait_for.done()

    return set(filter(is_not_blocked, tasks))


class TasksState(IntEnum):
    RUNNING = 0
    BLOCKED = 1
    DONE = 2


def _get_tasks_state(
    prev_tasks: set[asyncio.Task], cur_tasks: set[asyncio.Task]
) -> TasksState:
    """Given set of tasks for previous and current pass of the event loop,
    determine the overall state of the tasks. Are the tasks making progress,
    blocked, or done?"""
    if not cur_tasks:
        return TasksState.DONE

    unblocked: set[asyncio.Task] = _get_unblocked_tasks(cur_tasks)
    # check if there are tasks that can make progress
    if unblocked:
        return TasksState.RUNNING

    # if no tasks appear to make progress, check if this and last step the state
    # has been constant
    elif prev_tasks == cur_tasks:
        return TasksState.BLOCKED

    return TasksState.RUNNING


async def _stop_when_blocked():
    """Schedule this task to stop the event loop when all other tasks are
    blocked, or they all complete"""
    prev_tasks: set[asyncio.Task] = set()
    loop = asyncio.get_running_loop()
    while True:
        tasks = _get_other_tasks(loop)
        state = _get_tasks_state(prev_tasks, tasks)
        prev_tasks = tasks

        # stop the event loop if all other tasks cannot make progress
        if state == TasksState.BLOCKED:
            loop.stop()

        # finish this task too, if no other tasks exist
        if state == TasksState.DONE:
            break

        # yield back to the event loop
        await asyncio.sleep(0.0)

    loop.stop()


T = TypeVar("T")


class DeterministicLoop:
    """An event loop for writing deterministic tests."""

    def __init__(self):
        self.loop = asyncio.get_event_loop_policy().new_event_loop()
        asyncio.set_event_loop(self.loop)
        self.stepper_task = self.loop.create_task(_stop_when_blocked())
        self.tasks: list[asyncio.Task] = []

    def add_coro(self, coro: Coroutine[Any, Any, T]) -> asyncio.Task[T]:
        """Add a coroutine to the set of running coroutines, so they can be stepped through"""
        if self.done():
            raise RuntimeError("No point in adding more tasks. All tasks have finished")
        task = self.loop.create_task(coro)
        self.tasks.append(task)
        return task

    def step(self, awaitable: Optional[Awaitable[T]] = None) -> Optional[T]:
        if self.done() or not self.tasks:
            raise RuntimeError(
                "No point in stepping. No tasks to step or all are finished"
            )

        step_future: Optional[asyncio.Future[T]] = None
        if awaitable is not None:
            step_future = asyncio.ensure_future(awaitable, loop=self.loop)

        # stepper_task should halt us if we're blocked or all tasks are done
        self.loop.run_forever()

        if step_future is not None:
            assert (
                step_future.done()
            ), "Can't step the event loop, where the step function itself might get blocked"
            return step_future.result()
        return None

    def done(self) -> bool:
        return self.stepper_task.done()

This question has puzzled me since I first read it, because it’s almost do-able with standard asyncio functions. The key is Alexander’s "magic" is_not_blocked method, which I give verbatim below (except for moving it to the outer indentation level). I also use his wait_on_event method, and his test_deterministic_loop function. I added some extra tests to show how to start and stop other tasks, and how to drive the event loop step-by-step until all tasks are finished.

Instead of his DeterministicLoop class, I use a function run_until_blocked that makes only standard asyncio function calls. The two lines of code:

loop.call_soon(loop.stop)
loop.run_forever()

are a convenient means of advancing the loop by exactly one cycle. And asyncio already provides a method for obtaining all the tasks that run within a given event loop, so there is no need to store them independently.

A comment on the Alexander’s "magic" method: if you look at the comments in the asyncio.Task code, the "private" variable _fut_waiter is described as an important invariant. That’s very unlikely to change in future versions. So I think it’s quite safe in practice.

import asyncio
from typing import Optional, cast

def _is_not_blocked(task: asyncio.Task):
    # pylint: disable-next=protected-access
    wait_for = cast(Optional[asyncio.Future], task._fut_waiter)  # type: ignore
    if wait_for is None:
        return True
    return wait_for.done()

def run_until_blocked():
    """Runs steps of the event loop until all tasks are blocked."""
    loop = asyncio.get_event_loop()
    # Always run one step.
    loop.call_soon(loop.stop)
    loop.run_forever()
    # Continue running until all tasks are blocked
    while any(_is_not_blocked(task) for task in asyncio.all_tasks(loop)):
        loop.call_soon(loop.stop)
        loop.run_forever()
        
# This coroutine could spawn many others. Keeping it simple here
async def wait_on_event(event: asyncio.Event) -> int:
    print("Waiting")
    await event.wait()
    print("Done")
    return 42

def test_deterministic_loop():
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    event = asyncio.Event()
    task = loop.create_task(wait_on_event(event))
    assert not task.done()
    run_until_blocked()
    print("Task done", task.done())
    assert not task.done()
    print("Tasks running", asyncio.all_tasks(loop))
    assert asyncio.all_tasks(loop)
    event.set()
    # You can start and stop tasks
    loop.run_until_complete(asyncio.sleep(2.0))
    run_until_blocked()
    print("Task done", task.done())
    assert task.done()
    print("Tasks running", asyncio.all_tasks(loop))
    assert task.result() == 42
    assert not asyncio.all_tasks(loop)
    # If you create a task you must loop run_until_blocked until
    # the task is done.
    task2 = loop.create_task(asyncio.sleep(2.0))
    assert not task2.done()
    while not task2.done():
        assert asyncio.all_tasks(loop)
        run_until_blocked()
    assert task2.done()
    assert not asyncio.all_tasks(loop)
    
test_deterministic_loop()
Answered By: Paul Cornelius

Here is a simple generalized implementation.

If the loop finds that all tasks are stuck in some async instruction (Event, Semaphore, whatever) for some constant amount of iterations it will exit the loop context until run_until_blocked is called once again.

import asyncio

MAX_LOOP_ITER = 100

def run_until_blocked(loop):
    global _tasks
    _tasks = {}

    while True:

        loop.call_soon(loop.stop)
        loop.run_forever()

        for task in asyncio.all_tasks(loop=loop):
            if task.done():
                continue

            lasti = task.get_stack()[-1].f_lasti
            if task in _tasks and _tasks[task]["lasti"] == lasti:
                _tasks[task]["iter"] += 1
            else:
                _tasks[task] = {"iter": 0, "lasti": lasti}

        if all(val["iter"] < MAX_LOOP_ITER for val in _tasks.values()):
            break

async def wait_on_event(event: asyncio.Event):
    print("Waiting on event")
    await event.wait()
    print("Done waiting on event")
    return 42

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

event = asyncio.Event()
coro = wait_on_event(event)
task = loop.create_task(coro)

run_until_blocked(loop)
event.set()
run_until_blocked(loop)

print(task.result())
Answered By: Felipe