pydantic and subclasses of abstract class

Question:

I am trying to use pydantic with a schema that looks as the following:

class Base(BaseModel, ABC):
    common: int

class Child1(Base):
    child1: int

class Child2(Base):
    child2: int

class Response(BaseModel):
    events: List[Base]


events = [{'common':1, 'child1': 10}, {'common': 2, 'child2': 20}]

resp = Response(events=events)

resp.events
#Out[49]: [<Base common=10>, <Base common=3>]

It only took the field of the Base class and ignored the rest. How can I use pydantic with this kind of inheritance? I want events to be a list of instances of subclasses of Base

Asked By: Apostolos

||

Answers:

The best approach right now would be to use Union, something like

class Response(BaseModel):
    events: List[Union[Child2, Child1, Base]]

Note the order in the Union matters: pydantic will match your input data against Child2, then Child1, then Base; thus your events data above should be correctly validated. See this warning about Union order.

In future discriminators might be able to do something similar to this in a more powerful way.

There’s also more information on related matters in this issue.

Answered By: SColvin

I went about approaching this problem by building a custom validator:

class Base(BaseModel, ABC):
    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, v):
        if not issubclass(v, Base):
            raise ValueError("Invalid Object")

        return v
    common: int

class Child1(Base):
    child1: int

class Child2(Base):
    child2: int

class Response(BaseModel):
    events: List[Base]
Answered By: syntactic

Here are two ways to do this in a more generic way plus a third one which can be customised :

from typing import Union

def all_subclasses(cls):
    return list(cls.__subclasses__()) +
        [s for c in cls.__subclasses__() for s in all_subclasses(c)]

def model_instance(cls):
    return Union.__getitem__( tuple(all_subclasses(cls)[::-1]) )

##########################
class Response(BaseModel):
    events: List[model_instance(Base)] 

Better I think, and more readable:

class ModelInstanceMeta(type):
    def __getitem__(cls, item):
        if isinstance(item, tuple):
            raise ValueError("ModelInstance takes only one subfield ")
        # quizz of the order ?? 
        return Union.__getitem__( tuple(all_subclasses(item)[::-1]) )
        
class ModelInstance(metaclass=ModelInstanceMeta):
    pass

#############################
class Response(BaseModel):
    events: List[ModelInstance[Base]] 

Finally this one is I think more complete and you can imagine
a custom validation function to what found in the payload.
(e.g. the payload could have a type keyword as a switch to one type
or an other).

Note that you may also want to return a copy of model (when it is a model) to behave like Pydantic default instead of the model itself.

from pydantic import ValidationError, BaseModel

def all_subclasses(cls):
    return list(cls.__subclasses__()) +
        [s for c in cls.__subclasses__() for s in all_subclasses(c)]

class ModelInstanceMeta(type):
    def __getitem__(cls, item):
        if isinstance(item, tuple):
            raise ValueError("ModelInstance takes only one subfield ")
        return type("ModelInstance["+item.__name__+"]", (cls,), {'__BaseClass__': item})


class ModelInstance(metaclass=ModelInstanceMeta):
    __BaseClass__ = BaseModel
    @classmethod
    def __get_validators__(cls):
        yield cls.validate 

    @classmethod
    def validate(cls, value):
        if isinstance( value, cls.__BaseClass__ ):
            return value 
        
        errors = []
        #################
        # replace this with something more custom if needed 
        for SubClass in all_subclasses(cls.__BaseClass__)[::-1]:
            for validator in SubClass.__get_validators__():
                try:
                    return validator(value)
                except (ValidationError, ValueError, AttributeError, KeyError) as err:
                    errors.append(err)
        ####
        if errors:
            raise ValueError( "n".split( errors ))
        else:
            raise ValueError( "cannot find a valid subclass")

#######################
class Response(BaseModel):
    events: List[ModelInstance[Base]] 

Note1: I can imagine that for complex subclassing the result can be out of control as it depend of the order of subclasses found. May be the order can be done by trying to validate first the subclass with the highest number of fields.

Note2: I am raising ValueError because I have problems with ValidationError, I may not have understood how it works.

@SColvin Don’t you think something pydantic native can be implemented ?

Edit1

Bonus following what I did above. If one wants to have the possibility to switch from one subclassed model to an other from the payload.
Here is a little trick:

from enum import Enum 
def strict(value):
    return Enum("Strict", {"V":value}).V 

Then use it in your model:

class Child1(Base):
    child1: int
    kind = strict("C1")

class Child2(Base):
    child2: int
    kind = strict("C2")

class Response(BaseModel):
    events: List[ModelInstance[Base]]

events = [{'common':1, 'kind': 'C1'}, {'common': 2, 'kind': 'C2'}]

response = Response(events=events)
assert isinstance( response.events[0], Child1)
assert isinstance( response.events[1], Child2)
Answered By: Sylvain Guieu
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.