How to combine streams in anyio?

Question:

How to iterate over multiple steams at once in anyio, interleaving the items as they appear?

Let’s say, I want a simple equivalent of annotate-output. The simplest I could make is

#!/usr/bin/env python3

import dataclasses
from collections.abc import Sequence
from typing import TypeVar

import anyio
import anyio.abc
import anyio.streams.text

SCRIPT = r"""
for idx in $(seq 1 5); do
    printf "%s  " "$idx"
    date -Ins
    sleep 0.08
done
echo "."
"""
CMD = ["bash", "-x", "-c", SCRIPT]


def print_data(data: str, is_stderr: bool) -> None:
    print(f"{int(is_stderr)}: {data!r}")


T_Item = TypeVar("T_Item")  # TODO: covariant=True?


@dataclasses.dataclass(eq=False)
class CombinedReceiveStream(anyio.abc.ObjectReceiveStream[tuple[int, T_Item]]):
    """Combines multiple streams into a single one, annotating each item with position index of the origin stream"""

    streams: Sequence[anyio.abc.ObjectReceiveStream[T_Item]]
    max_buffer_size_items: int = 32

    def __post_init__(self) -> None:
        self._queue_send, self._queue_receive = anyio.create_memory_object_stream(
            max_buffer_size=self.max_buffer_size_items,
            # Should be: `item_type=tuple[int, T_Item] | None`
        )
        self._pending = set(range(len(self.streams)))
        self._started = False
        self._task_group = anyio.create_task_group()

    async def _copier(self, idx: int) -> None:
        assert idx in self._pending
        stream = self.streams[idx]
        async for item in stream:
            await self._queue_send.send((idx, item))
        assert idx in self._pending
        self._pending.remove(idx)
        await self._queue_send.send(None)  # Wake up the `receive` waiters, if any.

    async def _start(self) -> None:
        assert not self._started
        await self._task_group.__aenter__()
        for idx in range(len(self.streams)):
            self._task_group.start_soon(self._copier, idx, name=f"_combined_receive_copier_{idx}")
        self._started = True

    async def receive(self) -> tuple[int, T_Item]:
        if not self._started:
            await self._start()

        # Non-blocking pre-check.
        # Gathers items that are in the queue when `self._pending` is empty.
        try:
            item = self._queue_receive.receive_nowait()
        except anyio.WouldBlock:
            pass
        else:
            if item is not None:
                return item

        while True:
            if not self._pending:
                raise anyio.EndOfStream

            item = await self._queue_receive.receive()
            if item is not None:
                return item

    async def aclose(self) -> None:
        if self._started:
            self._task_group.cancel_scope.cancel()
            self._started = False
            await self._task_group.__aexit__(None, None, None)


async def amain(max_buffer_size_items: int = 32) -> None:
    async with await anyio.open_process(CMD) as proc:
        assert proc.stdout is not None
        assert proc.stderr is not None
        raw_streams = [proc.stdout, proc.stderr]
        idx_to_is_stderr = {0: False, 1: True}  # just making it explicit
        streams = [anyio.streams.text.TextReceiveStream(stream) for stream in raw_streams]
        async with CombinedReceiveStream(streams) as outputs:
            async for idx, data in outputs:
                is_stderr = idx_to_is_stderr[idx]
                print_data(data, is_stderr=is_stderr)


def main():
    anyio.run(amain)


if __name__ == "__main__":
    main()

However, this CombinedReceiveStream solution is somewhat ugly, and I would some solution should already exist. What am I overlooking?

Asked By: HoverHell

||

Answers:

This should be more safe and idiomatic.

class CtxObj:
    """
    Add an async context manager that calls `_ctx` to run the context.

    Usage::
        class Foo(CtxObj):
            @asynccontextmanager
            async def _ctx(self):
                yield self # or whatever

        async with Foo() as self_or_whatever:
            pass
    """

    async def __aenter__(self):
        self.__ctx = ctx = self._ctx()  # pylint: disable=E1101,W0201
        return await ctx.__aenter__()

    def __aexit__(self, *tb):
        return self.__ctx.__aexit__(*tb)


@dataclasses.dataclass(eq=False)
class CombinedReceiveStream(CtxObj):
    """Combines multiple streams into a single one, annotating each item with position index of the origin stream"""

    streams: Sequence[anyio.abc.ObjectReceiveStream[T_Item]]
    max_buffer_size_items: int = 32

    def __post_init__(self) -> None:
        self._queue_send, self._queue_receive = anyio.create_memory_object_stream(
            max_buffer_size=self.max_buffer_size_items,
            # Should be: `item_type=tuple[int, T_Item] | None`
        )
        self._pending = set(range(len(self.streams)))

    @asynccontextmanager
    async def _ctx(self):
        async with anyio.create_task_group() as tg:
            for i in self._pending:
                tg.start_soon(self._copier, i)

            yield self
            tg.cancel_scope.cancel()


    async def _copier(self, idx: int) -> None:
        stream = self.streams[idx]
        async for item in stream:
            await self._queue_send.send((idx, item))
        self._pending.remove(idx)
        if not self._pending:
            await self._queue_send.aclose()


    async def receive(self) -> tuple[int, T_Item]:
        return await self._queue_receive.receive()

    def __aiter__(self):
        return self

    async def __anext__(self):
        try:
            return await self.receive()
        except anyio.EndOfStream:
            raise StopAsyncIteration() from None
Answered By: Matthias Urlichs
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.