mypy indexing pd.DataFrame with an Enum raises no overload variant error

Question:

The issue

Mypy gives no overload variant of __getitem__ of "DataFrame" matches argument type "MyEnum" error. In this case the argument type is an Enum but the issue would occur for any other custom type. Here is the signature of __get_item__ below.

def __getitem__(self, Union[str, bytes, date, datetime, timedelta, bool, int, float, complex, Timestamp, Timedelta], /) -> Series[Any]

To reproduce

Here is a script (namely mypy_enum.py) creating a pandas dataframe with enums as columns.

from enum import Enum
import pandas as pd

class MyEnum(Enum):
    TAYYAR = "tayyar"
    HAYDAR = "haydar"

df = pd.DataFrame(data = [[12.2, 10], [8.8, 15], [22.1, 14]], columns=[MyEnum.TAYYAR, MyEnum.HAYDAR])
print(df[MyEnum.TAYYAR])

Here’s the output when you call it. It works as expected, all is well.

> python mypy_enum.py
0    12.2
1     8.8
2    22.1
Name: MyEnum.TAYYAR, dtype: float64

When you call it with mypy however;

> mypy mypy_enum.py  
mypy_enum.py:12: error: No overload variant of "__getitem__" of "DataFrame" matches argument type "MyEnum"  [call-overload]
mypy_enum.py:12: note: Possible overload variants:
mypy_enum.py:12: note:     def __getitem__(self, Union[str, bytes, date, datetime, timedelta, bool, int, float, complex, Timestamp, Timedelta], /) -> Series[Any]
mypy_enum.py:12: note:     def __getitem__(self, slice, /) -> DataFrame
mypy_enum.py:12: note:     def [ScalarT] __getitem__(self, Union[Tuple[Any, ...], Series[bool], DataFrame, List[str], List[ScalarT], Index, ndarray[Any, dtype[str_]], ndarray[Any, dtype[bool_]], Sequence[Tuple[Union[str, bytes, date, datetime, timedelta, bool, int, float, complex, Timestamp, Timedelta], ...]]], /) -> DataFrame
Found 1 error in 1 file (checked 1 source file)

Shouldn’t __getitem__ be supporting the column type itself? How can this be addressed?

Asked By: anilbey

||

Answers:

This issue was due to this bug in pandas-stubs. It is now fixed in PR/596.

Before the fix, the type signature of the first overload of __getitem__ was this:

    @overload
    def __getitem__(self, idx: Scalar | tuple[Hashable, ...]) -> Series: ...

The solution was to add Hashable type to that overload of __getitem__.

Here I am listing all 3 overloads of __getitem__ after the fix.

    @overload
    def __getitem__(self, idx: Scalar | Hashable) -> Series: ...
    @overload
    def __getitem__(self, rows: slice) -> DataFrame: ...
    @overload
    def __getitem__(
        self,
        idx: Series[_bool]
        | DataFrame
        | Index
        | np_ndarray_str
        | np_ndarray_bool
        | list[_ScalarOrTupleT],
    ) -> DataFrame: ...

The example reported in the question meets the first overload now, therefore the type checkers are happy.

Answered By: anilbey
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.