How to define `__str__` for `dataclass` that omits default values?

Question:

Given a dataclass instance, I would like print() or str() to only list the non-default field values. This is useful when the dataclass has many fields and only a few are changed.

@dataclasses.dataclass
class X:
  a: int = 1
  b: bool = False
  c: float = 2.0

x = X(b=True)
print(x)  # Desired output: X(b=True)
Asked By: Hugues

||

Answers:

The solution is to add a custom __str__() function:

@dataclasses.dataclass
class X:
  a: int = 1
  b: bool = False
  c: float = 2.0

  def __str__(self):
    """Returns a string containing only the non-default field values."""
    s = ', '.join(f'{field.name}={getattr(self, field.name)!r}'
                  for field in dataclasses.fields(self)
                  if getattr(self, field.name) != field.default)
    return f'{type(self).__name__}({s})'

x = X(b=True)
print(x)        # X(b=True)
print(str(x))   # X(b=True)
print(repr(x))  # X(a=1, b=True, c=2.0)
print(f'{x}, {x!s}, {x!r}')  # X(b=True), X(b=True), X(a=1, b=True, c=2.0)

This can also be achieved using a decorator:

def terse_str(cls):  # Decorator for class.
  def __str__(self):
    """Returns a string containing only the non-default field values."""
    s = ', '.join(f'{field.name}={getattr(self, field.name)}'
                  for field in dataclasses.fields(self)
                  if getattr(self, field.name) != field.default)
    return f'{type(self).__name__}({s})'

  setattr(cls, '__str__', __str__)
  return cls

@dataclasses.dataclass
@terse_str
class X:
  a: int = 1
  b: bool = False
  c: float = 2.0
Answered By: Hugues

One improvement I would suggest is to compute the result from dataclasses.fields and then cache the default values from the result. This will help performance because currently dataclasses evaluates the fields each time it is invoked.

Here’s a simple example using a metaclass approach.

Note that I’ve also modified it slightly so it handles mutable-type fields that define a default_factory for instance.

from __future__ import annotations
import dataclasses


# adapted from `dataclasses` module
def _create_fn(name, args, body, *, globals=None, locals=None):
    if locals is None:
        locals = {}
    args = ','.join(args)
    body = 'n'.join(f'  {b}' for b in body)
    # Compute the text of the entire function.
    txt = f' def {name}({args}):n{body}'
    local_vars = ', '.join(locals.keys())
    txt = f"def __create_fn__({local_vars}):n{txt}n return {name}"
    ns = {}
    exec(txt, globals, ns)
    return ns['__create_fn__'](**locals)


def terse_str(cls_name, bases, cls_dict):  # Metaclass for class

    def __str__(self):
        cls_fields: tuple[dataclasses.Field, ...] = dataclasses.fields(self)
        _locals = {}
        _body_lines = ['lines=[]']
        for f in cls_fields:
            name = f.name
            dflt_name = f'_dflt_{name}'
            dflt_factory = f.default_factory
            if dflt_factory is not dataclasses.MISSING:
                _locals[dflt_name] = dflt_factory()
            else:
                _locals[dflt_name] = f.default
            _body_lines.append(f'value=self.{name}')
            _body_lines.append(f'if value != _dflt_{name}:')
            _body_lines.append(f' lines.append(f"{name}={{value!r}}")')
        _body_lines.append(f'return f'{cls_name}({{", ".join(lines)}})'')
        # noinspection PyShadowingNames
        __str__ = _create_fn('__str__', ('self', ), _body_lines, locals=_locals)
        # set the __str__ with the cached `dataclass.fields`
        setattr(type(self), '__str__', __str__)
        # on initial run, compute and return __str__()
        return __str__(self)

    cls_dict['__str__'] = __str__
    return type(cls_name, bases, cls_dict)


@dataclasses.dataclass
class X(metaclass=terse_str):
    a: int = 1
    b: bool = False
    c: float = 2.0
    d: list[int] = dataclasses.field(default_factory=lambda: [1, 2, 3])


x1 = X(b=True)
x2 = X(b=False, c=3, d=[1, 2])

print(x1)    # X(b=True)
print(x2)    # X(c=3, d=[1, 2])

Finally, here’s a quick and dirty test to confirm that caching is actually beneficial for repeated calls to str() or print:

import dataclasses
from timeit import timeit


def terse_str(cls):  # Decorator for class.
    def __str__(self):
        """Returns a string containing only the non-default field values."""
        s = ', '.join(f'{field.name}={getattr(self, field.name)}'
                      for field in dataclasses.fields(self)
                      if getattr(self, field.name) != field.default)
        return f'{type(self).__name__}({s})'

    setattr(cls, '__str__', __str__)
    return cls


# adapted from `dataclasses` module
def _create_fn(name, args, body, *, globals=None, locals=None):
    if locals is None:
        locals = {}
    args = ','.join(args)
    body = 'n'.join(f'  {b}' for b in body)
    # Compute the text of the entire function.
    txt = f' def {name}({args}):n{body}'
    local_vars = ', '.join(locals.keys())
    txt = f"def __create_fn__({local_vars}):n{txt}n return {name}"
    ns = {}
    exec(txt, globals, ns)
    return ns['__create_fn__'](**locals)


def terse_str_meta(cls_name, bases, cls_dict):  # Metaclass for class

    def __str__(self):
        cls_fields: tuple[dataclasses.Field, ...] = dataclasses.fields(self)
        _locals = {}
        _body_lines = ['lines=[]']
        for f in cls_fields:
            name = f.name
            dflt_name = f'_dflt_{name}'
            dflt_factory = f.default_factory
            if dflt_factory is not dataclasses.MISSING:
                _locals[dflt_name] = dflt_factory()
            else:
                _locals[dflt_name] = f.default
            _body_lines.append(f'value=self.{name}')
            _body_lines.append(f'if value != _dflt_{name}:')
            _body_lines.append(f' lines.append(f"{name}={{value!r}}")')
        _body_lines.append(f'return f'{cls_name}({{", ".join(lines)}})'')
        # noinspection PyShadowingNames
        __str__ = _create_fn('__str__', ('self', ), _body_lines, locals=_locals)
        # set the __str__ with the cached `dataclass.fields`
        setattr(type(self), '__str__', __str__)
        # on initial run, compute and return __str__()
        return __str__(self)

    cls_dict['__str__'] = __str__
    return type(cls_name, bases, cls_dict)


@dataclasses.dataclass
@terse_str
class X:
    a: int = 1
    b: bool = False
    c: float = 2.0


@dataclasses.dataclass
class X_Cached(metaclass=terse_str_meta):
    a: int = 1
    b: bool = False
    c: float = 2.0


print(f"Simple:  {timeit('str(X(b=True))', globals=globals()):.3f}")
print(f"Cached:  {timeit('str(X_Cached(b=True))', globals=globals()):.3f}")

print()
print(X(b=True))
print(X_Cached(b=True))

Results:

Simple:  1.038
Cached:  0.289
Answered By: rv.kvetch
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.