Building a toy Python @dataclass decorator
Marton Trencseni - Thu 12 May 2022 - Python
Introduction
Following the previous article on a toy implementation of a Python Enum
class and the article on Python decorator patterns, this time I will write a toy implememtation of the built-in @dataclass
decorator. The official documentation for dataclasses is here. The ipython notebook is up on Github.
@dataclass
features
@dataclass
is a very useful feature of the Python standard library. It's a simple way to declare a class with typed variables, and get useful helper functions added to the class "for free":
from dataclasses import *
@dataclass
class Fraction:
numerator: int = 0
denominator: int = 1
With @dataclass
, we get free constructors for the member attributes. There are 3 ways to get a new Fraction
object:
f = Fraction() # defaults, same as Fraction(0, 1)
f = Fraction(1, 2)
f = Fraction(numerator=1, denominator=2)
Note that the second and third would not work without @dataclass
, we would get TypeError: Fraction() takes no arguments
. We also get free equality checks (it compares each of the members), like:
f = Fraction(1, 2)
g = Fraction(1, 2)
f == g # True
Note that without @dataclass
, the ==
equality would be False
. We also get a useful string representation:
print(Fraction(1, 2)) # prints: Fraction(numerator=1, denominator=2)
By specifying order=True
in the @dataclass
decorator, we also get free <, <=, >, >=
comparisons:
@dataclass(order=True)
class Fraction:
numerator: int = 0
denominator: int = 1
f = Fraction(1, 2)
g = Fraction(2, 3)
f <= g # True
The free comparison functions aren't very useful for the Fraction
example, since they just do a attribute-wise comparison. So eg. f < g
is the same as f.numerator < g.numerator and f.denominator < g.denominator
, which mathematically isn't the right expression.
By specifying frozen=True
in the decorator arguments, we get read-only objects, which are useful in eg. multithreading:
@dataclass(frozen=True)
class Fraction:
numerator: int = 0
denominator: int = 1
f = Fraction(1, 2)
f.numerator = 2 # FrozenInstanceError: cannot assign to field 'numerator'
Finally, at the top-level of dataclasses
are 2 useful functions, asdict()
and astuple()
:
asdict(Fraction(1, 2)) # {'numerator': 1, 'denominator': 2}
astuple(Fraction(1, 2)) # (1, 2)
The Python standard library @dataclass
has a lot more features, but these are some of the most frequently used ones. Let's practice our Python fu and create a toy implementation of @dataclass
. As seen before, the general skeleton will be:
def dataclass():
def decorator(cls):
# add useful features to cls
...
return cls
return decorator
Finding the annotated member attributes
First we have to locate the annotated attributes declared by the user in the class object. The trick is telling them apart from all the built-in functions and attributes and any user defined functions. We can check in cls.__dict__
and dir(cls)
:
class Fraction:
numerator: int = 0
denominator: int = 1
def mul(self, x: int):
self.numerator *= x
print(Fraction.__dict__)
print()
print(dir(Fraction))
Output:
{'__module__': '__main__', '__annotations__': {'numerator': <class 'int'>, 'denominator': <class 'int'>}, 'numerator': 0, 'denominator': 1, '__dict__': <attribute '__dict__' of 'Fraction' objects>, '__weakref__': <attribute '__weakref__' of 'Fraction' objects>, '__doc__': None}
['__annotations__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'denominator', 'numerator']
Long story short, the right place to look is cls.__annotations__
:
print(Fraction.__annotations__)
Output:
{'numerator': <class 'int'>, 'denominator': <class 'int'>}
Building the constuctor
Now that we found the attributes, let's build the constructor:
def dataclass(...):
def decorator(cls):
cls._attribs = cls.__annotations__.keys()
# define initializer, unless defined by the user
if '__init__' not in cls.__dict__:
def __init__(self, *args, **kwargs):
if len(args) > 0:
for attrib, arg in zip(self.__class__._attribs, args):
# avoid our own __setattr__ in case it's a frozen dataclass:
object.__setattr__(self, attrib, arg)
elif len(kwargs) > 0:
for attrib in self.__class__._attribs:
# avoid our own __setattr__ in case it's a frozen dataclass:
object.__setattr__(self, attrib, kwargs[attrib])
cls.__init__ = __init__
...
return cls
First we save the attribute keys in cls._attribs
, so we can conveniently access it in subsequent functions.
We check if the user explicitly defined their own constructor with if '__init__' not in cls.__dict__
. If not, then we define the function and assign it with cls.__init__ = __init__
at the end. The constructor function itself takes *args
and **kwargs
. If the user calls it like Fraction(1, 2)
, then args
will be (1, 2)
and kwargs
will be an empty dict {}
. If the user calls it like Fraction(numerator=1, denominator=2)
, then args
will be ()
and kwargs
will be a dict like {'numerator': 1, 'denominator': 2}
. So we just need to get these, and assign them to the appropriate member variable with object.__setattr__()
, to avoid using cls.__setattr__()
in case this is a frozen dataclass (see later).
String conversion
String conversion is relatively straightforward:
def decorator(cls):
cls._attribs = cls.__annotations__.keys()
...
# define string conversion, unless defined by th user
if '__str__' not in cls.__dict__:
def __str__(self):
kv_tuples = [(attrib, getattr(self, attrib)) for attrib in self.__class__._attribs]
kv_str = ', '.join([f'{k}={v}' for (k, v) in kv_tuples])
return f'{self.__class__.__name__}({kv_str})'
cls.__str__ = __str__
Equality and comparators
Next, let's write equality and comparators. These can be switched on and off by passing the appropriate arguments to the decorator, so we have to check these:
def dataclass(..., **kwargs):
def decorator(cls):
...
if kwargs.get('eq', True):
# define ==, unless defined by the user
if '__eq__' not in cls.__dict__:
def __eq__(self, other):
for attrib in self.__class__._attribs:
if getattr(self, attrib) != getattr(other, attrib):
return False
return True
cls.__eq__ = __eq__
Operators for <, <=, >, >=
follow the same logic, showing just the first:
if kwargs.get('order', False):
# define <, <=, >, >=, unless defined by the user
if '__lt__' not in cls.__dict__:
def __lt__(self, other):
for attrib in self.__class__._attribs:
if getattr(self, attrib) >= getattr(other, attrib):
return False
return True
cls.__lt__ = __lt__
...
Frozen
If frozen=True
is passed to the decorator, we have to disallow assignment to attributes, ie. f.numerator = 1
should raise an exception. When we write f.numerator = 1
, internally this becomes Fraction.__setattr__(f, 'numerator', 1)
. So to disallow this, we just have to define our own __setattr__()
:
if kwargs.get('frozen', False):
# don't allow changing attributes
def __setattr__(self, attrib, value):
if attrib not in self.__class__._attribs:
setattr(self, attrib, value)
else:
raise AttributeError(f'dataclass is frozen, cannot assign to field \'{attrib}\'')
cls.__setattr__ = __setattr__
@dataclass
vs @dataclass(...)
We need one final trick. When using this decorator, we want to allow users to write both @dataclass
and @dataclass(...)
. In the first case, the class after the decorator gets passed directly to the dataclass
function. In the second case, dataclass()
is first called with the arguments (without the following class), and that must return a function which accepts the class. We want to support both behaviours in a single decorator function. The solution is quite simple:
def dataclass(cls=None, **kwargs):
def decorator(cls):
cls._attribs = cls.__annotations__.keys()
...
# this is where the code shown above is
...
return cls
if cls is None:
# decorator was used like @dataclass(...)
return decorator
else:
# decorator was used like @dataclass, without parens
return decorator(cls)
Complete toy implementation
That's it! The complete toy implementation, including asdict()
and astuple()
is:
def asdict(x):
return {a:getattr(x, a) for a in x.__class__._attribs}
def astuple(x):
return tuple(getattr(x, a) for a in x.__class__._attribs)
def dataclass(cls=None, **kwargs):
def decorator(cls):
cls._attribs = cls.__annotations__.keys()
# define initializer, unless defined by the user
if '__init__' not in cls.__dict__:
def __init__(self, *args, **kwargs):
if len(args) > 0:
for attrib, arg in zip(self.__class__._attribs, args):
# avoid our own __setattr__ in case it's a frozen dataclass:
object.__setattr__(self, attrib, arg)
elif len(kwargs) > 0:
for attrib in self.__class__._attribs:
# avoid our own __setattr__ in case it's a frozen dataclass:
object.__setattr__(self, attrib, kwargs[attrib])
cls.__init__ = __init__
# define string conversion, unless defined by th user
if '__str__' not in cls.__dict__:
def __str__(self):
kv_tuples = [(attrib, getattr(self, attrib)) for attrib in self.__class__._attribs]
kv_str = ', '.join([f'{k}={v}' for (k, v) in kv_tuples])
return f'{self.__class__.__name__}({kv_str})'
cls.__str__ = __str__
if kwargs.get('eq', True):
# define ==, unless defined by the user
if '__eq__' not in cls.__dict__:
def __eq__(self, other):
for attrib in self.__class__._attribs:
if getattr(self, attrib) != getattr(other, attrib):
return False
return True
cls.__eq__ = __eq__
if kwargs.get('order', False):
# define <, <=, >, >=, unless defined by the user
if '__lt__' not in cls.__dict__:
def __lt__(self, other):
for attrib in self.__class__._attribs:
if getattr(self, attrib) >= getattr(other, attrib):
return False
return True
cls.__lt__ = __lt__
if '__le__' not in cls.__dict__:
def __le__(self, other):
for attrib in self.__class__._attribs:
if getattr(self, attrib) > getattr(other, attrib):
return False
return True
cls.__le__ = __le__
if '__gt__' not in cls.__dict__:
def __gt__(self, other):
for attrib in self.__class__._attribs:
if getattr(self, attrib) <= getattr(other, attrib):
return False
return True
cls.__gt__ = __gt__
if '__ge__' not in cls.__dict__:
def __ge__(self, other):
for attrib in self.__class__._attribs:
if getattr(self, attrib) < getattr(other, attrib):
return False
return True
cls.__ge__ = __ge__
if kwargs.get('frozen', False):
# don't allow changing attributes
def __setattr__(self, attrib, value):
if attrib not in self.__class__._attribs:
setattr(self, attrib, value)
else:
raise AttributeError(f'dataclass is frozen, cannot assign to field \'{attrib}\'')
cls.__setattr__ = __setattr__
return cls
if cls is None:
# decorator was used like @dataclass(...)
return decorator
else:
# decorator was used like @dataclass, without parens
return decorator(cls)
With this toy implementation, the examples shown at the beginning work just like the Python standard library @dataclass
, whose implementation is in dataclasses.py
(it's 1488 lines of code).
Conclusion
Similar to the previous articles, this is a good way to improve one's Python fu.