Restrict possible values in hydra structured configs

Question:

I try to adopt my app for hydra framework. I use structured config schema and I want to restrict possible values for some fields. Is there any way to do that?

Here is my code:

my_app.py:

import hydra


@dataclass
class Config:
    # possible values are 'foo' and 'bar'
    some_value: str = "foo"


@hydra.main(config_path="configs", config_name="config")
def main(cfg: Config):
    print(cfg)


if __name__ == "__main__":
    main()

configs/config.yaml:

# value is incorrect.
# I need hydra to throw an exception in this case
some_value: "barrr"
Asked By: NShiny

||

Answers:

A few options:

1) If your acceptable values are enumerable, use an Enum type:

from enum import Enum
from dataclasses import dataclass

class SomeValue(Enum):
    foo = 1
    bar = 2

@dataclass
class Config:
    # possible values are 'foo' and 'bar'
    some_value: SomeValue = SomeValue.foo

If no fancy logic is needed to validate some_value, this is the solution I would recommend.

2) If you are using yaml files, you can use OmegaConf to register a custom resolver:

# my_python_file.py
from omegaconf import OmegaConf

def check_some_value(value: str) -> str:
    assert value in ("foo", "bar")
    return value

OmegaConf.register_new_resolver("check_foo_bar", check_some_value)

@hydra.main(...)
...

if __name__ == "__main__":
    main()
# my_yaml_file.yaml
some_value: ${check_foo_bar:foo}

When you access cfg.some_value in your python code, an AssertionError will be raised if the value does not agree with the check_some_value function.

3) After config composition is completed, you can call OmegaConf.to_object to create an instance of your dataclass. This means that the dataclass’s __post_init__ function will get called.

import hydra
from dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf

@dataclass
class Config:
    # possible values are 'foo' and 'bar'
    some_value: str = "foo"

    def __post_init__(self) -> None:
        assert self.some_value in ("foo", "bar")

@hydra.main(config_path="configs", config_name="config")
def main(dict_cfg: DictConfg):
    cfg: Config = OmegaConf.to_object(dict_cfg)
    print(cfg)

if __name__ == "__main__":
    main()
Answered By: Jasha
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.