How to type annotate a function with return "shape" based on input "shape"?

Question:

The function below receives a instance of a type (e.g. int) and returns a instance of another type (e.g. str).

But, it may receive list of the input type (list[int]), and returns a list of the output type in that case (list[str]).

The way I found to achieve that is to use Union:

from typing import Union

def foo(a: Union[int, list[int]]) -> Union[str, list[str]]:
   ...

But the type checker fails when I use it in another function:

def bar(s: str):
   """this function should receive a single `str`"""
   ...

b = bar(foo(33)) # type checks fails: Type "str | list[str]" cannot be assigned to type "str"

The warning raise by the type checker is:

"Type "str | list[str]" cannot be assigned to type "str""

The int and str types are just for examples. They could be other classes.

How do I type annotate the first function to not get a fail?

i.e., how to annotate the input and return so the return type is based on the "shape" of the input type?

(So the type checker could know that the return type would be a single str in that example)

Asked By: Diogo

||

Answers:

You’re looking for @typing.overload. Try the following:

from typing import overload, Union

@overload
def foo(a: list[int]) -> list[str]:
    ...
@overload
def foo(a: int) -> str:
    ...
def foo(a: Union[int, list[int]]) -> Union[str, list[str]]:
    # implementation goes here. it could be something like the following
    if isinstance(a, int):
        return str(a)
    # since we already checked if a was an int, it must be a list from here on
    return [str(elm) for elm in a]

Basically, this tells the type checker that if it sees a call to foo with a list[int] param, the return will be of type list[str], and likewise for int in and str out. In your original formulation, the best the type checker can determine is that all combinations of in and out types are possible.

By the way, the modern way to express Union is with a pipe (|). This syntax was introduced in PEP 604 and is available in python 3.10 and beyond. We can rewrite your implementation signature more succinctly as

def foo(a: int | list[int]) -> str | list[str]:
Answered By: thisisrandy

In the example code you have provided, you declared that foo returns str | list[str], but you have declared that bar only accepts str. You are receiving this error because your code doesn’t answer the question: "What happens if foo returns list[str]?".

However, if you know that foo is going to return str in this specific case, then you can utilize cast() (link):

from typing import cast

b = bar(cast(str, foo(33)))
Answered By: VoidTwo