What is the proper way to use descriptors as fields in Python dataclasses?
Question:
I’ve been playing around with python dataclasses and was wondering: What is the most elegant or most pythonic way to make one or some of the fields descriptors?
In the example below I define a Vector2D class that should be compared on its length.
from dataclasses import dataclass, field
from math import sqrt
@dataclass(order=True)
class Vector2D:
x: int = field(compare=False)
y: int = field(compare=False)
length: int = field(init=False)
def __post_init__(self):
type(self).length = property(lambda s: sqrt(s.x**2+s.y**2))
Vector2D(3,4) > Vector2D(4,1) # True
While this code works, it touches the class every time an instance is created, is there a more readable / less hacky / more intended way to use dataclasses and descriptors together?
Just having length as a property and not a field will work but this means I have to write __lt__
, et.al. by myself.
Another solution I found is equally unappealing:
@dataclass(order=True)
class Vector2D:
x: int = field(compare=False)
y: int = field(compare=False)
length: int = field(init=False)
@property
def length(self):
return sqrt(self.x**2+self.y**2)
@length.setter
def length(self, value):
pass
Introducing a no-op setter is necessary as apparently the dataclass-created init method tries to assign to length even though there isn’t a default value and it explicitly sets init=False
…
Surely there has to be a better way right?
Answers:
Might not answer your exact question, but you mentioned that the reason that you didnt want to have length as a property and a not field was because you would have to
write __lt__
, et.al by myself
While you do have to implement __lt__
by yourself, you can actually get away with implementing just that
from functools import total_ordering
from dataclasses import dataclass, field
from math import sqrt
@total_ordering
@dataclass
class Vector2D:
x: int
y: int
@property
def length(self):
return sqrt(self.x ** 2 + self.y ** 2)
def __lt__(self, other):
if not isinstance(other, Vector2D):
return NotImplemented
return self.length < other.length
def __eq__(self, other):
if not isinstance(other, Vector2D):
return NotImplemented
return self.length == other.length
print(Vector2D(3, 4) > Vector2D(4, 1))
The reason this works is because total_ordering
just adds all the other equality methods based on __eq__
and __lt__
I do not think that the example you present is a good use case for what you are trying to do. Nevertheless, for completeness sake, here is a possible solution to your problem:
@dataclass(order=True)
class Vector2D:
x: int = field(compare=False)
y: int = field(compare=False)
length: int = field(default=property(lambda s: sqrt(s.x**2+s.y**2)), init=False)
This works because dataclass
sets defaults as values on the class attributes unless the value is a list, dict or set.
Although you could implement the @property
and other methods manually, this can make you lose other desirable features like in this case using hash=False
if you wanted to use your Vector2D
in a dict
. Additionally, letting it implement dunder methods for you makes your code less error prone e.g. you can’t forget to return NotImplemented
which is a common mistake.
The drawback is that implementing the correct type-hint is not easy and that there can be some minor caveats, but once the type-hint is implemented it can be easily used anywhere.
The property (descriptor) type-hint:
import sys
from typing import Any, Optional, Protocol, TypeVar, overload
if sys.version_info < (3, 9):
from typing import Type
else:
from builtins import type as Type
IT = TypeVar("IT", contravariant=True)
CT = TypeVar("CT", covariant=True)
GT = TypeVar("GT", covariant=True)
ST = TypeVar("ST", contravariant=True)
class Property(Protocol[CT, GT, ST]):
# Get default attribute from a class.
@overload
def __get__(self, instance: None, owner: Type[Any]) -> CT:
...
# Get attribute from an instance.
def __get__(self, instance: IT, owner: Optional[Type[IT]] = ...) -> GT:
...
def __get__(self, instance, owner=None):
...
def __set__(self, instance: Any, value: ST) -> None:
...
From here, we can now type-hint our property
object when using a dataclass
. Use field(default=property(...))
if you need to use the other options in field(...)
.
import sys
import typing
from dataclasses import dataclass, field
from math import hypot
# Use for read-only property.
if sys.version_info < (3, 11):
from typing import NoReturn as Never
else:
from typing import Never
@dataclass(order=True)
class Vector2D:
x: int = field(compare=False)
y: int = field(compare=False)
# Properties return themselves as their default class variable.
# Read-only properties never allow setting a value.
# If init=True, then it would assign self.length = Vector2D.length for the
# default factory.
# Setting repr=False for consistency with init=False.
length: Property[property, float, Never] = field(
default=property(lambda v: hypot(v.x, v.y)),
init=False,
repr=False,
)
v1 = Vector2D(3, 4)
v2 = Vector2D(6, 8)
if typing.TYPE_CHECKING:
reveal_type(Vector2D.length) # builtins.property
reveal_type(v1.length) # builtins.float
assert v1.length == 5.0
assert v2.length == 10.0
assert v1 < v2
Try it on mypy Playground.
I’ve been playing around with python dataclasses and was wondering: What is the most elegant or most pythonic way to make one or some of the fields descriptors?
In the example below I define a Vector2D class that should be compared on its length.
from dataclasses import dataclass, field
from math import sqrt
@dataclass(order=True)
class Vector2D:
x: int = field(compare=False)
y: int = field(compare=False)
length: int = field(init=False)
def __post_init__(self):
type(self).length = property(lambda s: sqrt(s.x**2+s.y**2))
Vector2D(3,4) > Vector2D(4,1) # True
While this code works, it touches the class every time an instance is created, is there a more readable / less hacky / more intended way to use dataclasses and descriptors together?
Just having length as a property and not a field will work but this means I have to write __lt__
, et.al. by myself.
Another solution I found is equally unappealing:
@dataclass(order=True)
class Vector2D:
x: int = field(compare=False)
y: int = field(compare=False)
length: int = field(init=False)
@property
def length(self):
return sqrt(self.x**2+self.y**2)
@length.setter
def length(self, value):
pass
Introducing a no-op setter is necessary as apparently the dataclass-created init method tries to assign to length even though there isn’t a default value and it explicitly sets init=False
…
Surely there has to be a better way right?
Might not answer your exact question, but you mentioned that the reason that you didnt want to have length as a property and a not field was because you would have to
write
__lt__
, et.al by myself
While you do have to implement __lt__
by yourself, you can actually get away with implementing just that
from functools import total_ordering
from dataclasses import dataclass, field
from math import sqrt
@total_ordering
@dataclass
class Vector2D:
x: int
y: int
@property
def length(self):
return sqrt(self.x ** 2 + self.y ** 2)
def __lt__(self, other):
if not isinstance(other, Vector2D):
return NotImplemented
return self.length < other.length
def __eq__(self, other):
if not isinstance(other, Vector2D):
return NotImplemented
return self.length == other.length
print(Vector2D(3, 4) > Vector2D(4, 1))
The reason this works is because total_ordering
just adds all the other equality methods based on __eq__
and __lt__
I do not think that the example you present is a good use case for what you are trying to do. Nevertheless, for completeness sake, here is a possible solution to your problem:
@dataclass(order=True)
class Vector2D:
x: int = field(compare=False)
y: int = field(compare=False)
length: int = field(default=property(lambda s: sqrt(s.x**2+s.y**2)), init=False)
This works because dataclass
sets defaults as values on the class attributes unless the value is a list, dict or set.
Although you could implement the @property
and other methods manually, this can make you lose other desirable features like in this case using hash=False
if you wanted to use your Vector2D
in a dict
. Additionally, letting it implement dunder methods for you makes your code less error prone e.g. you can’t forget to return NotImplemented
which is a common mistake.
The drawback is that implementing the correct type-hint is not easy and that there can be some minor caveats, but once the type-hint is implemented it can be easily used anywhere.
The property (descriptor) type-hint:
import sys
from typing import Any, Optional, Protocol, TypeVar, overload
if sys.version_info < (3, 9):
from typing import Type
else:
from builtins import type as Type
IT = TypeVar("IT", contravariant=True)
CT = TypeVar("CT", covariant=True)
GT = TypeVar("GT", covariant=True)
ST = TypeVar("ST", contravariant=True)
class Property(Protocol[CT, GT, ST]):
# Get default attribute from a class.
@overload
def __get__(self, instance: None, owner: Type[Any]) -> CT:
...
# Get attribute from an instance.
def __get__(self, instance: IT, owner: Optional[Type[IT]] = ...) -> GT:
...
def __get__(self, instance, owner=None):
...
def __set__(self, instance: Any, value: ST) -> None:
...
From here, we can now type-hint our property
object when using a dataclass
. Use field(default=property(...))
if you need to use the other options in field(...)
.
import sys
import typing
from dataclasses import dataclass, field
from math import hypot
# Use for read-only property.
if sys.version_info < (3, 11):
from typing import NoReturn as Never
else:
from typing import Never
@dataclass(order=True)
class Vector2D:
x: int = field(compare=False)
y: int = field(compare=False)
# Properties return themselves as their default class variable.
# Read-only properties never allow setting a value.
# If init=True, then it would assign self.length = Vector2D.length for the
# default factory.
# Setting repr=False for consistency with init=False.
length: Property[property, float, Never] = field(
default=property(lambda v: hypot(v.x, v.y)),
init=False,
repr=False,
)
v1 = Vector2D(3, 4)
v2 = Vector2D(6, 8)
if typing.TYPE_CHECKING:
reveal_type(Vector2D.length) # builtins.property
reveal_type(v1.length) # builtins.float
assert v1.length == 5.0
assert v2.length == 10.0
assert v1 < v2
Try it on mypy Playground.