How to check that a string is a string literal for mypy?

Question:

With this code

import os
from typing import Literal, get_args

Markets = Literal[
    "BE", "DE", "DK", "EE", "ES", "FI", "FR", "GB", "IT", "LT", "LV", "NL", "NO", "PL", "PT", "SE"
]
MARKETS: list[Markets] = list(get_args(Markets))


def foo(x: Markets) -> None:
    print(x)


market = os.environ.get("market")


if market not in MARKETS:
    raise ValueError


foo(market)

I get the following error.

Argument 1 to "foo" has incompatible type "str"; expected "Literal['BE', 'DE', 'DK', 'EE', 'ES', 'FI', 'FR', 'GB', 'IT', 'LT', 'LV', 'NL', 'NO', 'PL', 'PT', 'SE']"  [arg-type]mypy(error)

How do I need to check the market variable so that mypy knows that is of correct type?

Asked By: Prokie

||

Answers:

Alternative: Custom TypeGuard

Since Python 3.10 you can define your own type guards, which can make this slightly more elegant:

import os
from typing import Literal, TypeGuard, get_args


MarketT = Literal[
    "BE", "DE", "DK", "EE", "ES", "FI", "FR", "GB", "IT", "LT", "LV", "NL", "NO", "PL", "PT", "SE"
]
MARKETS: list[MarketT] = list(get_args(MarketT))


def is_valid_market(val: str) -> TypeGuard[MarketT]:
    return val in MARKETS


def foo(x: MarketT) -> None:
    print(x)


market = os.environ.get("market", "")
reveal_type(market)
assert is_valid_market(market)
reveal_type(market)

foo(market)

Running mypy over this will show you that before the assert the type is inferred as str, whereas after the assert it is narrowed to that union of string literals you defined earlier. This basically combines both the runtime check (that you already had) and the static narrowing into one.

Note: I still need to provide a str instance as the default for os.environ.get because otherwise market might still turn out to be None. We could instead annotate the val parameter in is_valid_market with Optional[str] to avoid another type checker error. This is just a matter of preference.


Original post

Yes, cast is the easiest way IMO:

import os
from typing import Literal, cast, get_args


Market = Literal[
    "BE", "DE", "DK", "EE", "ES", "FI", "FR", "GB", "IT", "LT", "LV", "NL", "NO", "PL", "PT", "SE"
]
MARKETS: list[Market] = list(get_args(Market))


def foo(x: Market) -> None:
    print(x)


market = cast(Market, os.environ.get("market"))
# reveal_type(market)

if market not in MARKETS:
    raise ValueError

foo(market)

Uncommenting the reveal_type statement and running mypy will give you the following:

note: Revealed type is "Union[Literal['BE'], Literal['DE'], Literal['DK'], Literal['EE'], Literal['ES'], Literal['FI'], Literal['FR'], Literal['GB'], Literal['IT'], Literal['LT'], Literal['LV'], Literal['NL'], Literal['NO'], Literal['PL'], Literal['PT'], Literal['SE']]"

So the type is correctly inferred as a union of those string literals.

As a side note, semantically, I would say the name of your literal union should be Market, not Markets (maybe even MarketType or MarketT). It refers to the type of the variable that will represent a single market after all, not multiple. The list name on the other hand is fitting, since it refers to a collection of all the possible markets.

Answered By: Daniil Fajnberg

No need to use cast(..). Just type your variable:

def foo(x: Market) -> None:
    print(x)

market: Market = os.environ.get("market")

foo(market)
Answered By: Joan Flotats
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.