Pydantic root_validator – ok to use once for entire model, instead of individual validators?

Question:

I am using Pydantic to create a Timeseries model based on pandas Timestamp (start, end) and Timedelta (period) objects. The model will be used by a small data analysis program with a number of configurations/scenarios.

I need to instantiate and validate aspects of the Timeseries model based on two bool (include_end_period, allow_future) and one optional int (max_periods) config params. I then need to derive three new fields (timezone, total_duration, total_periods) and perform some additional validations.

Due to several instances of needing to use one value when validating another, I was unable to achieve the desired result with the typical @validator methods. In particular, I would often get a missing KeyError instead of an expected ValueError. The best solution I’ve found is to instead create one long @root_validator(pre=True) method.

from pydantic import BaseModel, ValidationError, root_validator, conint
from pandas import Timestamp, Timedelta


class Timeseries(BaseModel):
    start: Timestamp
    end: Timestamp
    period: Timedelta
    include_end_period: bool = False
    allow_future: bool = True
    max_periods: conint(gt=0, strict=True) | None = None
    
    # Derived values, do not pass as params
    timezone: str | None
    total_duration: Timedelta
    total_periods: conint(gt=0, strict=True)
    
    class Config:
        extra = 'forbid'
        validate_assignment = True
    
    
    @root_validator(pre=True)
    def _validate_model(cls, values):
        
        # Validate input values
        if values['start'] > values['end']:
            raise ValueError('Start timestamp cannot be later than end')
        if values['start'].tzinfo != values['end'].tzinfo:
            raise ValueError('Start, end timezones do not match')
        if values['period'] <= Timedelta(0):
            raise ValueError('Period must be a positive amount of time')
        
        # Set timezone
        timezone = values['start'].tzname()
        if 'timezone' in values and values['timezone'] != timezone:
            raise ValueError('Timezone param does not match start timezone')
        values['timezone'] = timezone
        
        # Set duration (add 1 period if including end period)
        total_duration = values['end'] - values['start']
        if values['include_end_period']:
            total_duration += values['period']
        if 'total_duration' in values and values['total_duration'] != total_duration:
            error_context = ' + 1 period (included end period)' if values['include_end_period'] else ''
            raise ValueError(f'Duration param does not match end - start timestamps{error_context}')
        values['total_duration'] = total_duration
        
        # Set total_periods
        periods_float: float = values['total_duration'] / values['period']
        if periods_float != int(periods_float):
            raise ValueError('Total duration not divisible by period length')
        total_periods = int(periods_float)
        if 'total_periods' in values and values['total_periods'] != total_periods:
            raise ValueError('Total periods param does not match')
        values['total_periods'] = total_periods
        
        # Validate future
        if not values['allow_future']:
            # Get current timestamp to floor of period (subtract 1 period if including end period)
            max_end: Timestamp = Timestamp.now(tz=values['timezone']).floor(freq=values['period'])
            if values['include_end_period']:
                max_end -= values['period']
            if values['end'] > max_end:
                raise ValueError('End period is future or current (incomplete)')
        
        # Validate derived values
        if values['total_duration'] < Timedelta(0):
            raise ValueError('Total duration must be positive amount of time')
        if values['max_periods'] and values['total_periods'] > values['max_periods']:
            raise ValueError('Total periods exceeds max periods param')
        
        return values

Instantiating the model in the happy case, using all config checks:

start = Timestamp('2023-03-01T00:00:00Z')
end = Timestamp('2023-03-02T00:00:00Z')
period = Timedelta('5min')

try:
    ts = Timeseries(start=start, end=end, period=period,
                    include_end_period=True, allow_future=False, max_periods=10000)
    print(ts.dict())
except ValidationError as e:
    print(e)

Output:

"""
{'start': Timestamp('2023-03-01 00:00:00+0000', tz='UTC'),
 'end': Timestamp('2023-03-02 00:00:00+0000', tz='UTC'),
 'period': Timedelta('0 days 00:05:00'),
 'include_end_period': True,
 'allow_future': False,
 'max_periods': 10000,
 'timezone': 'UTC',
 'total_duration': Timedelta('1 days 00:05:00'),
 'total_periods': 289}
"""

