Flatten nested Pydantic model

Question:

from typing import Union
from pydantic import BaseModel, Field


class Category(BaseModel):
    name: str = Field(alias="name")


class OrderItems(BaseModel):
    name: str = Field(alias="name")
    category: Category = Field(alias="category")
    unit: Union[str, None] = Field(alias="unit")
    quantity: int = Field(alias="quantity")

When instantiated like this:

OrderItems(**{'name': 'Test','category':{'name': 'Test Cat'}, 'unit': 'kg', 'quantity': 10})

It returns data like this:

OrderItems(name='Test', category=Category(name='Test Cat'), unit='kg', quantity=10)

But I want the output like this:

OrderItems(name='Test', category='Test Cat', unit='kg', quantity=10)

How can I achieve this?

Asked By: Russell

||

Answers:

Try this when instantiating:

myCategory = Category(name="test cat")
OrderItems(
    name="test",
    category=myCategory.name,
    unit="kg",
    quantity=10)
Answered By: jjislam

Well, i was curious, so here’s the insane way:

class Category(BaseModel):
    name: str = Field(alias="name")


class OrderItems(BaseModel):
    name: str = Field(alias="name")
    category: Category = Field(alias="category")
    unit: Union[str, None] = Field(alias="unit")
    quantity: int = Field(alias="quantity")
    
    def json(self, *args, **kwargs) -> str:
        self.__dict__.update({'category': self.__dict__['category'].name})
        return super().json(*args, **kwargs)
    
c = Category(name='Dranks')
m = OrderItems(name='sodie', category=c, unit='can', quantity=1)
m.json()

And you get:

'{"name": "sodie", "category": "Dranks", "unit": "can", "quantity": 1}'

The sane way would probably be:

class Category(BaseModel):
    name: str = Field(alias="name")


class OrderItems(BaseModel):
    name: str = Field(alias="name")
    category: Category = Field(alias="category")
    unit: Union[str, None] = Field(alias="unit")
    quantity: int = Field(alias="quantity")
    
c = Category(name='Dranks')
m = OrderItems(name='sodie', category=c, unit='can', quantity=1)

r = m.dict()
r['category'] = r['category']['name']
Answered By: Vetsin

You should try as much as possible to define your schema the way you actually want the data to look in the end, not the way you might receive it from somewhere else.


UPDATE: Generalized solution (one nested field or more)

To generalize this problem, let’s assume you have the following models:

from pydantic import BaseModel


class Foo(BaseModel):
    x: bool
    y: str
    z: int


class _BarBase(BaseModel):
    a: str
    b: float

    class Config:
        orm_mode = True


class BarNested(_BarBase):
    foo: Foo


class BarFlat(_BarBase):
    foo_x: bool
    foo_y: str

Problem: You want to be able to initialize BarFlat with a foo argument just like BarNested, but the data to end up in the flat schema, wherein the fields foo_x and foo_y correspond to x and y on the Foo model (and you are not interested in z).

Solution: Define a custom root_validator with pre=True that checks if a foo key/attribute is present in the data. If it is, it validates the corresponding object against the Foo model, grabs its x and y values and then uses them to extend the given data with foo_x and foo_y keys:

from pydantic import BaseModel, root_validator
from pydantic.utils import GetterDict

...

class BarFlat(_BarBase):
    foo_x: bool
    foo_y: str

    @root_validator(pre=True)
    def flatten_foo(cls, values: GetterDict) -> GetterDict | dict[str, object]:
        foo = values.get("foo")
        if foo is None:
            return values
        # Assume `foo` must ba valid `Foo` data:
        foo = Foo.validate(foo)
        return {
            "foo_x": foo.x,
            "foo_y": foo.y,
        } | dict(values)

Note that we need to be a bit more careful inside a root validator with pre=True because the values are always passed in the form of a GetterDict, which is an immutable mapping-like object. So we cannot simply assign new values foo_x/foo_y to it like we would to a dictionary. But nothing is stopping us from returning the cleaned up data in the form of a regular old dict.

To demonstrate, we can throw some test data at it:

test_dict = {"a": "spam", "b": 3.14, "foo": {"x": True, "y": ".", "z": 0}}
test_orm = BarNested(a="eggs", b=-1, foo=Foo(x=False, y="..", z=1))
test_flat = '{"a": "beans", "b": 0, "foo_x": true, "foo_y": ""}'
bar1 = BarFlat.parse_obj(test_dict)
bar2 = BarFlat.from_orm(test_orm)
bar3 = BarFlat.parse_raw(test_flat)
print(bar1.json(indent=4))
print(bar2.json(indent=4))
print(bar3.json(indent=4))

The output:

{
    "a": "spam",
    "b": 3.14,
    "foo_x": true,
    "foo_y": "."
}
{
    "a": "eggs",
    "b": -1.0,
    "foo_x": false,
    "foo_y": ".."
}
{
    "a": "beans",
    "b": 0.0,
    "foo_x": true,
    "foo_y": ""
}

The first example simulates a common situation, where the data is passed to us in the form of a nested dictionary. The second example is the typical database ORM object situation, where BarNested represents the schema we find in a database. The third is just to show that we can still correctly initialize BarFlat without a foo argument.

One caveat to note is that the validator does not get rid of the foo key, if it finds it in the values. If your model is configured with Extra.forbid that will lead to an error. In that case, you’ll just need to have an extra line, where you coerce the original GetterDict to a dict first, then pop the "foo" key instead of getting it.


Original post (flatten single field)

If you need the nested Category model for database insertion, but you want a "flat" order model with category being just a string in the response, you should split that up into two separate models.

Then in the response model you can define a custom validator with pre=True to handle the case when you attempt to initialize it providing an instance of Category or a dict for category.

Here is what I suggest:

from pydantic import BaseModel, validator


class Category(BaseModel):
    name: str


class OrderItemBase(BaseModel):
    name: str
    unit: str | None
    quantity: int


class OrderItemCreate(OrderItemBase):
    category: Category


class OrderItemResponse(OrderItemBase):
    category: str

    @validator("category", pre=True)
    def handle_category_model(cls, v: object) -> object:
        if isinstance(v, Category):
            return v.name
        if isinstance(v, dict) and "name" in v:
            return v["name"]
        return v

Here is a demo:

if __name__ == "__main__":
    insert_data = '{"name": "foo", "category": {"name": "bar"}, "quantity": 1}'
    insert_obj = OrderItemCreate.parse_raw(insert_data)
    print(insert_obj.json(indent=2))
    ...  # insert into DB
    response_obj = OrderItemResponse.parse_obj(insert_obj.dict())
    print(response_obj.json(indent=2))

Here is the output:

{
  "name": "foo",
  "unit": null,
  "quantity": 1,
  "category": {
    "name": "bar"
  }
}
{
  "name": "foo",
  "unit": null,
  "quantity": 1,
  "category": "bar"
}

One of the benefits of this approach is that the JSON Schema stays consistent with what you have on the model. If you use this in FastAPI that means the swagger documentation will actually reflect what the consumer of that endpoint receives. You could of course override and customize schema creation, but… why? Just define the model correctly in the first place and avoid headache in the future.

Answered By: Daniil Fajnberg
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.