Building a toy Python @dataclass decorator

Marton Trencseni - Thu 12 May 2022 - Python

Introduction

Following the previous article on a toy implementation of a Python Enum class and the article on Python decorator patterns, this time I will write a toy implememtation of the built-in @dataclass decorator. The official documentation for dataclasses is here. The ipython notebook is up on Github.

@dataclass features

@dataclass is a very useful feature of the Python standard library. It's a simple way to declare a class with typed variables, and get useful helper functions added to the class "for free":

from dataclasses import *

@dataclass
class Fraction:
    numerator: int = 0
    denominator: int = 1

With @dataclass, we get free constructors for the member attributes. There are 3 ways to get a new Fraction object:

f = Fraction() # defaults, same as Fraction(0, 1)
f = Fraction(1, 2)
f = Fraction(numerator=1, denominator=2)

Note that the second and third would not work without @dataclass, we would get TypeError: Fraction() takes no arguments. We also get free equality checks (it compares each of the members), like:

f = Fraction(1, 2)
g = Fraction(1, 2)
f == g # True

Note that without @dataclass, the == equality would be False. We also get a useful string representation:

print(Fraction(1, 2)) # prints: Fraction(numerator=1, denominator=2)

By specifying order=True in the @dataclass decorator, we also get free <, <=, >, >= comparisons:

@dataclass(order=True)
class Fraction:
    numerator: int = 0
    denominator: int = 1

f = Fraction(1, 2)
g = Fraction(2, 3)
f <= g # True

The free comparison functions aren't very useful for the Fraction example, since they just do a attribute-wise comparison. So eg. f < g is the same as f.numerator < g.numerator and f.denominator < g.denominator, which mathematically isn't the right expression.

By specifying frozen=True in the decorator arguments, we get read-only objects, which are useful in eg. multithreading:

@dataclass(frozen=True)
class Fraction:
    numerator: int = 0
    denominator: int = 1

f = Fraction(1, 2)
f.numerator = 2 # FrozenInstanceError: cannot assign to field 'numerator'

Finally, at the top-level of dataclasses are 2 useful functions, asdict() and astuple():

asdict(Fraction(1, 2))  # {'numerator': 1, 'denominator': 2}
astuple(Fraction(1, 2)) # (1, 2)

The Python standard library @dataclass has a lot more features, but these are some of the most frequently used ones. Let's practice our Python fu and create a toy implementation of @dataclass. As seen before, the general skeleton will be:

def dataclass():
    def decorator(cls):
        # add useful features to cls
        ...
        return cls
    return decorator

Finding the annotated member attributes

First we have to locate the annotated attributes declared by the user in the class object. The trick is telling them apart from all the built-in functions and attributes and any user defined functions. We can check in cls.__dict__ and dir(cls):

class Fraction:
    numerator: int = 0
    denominator: int = 1
    def mul(self, x: int):
        self.numerator *= x

print(Fraction.__dict__)
print()
print(dir(Fraction))

Output:

{'__module__': '__main__', '__annotations__': {'numerator': <class 'int'>, 'denominator': <class 'int'>}, 'numerator': 0, 'denominator': 1, '__dict__': <attribute '__dict__' of 'Fraction' objects>, '__weakref__': <attribute '__weakref__' of 'Fraction' objects>, '__doc__': None}