Here I believe all my validation is working as expected, and delivers the expected ValueErrors instead of less helpful KeyErrors. Is this approach reasonable? It seems to go against the typical/recommended approach, and the @root_validator documentation is quite brief compared to that of the @validator.

I am also unsatisfied that I need to list the derived values (timezone, total_duration, total_periods) at the top of the model. This implies they can/should be passed when instantiating, and requires extra logic in my validator script to check if they were passed, and if they match the derived values. By omitting them they would not benefit from the default validation of type, constraints, etc., and would force me to change the config to extra='allow'. I would appreciate any tips on how to improve this.

Thank you!

Asked By: sha2fiddy

||

Answers:

For testing purposes it is usually a good idea not to have such large functions. Even if you want to go for the root_validator approach, you can (and IMO should) still divide up the logic into distinct, semantically sensible methods.

But I would suggest a slightly different approach altogether. Since timezone, total_duration and total_periods are derived from other fields and that process is not very expensive, I would define properties for those instead of having them as fields.

This has the advantage that you don’t need to compute their values in advance, which means you don’t need the pre=True approach and can utilize previously validated field values in field-specific validators.

Root validators still make sense, when you really need to ensure that many distinct fields taken together follow a certain logic.

Here is what I propose:

from collections.abc import Mapping
from typing import Any

from pandas import Timedelta, Timestamp
from pydantic import BaseModel, conint, root_validator, validator

AnyMap = Mapping[str, Any]


class Timeseries(BaseModel):
    start: Timestamp
    end: Timestamp
    period: Timedelta
    include_end_period: bool = False
    allow_future: bool = True
    max_periods: conint(gt=0, strict=True) | None = None

    @validator("end")
    def ensure_end_consistent_with_start(
        cls,
        v: Timestamp,
        values: AnyMap,
    ) -> Timestamp:
        val_start: Timestamp = values["start"]
        if v < val_start:
            raise ValueError("Start timestamp cannot be later than end")
        if val_start.tzinfo != v.tzinfo:
            raise ValueError("Start, end timezones do not match")
        return v

    @validator("period")
    def ensure_period_is_positive(cls, v: Timedelta) -> Timedelta:
        if v <= Timedelta(0):
            raise ValueError("Period must be a positive amount of time")
        return v

    @validator("period")
    def ensure_period_divides_duration(
        cls,
        v: Timedelta,
        values: AnyMap,
    ) -> Timedelta:
        duration: float = (values["end"] - values["start"]) / v
        if int(duration) != duration:
            raise ValueError("Total duration not divisible by period length")
        return v

    @root_validator
    def ensure_end_is_allowed(cls, values: AnyMap) -> AnyMap:
        if values["allow_future"]:
            return values
        val_period: Timedelta = values["period"]
        val_end: Timestamp = values["end"]
        max_end = Timestamp.now(tz=val_end.tzname()).floor(freq=val_period)
        if values["include_end_period"]:
            max_end -= val_period
        if val_end > max_end:
            raise ValueError("End period is future or current (incomplete)")
        return values

    @root_validator
    def ensure_num_periods_allowed(cls, values: AnyMap) -> AnyMap:
        periods = int((values["end"] - values["start"]) / values["period"])
        if values["include_end_period"]:
            periods += 1
        if values["max_periods"] and periods > values["max_periods"]:
            raise ValueError("Total periods exceeds max periods param")
        return values

    @property
    def timezone(self) -> str | None:
        return self.start.tzname()

    @property
    def total_duration(self) -> Timedelta:
        total_duration = self.end - self.start
        if self.include_end_period:
            total_duration += self.period
        return total_duration

    @property
    def total_periods(self) -> int:
        return int(self.total_duration / self.period)

I guess it is a matter of personal preference, when to switch from field validators to root validators. For example, you could argue ensure_period_divides_duration should be a root validator since it uses the values of three fields.

Your example data of course works with this model as well.

One thing to note is that the range constraint on total_periods is redundant anyway, when you validate that end is after start (and that period evenly divides the total duration).

You could also argue that even something as simple as total_duration should not be a property. In that case you could make it a method called get_total_duration.

But if you have those "derived" fields, you’ll always run into the issue of having to check that whatever was passed by the user is consistent with the rest of the data.

I believe most of this headache will be gone, once Pydantic v2 drops, which promises computed fields (see the plan for v2).

Answered By: Daniil Fajnberg