How to get the index of a dataclass field

Question:

Say I have a simple dataclass instance

import dataclasses as dc

@dc.dataclass
class DCItem:
    name: str
    unit_price: float

item = DCItem('test', 11)

Now I want to determine the position (index) of instance attribute item.unit_price. How can I make it simple to use and performant? I thought about using a get method using dc.asdict

@dc.dataclass
class DCItem:
    name: str
    unit_price: float

    def get_index(self, name):
        return list(dc.asdict(self)).index(name)

item.get_index('unit_price')  # 1

But this has two drawbacks:

  1. It’s not very performant, at least not for many instance attributes
  2. It looses the nice auto-completion feature of item.unit_price

Is there a solution that combines the features of a dataclass with that of IntEnum and enum.auto() without the above drawbacks?

Asked By: MiaPlan.de

||

Answers:

If the class is not going to be changed at runtime, you can cache indexes in a class attribute as a dictionary.

import dataclasses as dc

@dc.dataclass
class DCItem:
    name: str
    unit_price: float
    
    @classmethod
    def get_index(cls, name):
        if '_idx_mapping' not in cls.__dict__:
            flds = dc.fields(cls)
            cls._idx_mapping = {flds[idx].name: idx for idx in range(len(flds))}
        return cls._idx_mapping[name]


>>> item = DCItem('test', 11)
>>> item.get_index('unit_price')
1

Accessing dictionary should be fast – O(n) in the worst case.

>>> from timeit import timeit
>>> timeit("item.get_index('unit_price')", "from __main__ import item")
0.21372696105390787

For comparison, your solution is quite slow, as you mentioned:

>>> timeit("item.get_index('unit_price')", "from __main__ import item")
4.260601775022224

Note: I haven’t tested this class with inheritance.


EDIT: Solving the second point makes the solution more complex. I’ve come up with the following using Python descriptors.

import dataclasses as dc
from typing import Any
from collections import defaultdict


class IndexedField:
    def __init__(self, a_type: type, value: Any, index: int):
        self._validate_type(a_type, value)  # This line can be removed when type checking is not required.
        self._a_type = a_type
        self._value = value
        self._index = index

    @staticmethod
    def _validate_type(a_type: type, value: Any):
        if not isinstance(value, a_type):
            raise TypeError(f"value is of type {type(value)} but {a_type} is expected")

    @property
    def a_type(self):  # read-only
        return self._a_type

    @property
    def index(self):  # read-only
        return self._index

    @property
    def value(self):
        return self._value

    @value.setter
    def value(self, new_value):
        self._validate_type(self._a_type, new_value)  # This line can be removed when type checking is not required.
        self._value = new_value

    def __repr__(self):
        return (f'{self.__class__.__name__}'
                f'(a_type={self._a_type!r}, index={self._index!r}, value={self._value!r})')


class IndexedFieldDescriptor:
    _class_last_index = defaultdict(int)
    _class_indexes = defaultdict(dict)

    def __init__(self, a_type) -> None:
        self._name = None
        self._type = a_type

    def __get__(self, instance, owner):
        if instance is None:
            return self
        return instance.__dict__[self._name]

    def __set_name__(self, owner, name):
        self._name = name
        self._class_indexes[owner.__name__][name] = self._class_last_index[owner.__name__]
        self._class_last_index[owner.__name__] += 1

    def __set__(self, instance, value):
        index = self._class_indexes[instance.__class__.__name__][self._name]
        instance.__dict__[self._name] = IndexedField(self._type, value, index)


@dc.dataclass
class DCItem:
    name: IndexedField = IndexedFieldDescriptor(str)
    unit_price: IndexedField = IndexedFieldDescriptor(float)


item = DCItem('test', 11.0)
print(item)
print(f"* name field value: {item.name.value!r}, name field index: {item.name.index!r}, name field type: {item.name.a_type!r}")
print(f"* unit_price field value: {item.unit_price.value!r}, unit_price field index: {item.unit_price.index!r}, unit_price field type: {item.unit_price.a_type!r}")

from timeit import timeit
print(f'* Index access time: {timeit("item.name.index", "from __main__ import item")}')
print(f'* Value access time: {timeit("item.name.value", "from __main__ import item")}')

Output:

DCItem(name=IndexedField(a_type=<class 'str'>, index=0, value='test'), unit_price=IndexedField(a_type=<class 'float'>, index=1, value=11.0))
* name field value: 'test', name field index: 0, name field type: <class 'str'>
* unit_price field value: 11.0, unit_price field index: 1, unit_price field type: <class 'float'>
* Index access time: 0.2253845389932394
* Value access time: 0.2729280750500038
Answered By: dchrome