['__annotations__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'denominator', 'numerator']

Long story short, the right place to look is cls.__annotations__:

print(Fraction.__annotations__)

Output:

{'numerator': <class 'int'>, 'denominator': <class 'int'>}

Building the constuctor

Now that we found the attributes, let's build the constructor:

def dataclass(...):
    def decorator(cls):
        cls._attribs = cls.__annotations__.keys()
        # define initializer, unless defined by the user
        if '__init__' not in cls.__dict__:
            def __init__(self, *args, **kwargs):
                if len(args) > 0:
                    for attrib, arg in zip(self.__class__._attribs, args):
                        # avoid our own __setattr__ in case it's a frozen dataclass:
                        object.__setattr__(self, attrib, arg)                       
                elif len(kwargs) > 0:
                    for attrib in self.__class__._attribs:
                        # avoid our own __setattr__ in case it's a frozen dataclass:
                        object.__setattr__(self, attrib, kwargs[attrib])
            cls.__init__ = __init__
        ...
        return cls

First we save the attribute keys in cls._attribs, so we can conveniently access it in subsequent functions.

We check if the user explicitly defined their own constructor with if '__init__' not in cls.__dict__. If not, then we define the function and assign it with cls.__init__ = __init__ at the end. The constructor function itself takes *args and **kwargs. If the user calls it like Fraction(1, 2), then args will be (1, 2) and kwargs will be an empty dict {}. If the user calls it like Fraction(numerator=1, denominator=2), then args will be () and kwargs will be a dict like {'numerator': 1, 'denominator': 2}. So we just need to get these, and assign them to the appropriate member variable with object.__setattr__(), to avoid using cls.__setattr__() in case this is a frozen dataclass (see later).

String conversion

String conversion is relatively straightforward:

    def decorator(cls):
        cls._attribs = cls.__annotations__.keys()
        ...
        # define string conversion, unless defined by th user
        if '__str__' not in cls.__dict__:
            def __str__(self):
                kv_tuples = [(attrib, getattr(self, attrib)) for attrib in self.__class__._attribs]
                kv_str = ', '.join([f'{k}={v}' for (k, v) in kv_tuples])
                return f'{self.__class__.__name__}({kv_str})'
            cls.__str__ = __str__

Equality and comparators

Next, let's write equality and comparators. These can be switched on and off by passing the appropriate arguments to the decorator, so we have to check these:

def dataclass(..., **kwargs):
    def decorator(cls):
        ...
        if kwargs.get('eq', True):
            # define ==, unless defined by the user
            if '__eq__' not in cls.__dict__:
                def __eq__(self, other):
                    for attrib in self.__class__._attribs:
                        if getattr(self, attrib) != getattr(other, attrib):
                            return False
                    return True            
                cls.__eq__ = __eq__

Operators for <, <=, >, >= follow the same logic, showing just the first:

        if kwargs.get('order', False):
            # define <, <=, >, >=, unless defined by the user
            if '__lt__' not in cls.__dict__:
                def __lt__(self, other):
                    for attrib in self.__class__._attribs:
                        if getattr(self, attrib) >= getattr(other, attrib):
                            return False
                    return True            
                cls.__lt__ = __lt__
        ...

Frozen

If frozen=True is passed to the decorator, we have to disallow assignment to attributes, ie. f.numerator = 1 should raise an exception. When we write f.numerator = 1, internally this becomes Fraction.__setattr__(f, 'numerator', 1). So to disallow this, we just have to define our own __setattr__():

        if kwargs.get('frozen', False):
            # don't allow changing attributes
            def __setattr__(self, attrib, value):
                if attrib not in self.__class__._attribs:
                    setattr(self, attrib, value)
                else:
                    raise AttributeError(f'dataclass is frozen, cannot assign to field \'{attrib}\'')
            cls.__setattr__ = __setattr__

@dataclass vs @dataclass(...)

We need one final trick. When using this decorator, we want to allow users to write both @dataclass and @dataclass(...). In the first case, the class after the decorator gets passed directly to the dataclass function. In the second case, dataclass() is first called with the arguments (without the following class), and that must return a function which accepts the class. We want to support both behaviours in a single decorator function. The solution is quite simple:

def dataclass(cls=None, **kwargs):
    def decorator(cls):
        cls._attribs = cls.__annotations__.keys()
        ...
        # this is where the code shown above is
        ...
        return cls
    if cls is None:
        # decorator was used like @dataclass(...)
        return decorator
    else:
        # decorator was used like @dataclass, without parens
        return decorator(cls)

Complete toy implementation

That's it! The complete toy implementation, including asdict() and astuple() is:

def asdict(x):
    return {a:getattr(x, a) for a in x.__class__._attribs}

def astuple(x):
    return tuple(getattr(x, a) for a in x.__class__._attribs)

def dataclass(cls=None, **kwargs):
    def decorator(cls):
        cls._attribs = cls.__annotations__.keys()
        # define initializer, unless defined by the user
        if '__init__' not in cls.__dict__:
            def __init__(self, *args, **kwargs):
                if len(args) > 0:
                    for attrib, arg in zip(self.__class__._attribs, args):
                        # avoid our own __setattr__ in case it's a frozen dataclass:
                        object.__setattr__(self, attrib, arg)                       
                elif len(kwargs) > 0:
                    for attrib in self.__class__._attribs:
                        # avoid our own __setattr__ in case it's a frozen dataclass:
                        object.__setattr__(self, attrib, kwargs[attrib])
            cls.__init__ = __init__
        # define string conversion, unless defined by th user
        if '__str__' not in cls.__dict__:
            def __str__(self):
                kv_tuples = [(attrib, getattr(self, attrib)) for attrib in self.__class__._attribs]
                kv_str = ', '.join([f'{k}={v}' for (k, v) in kv_tuples])
                return f'{self.__class__.__name__}({kv_str})'
            cls.__str__ = __str__
        if kwargs.get('eq', True):
            # define ==, unless defined by the user
            if '__eq__' not in cls.__dict__:
                def __eq__(self, other):
                    for attrib in self.__class__._attribs:
                        if getattr(self, attrib) != getattr(other, attrib):
                            return False
                    return True            
                cls.__eq__ = __eq__
        if kwargs.get('order', False):
            # define <, <=, >, >=, unless defined by the user
            if '__lt__' not in cls.__dict__:
                def __lt__(self, other):
                    for attrib in self.__class__._attribs:
                        if getattr(self, attrib) >= getattr(other, attrib):
                            return False
                    return True            
                cls.__lt__ = __lt__
            if '__le__' not in cls.__dict__:
                def __le__(self, other):
                    for attrib in self.__class__._attribs:
                        if getattr(self, attrib) > getattr(other, attrib):
                            return False
                    return True            
                cls.__le__ = __le__
            if '__gt__' not in cls.__dict__:
                def __gt__(self, other):
                    for attrib in self.__class__._attribs:
                        if getattr(self, attrib) <= getattr(other, attrib):
                            return False
                    return True
                cls.__gt__ = __gt__
            if '__ge__' not in cls.__dict__:
                def __ge__(self, other):
                    for attrib in self.__class__._attribs:
                        if getattr(self, attrib) < getattr(other, attrib):
                            return False
                    return True
                cls.__ge__ = __ge__
        if kwargs.get('frozen', False):
            # don't allow changing attributes
            def __setattr__(self, attrib, value):
                if attrib not in self.__class__._attribs:
                    setattr(self, attrib, value)
                else:
                    raise AttributeError(f'dataclass is frozen, cannot assign to field \'{attrib}\'')
            cls.__setattr__ = __setattr__
        return cls
    if cls is None:
        # decorator was used like @dataclass(...)
        return decorator
    else:
        # decorator was used like @dataclass, without parens
        return decorator(cls)

With this toy implementation, the examples shown at the beginning work just like the Python standard library @dataclass, whose implementation is in dataclasses.py (it's 1488 lines of code).

Conclusion

Similar to the previous articles, this is a good way to improve one's Python fu.