How can I type-hint a function where the return type depends on the input type of an argument?

Question:

Assume that I have a function which converts Python data-types to Postgres data-types like this:

def map_type(input):
    if isinstance(input, int):
        return MyEnum(input)
    elif isinstance(input, str):
        return MyCustomClass(str)

I could type-hint this as:

def map_type(input: Union[int, str]) -> Union[MyEnum, MyCustomClass]: ...

But then code like the following would fail to type-check even though it is correct:

myvar = map_type('foobar')
print(myvar.property_of_my_custom_class)

Complete example (working code, but errors in type-hinting):

from typing import Union
from enum import Enum


class MyEnum(Enum):
    VALUE_1 = 1
    VALUE_2 = 2


class MyCustomClass:

    def __init__(self, value: str) -> None:
        self.value = value

    @property
    def myproperty(self) -> str:
        return 2 * self.value


def map_type(value: Union[int, str]) -> Union[MyEnum, MyCustomClass]:

    if isinstance(value, int):
        return MyEnum(value)
    elif isinstance(value, str):
        return MyCustomClass(value)
    raise TypeError('Invalid input type')


myvar1 = map_type(1)
print(myvar1.value, myvar1.name)

myvar2 = map_type('foobar')
print(myvar2.myproperty)

I’m aware that I could split up the mapping into two functions, but the aim is to have a generic type-mapping function.

I was also thinking about working with classes and polymorphism, but then how would I type-hint the topmost class methods? Because their output type would depend on the concrete instance type.

Asked By: exhuma

||

Answers:

This is exactly what function overloads are for.

In short, you do the following:

from typing import overload

# ...snip...

@overload
def map_type(value: int) -> MyEnum: ...

@overload
def map_type(value: str) -> MyCustomClass: ...

def map_type(value: Union[int, str]) -> Union[MyEnum, MyCustomClass]:
    if isinstance(value, int):
        return MyEnum(value)
    elif isinstance(value, str):
        return MyCustomClass(value)
    raise TypeError('Invalid input type')

Now, when you do map_type(3), mypy will understand that the return type is MyEnum.

And at runtime, the only function to actually run is the final one — the first two are completely overridden and ignored.

Answered By: Michael0x2a

If your return type is the same as your input type (or, as my example shows, a variation) you can use the follow strategy (and remove the need to add additional @overloads when more types are introduced/supported).

from typing import TypeVar

T = TypeVar("T")  # the variable name must coincide with the string

def filter_category(category: T) -> list[T]:
  # assume get_options() is some function that gets
  # an arbitrary number of objects of varying types
  return [
    option
    for option in get_options()
    if is subclass(option, category)
  ]

then using filter_category would correctly be associated by both your IDE (VSCode, Pycharm, etc) and by your static type checker (Mypy)

# for instance list_of_ints would show in your IDE as of type list[Type[int]]
list_of_ints = filter_category(int)

# or here it list_of_bools would show in your IDE as of type list[Type[bool]]
list_of_bools = filter_category(bool)

This type definition is much more specific than using this

def overly_general_filter(category: Any) -> list[Any]:
  pass

it really would be equivalent to

def equally_general_filter(category) -> list:
  pass
Answered By: Marc