Is it possible to create union fields in python data classes? (pydantic, dataclass, attrs)

Question:

I want to build a data structure where there can be multiple fields, but only one of the selected few fields can be set, otherwise an error is raised when other than 1 is set.

Here is what I want to to behave like.

from typing import Optional
from pydantic import BaseModel


class BasicSpec(BaseModel):
    ...


class ComplexSpec(BaseModel):
    ...


class Spec(BaseModel):
    title: str
    basic_spec: Optional[BasicSpec] = None
    complex_spec: Optional[ComplexSpec] = None

    def __init__(self, **kwargs: "Spec"):
        union_fields = ["basic_spec", "complex_spec"]
        r = len([k for k, v in kwargs.items() if k in union_fields and v is not None])
        if r > 1:
            raise ValueError(f"Given more than one union fields: {union_fields}")
        elif r < 1:
            raise ValueError(f"Spec must have one of union fields: {union_fields}")

        super().__init__(**kwargs)

# Fail cases
# s = Spec(title="Test 1", basic_spec=BasicSpec(), complex_spec=ComplexSpec())
# s = Spec(title="Test 1")

# Happy path / Expected behavior
s = Spec(title="Test 1", basic_spec=BasicSpec())
print(s.json(exclude_unset=True))
# Expected results
# {"title": "Test 1", "basic_spec": {}}

Here the data class requires only one of the union fields (basic_spec, complex_spec) to be set.

I created this solution, but I’m wondering if there is a built-in way in the library or a feature in other libraries.

Asked By: Gny

||

Answers:

This simple thing would be to just have one spec that is a Union[BasicSpec, ComplexSpec]. I’m assuming for some reason that isn’t feasible in your use case.

Given that, the best pydantic native solution I can think of is a @root_validator:

from typing import Optional
from pydantic import BaseModel, ValidationError, root_validator


from typing import Optional
from pydantic import BaseModel


class BasicSpec(BaseModel):
    ...


class ComplexSpec(BaseModel):
    ...


class Spec(BaseModel):
    title: str
    basic_spec: Optional[BasicSpec] = None
    complex_spec: Optional[ComplexSpec] = None

    @root_validator(pre=True)
    def check_exactly_one_spec(cls, values):
        assert ('basic_spec' in values) != ('complex_spec' in values), 'please supply exactly one spec'
        return values

# Fail cases
try:
    s = Spec(title="Test 1", basic_spec=BasicSpec(), complex_spec=ComplexSpec())
except ValidationError as v:
    print("Yay! test 1 failed! ", v)
try:
    s = Spec(title="Test 2")
except ValidationError as v:
    print("Yay! test 2 failed! ", v)
# Happy path / Expected behavior
s = Spec(title="Test 1", basic_spec=BasicSpec())
print(s.json(exclude_unset=True))
print("happy path is happy N{grinning face}")
# Expected results
# {"title": "Test 1", "basic_spec": {}}

Output:

Yay! test 1 failed!  1 validation error for Spec
__root__
  please supply exactly one spec (type=assertion_error)
Yay! test 2 failed!  1 validation error for Spec
__root__
  please supply exactly one spec (type=assertion_error)
{"title": "Test 1", "basic_spec": {}}
happy path is happy  
Answered By: mmdanziger

@mmdanziger
Your answer was just what I was looking for,
this is my full solution:

from typing import Any, Dict, Optional
from pydantic import BaseModel, root_validator


class SpecLike(BaseModel):
    ...


class BasicSpec(SpecLike):
    ...


class ComplexSpec(SpecLike):
    ...


class Spec(BaseModel):
    title: str
    basic_spec: Optional[BasicSpec] = None
    complex_spec: Optional[ComplexSpec] = None

    @root_validator(pre=True)
    def check_exactly_one_spec(cls, data: Dict[str, Any]):
        union_fields = ("basic_spec", "complex_spec")
        specs = len([k for k, v in data.items() if k in union_fields and v is not None])
        assert specs, "please supply exactly one spec"
        return data


# s = Spec(title="Test 1")
# s = Spec(title="Test 1")
s = Spec(title="Test 1", basic_spec=None, complex_spec=ComplexSpec())
print(s.json(exclude_unset=True))
Answered By: Gny
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.