Building a toy Python Enum class - Part II

Marton Trencseni - Thu 05 May 2022 - Python

Introduction

In the previous article I started to write a toy implementation of Python's Enum class. Here I will continue and add more features. The posts are an exercise in how to use Python's language features to build an easy-to-use interface, in this case a class which resembles old-school C++ enums.

The code for this post is up on Github. You should also check the official cpython implementation of Enum here on Github (it's 2018 lines of code).

Leaner auto()

A friend pointed out that the way I implemented auto() is quite inefficient and unneceassary! It can be simplified to:

class auto:
    pass

Also, we can change our code slightly so users can write both auto() or just auto when using this feature. In one case we will find an object of type auto, in the other case we will find the type auto itself:

class EnumMeta(type):
    def __new__(metacls, cls, bases, classdict, **kwds):
        enumerations = {x: y for x, y in classdict.items() if not x.startswith('__')}
        # handle auto()
        next_value = 1
        for k, v in enumerations.items():
            if type(v) is auto or v is auto: # <--------------
                enumerations[k] = next_value
                next_value += 1
            else:
                next_value = v + 1
        enum = super().__new__(metacls, cls, bases, classdict, **kwds)
        enum._enumerations = enumerations
        return enum
    ...

class Color(Enum):
    RED: int = auto
    GREEN: int = auto
    BLUE: int = auto

String representations

We want both Enum classes (such as Color) and Enum objects (such as Color(1)) to have nice string representations, like with the standard library Enum. This is easy:

class EnumMeta(type):
    ...
    def __str__(cls):
        return f'<enum \'{cls.__name__}\'>'
    ...

class Enum(metaclass=EnumMeta):
    ...
    def __str__(self):
        return "%s.%s" % (self.__class__.__name__, self.__key)
    ...

Note that with EnumMeta, the first argument is cls (not self) which is a type object. This gets called when the Color class itself gets printed out:

class Color(Enum):
    RED: int = auto()
    GREEN: int = auto()
    BLUE: int = auto()

print(Color)          # __str__() in EnumMeta
print(type(Color(1))) # __str__() in EnumMeta
print(Color(1))       # __str__() in Enum

Output:

<enum 'Color'>
<enum 'Color'>
Color.RED

Pretty!

Accessors

Next, let's make it so Color['RED'] and Color.RED returns a appropriate color object. The first one, Color['RED'] is quite simple: when we use [] on an object, the class' __getitem__() is called. For Color['RED'], the object is the class itself, it's type is EnumMeta. So all we have to do is:

class EnumMeta(type):
    def __getitem__(cls, key):
        return cls(cls._enumerations[key])

Making Color.RED work is a bit more tricky. The basic idea is simple: when we use the dot . on an object, the class' __getattr__() and/or __getattribute__() is called. To understand the difference between these two, check this post. __getattr__() is called if the attribute is not found, and the user wishes to return a computed value. In our case, Color.RED is defined, so we have to override __getattribute__(), which is always called by Python when accessing an attribute with dot operator. So, it's natural to think that a copy/paste version of the above __getitem__() will work here:

class EnumMeta(type):
    ...
    def __getattribute__(cls, key):
        return cls(cls._enumerations[key]) # recursively call itself, kills the kernel!

The problem is that we use the dot operator all over the place, including the EnumMeta constructor:

enum._enumerations = enumerations

What we have to do is to explicitly use Python's built-in __getattribute__() function, like this:

class EnumMeta(type):
    def __getattribute__(cls, key):
        if key.startswith('_'):
            return object.__getattribute__(cls, key)
        else:
            return cls(object.__getattribute__(cls, '_enumerations')[key])
    ...

What this does is, if we are trying to access an attribute which starts with an underscore, such as __class__ or our own _enumerations, it routes it through the usual channel of the built-in Python attribute look-up. But in other cases, such as Color.RED, it returns a new object constructed on the fly: cls(...) in the last line would be Color(...), object.__getattribute__() is used to avoid recursively entering this function, and _enumerations[key] is our own helper dictionary where we store the Enum's cases. In the case of Color.RED, key is RED, _enumerations[key] is 1, so the whole thing becomes return Color(1).

Membership and equality

Another useful feature of the standard library Enum is the ability to check Color.RED in Color, and of course Color.RED == Color.GREEN checks. These are easy:

class EnumMeta(type):
    ...
    def __contains__(cls, other):
        if type(other) == cls:
            return True
        else:
            return False  

print(Color['RED'] in Color)    # True
print(Color.RED in Color)       # True
print(Color(1) in Color)        # True
print(Color.RED in WeekendDay)  # False
print(Color.BLUE in WeekendDay) # False

Equality checks:

class Enum(metaclass=EnumMeta):
    ...
    def __eq__(self, other):
        if type(self) != type(other):
            return False
        else:
            return (self.__key == other.__key and self.__value == other.__value)

print(Color.RED == Color)               # False
print(Color.RED == Color.RED)           # True
print(Color.RED == Color['RED'])        # True
print(Color.RED == Color(1))            # True
print(Color.RED == Color(2))            # False
print(Color.RED == WeekendDay.SATURDAY) # False

Static objects

There is an issue with the implementation so far:

print(Color.RED is Color.RED)    # False
print(Color.RED is Color['Red']) # False
print(Color.RED is Color(1))     # False

At least the first one should be True, and probably all of them (in the standard library version, they're all True). Unlike equality, Python's is checks whether the variables are referring to the same object. But in our implementation so far, all of these calls return a new object (by doing cls(...)). This both breaks the above checks, and is also wasteful. Let's fix this.

Let's create one instance for each of the Enum cases (RED, GREEN, BLUE in the example) up front, and always return references to these objects, in a transparent way, so the user doesn't notice it. The first step is relatively easy, we create these static instances when the class is defined:

class EnumMeta(type):
    def __new__(metacls, cls, bases, classdict, **kwds):
        ...
        # make "static" instances of each enumeration object
        enum._instances = {k: object.__new__(enum) for k, v in enumerations.items()}
        # initialize static instances
        for k, instance in enum._instances.items():
            instance.__init__(enumerations[k])
        ...

Now we have to change our code to always return these saved instances. In our toy Enum class, there are several ways to get an Enum object such as Color: Color.RED, Color['RED'], Color(1). Making the first two return the saved instances is easy, because we already explicitly control what happens in those cases with our __getitem__() and __getattribute__() implementation. But what about when the user explicitly calls the constructor like Color(1), which we also do currently in our functions when returning Enum objects. If we get the constructor to return the static instances, we're done.

In Python, when the user calls the constuctor, two things happen. First, the class' __new__() is called to construct the object (so this has a return value), and then __init__() is called to initialize the already created object, which is conventionally called self (this has no return value, as self is already a given). Clearly here we have to write our custom __new__(), and also avoid unneceassary duplicate initialization in __init__(). The cleanest way to make this work that I have found is to override the newly created's Enum's' __new__() when the type is being created:

class EnumMeta(type):
    def __new__(metacls, cls, bases, classdict, **kwds):
        ...
        # overwrite the new Enum's __new__() so that is returns the static instances
        enum.__new__ = lambda cls, value: enum._instances[reverse_enumerations[value]]
        return enum

Short circuiting the __init__() is an easy further optimization:

class Enum(metaclass=EnumMeta):    
    def __init__(self, value):
        if hasattr(self, '_Enum__key') and hasattr(self, '_Enum__value'):
            return
        ...

The attributes are called like _Enum__key because of Python's mangling of "private class members".

By inserting some print() statements, we can verify that our code works as intended:

class Color(Enum):
    RED: int = auto
    GREEN: int = auto
    BLUE: int = auto

class WeekendDay(Enum):
    SATURDAY: int = auto
    SUNDAY: int = auto

print(Color.RED == Color)               # False
print(Color.RED == Color.RED)           # True
print(Color.RED is Color.RED)           # True
print(Color.RED == Color['RED'])        # True
print(Color.RED == Color(1))            # True
print(Color.RED == Color(2))            # False
print(Color.RED == WeekendDay.SATURDAY) # False
print(Color.RED is Color.RED)           # True
print(Color.RED is Color['RED'])        # True
print(Color.RED is Color(1))            # True
for _ in range(100):
    c = Color(3)
    d = WeekendDay(2)

Output:

Constructing new type Enum
Constructing new type Color
Initializing new <enum 'Color'>
Initializing new <enum 'Color'>
Initializing new <enum 'Color'>
Constructing new type WeekendDay
Initializing new <enum 'WeekendDay'>
Initializing new <enum 'WeekendDay'>
False
True
True
True
True
False
False
True
True
True

First the Enum type itself is constructed, then Color, the 3 Color instances are created, then the WeekendDay type is constructed, 2 WeekendDay instances are created. The actual code itself doesn't create any more objects!

Final version

The final version is 71 lines of code:

class auto():
    pass

class EnumMeta(type):
    def __new__(metacls, cls, bases, classdict, **kwds):
        enumerations = {x: y for x, y in classdict.items() if not x.startswith('__')}
        # handle auto() and auto
        next_value = 1
        for k, v in enumerations.items():
            if type(v) is auto or v is auto:
                enumerations[k] = next_value
                next_value += 1
            else:
                next_value = v + 1
        enum = super().__new__(metacls, cls, bases, classdict, **kwds)
        enum._enumerations = enumerations
        reverse_enumerations = {y: x for x, y in enumerations.items()}
        # make "static" instances of each enumeration object
        enum._instances = {k: object.__new__(enum) for k, v in enumerations.items()}
        # initialize static instances
        for k, instance in enum._instances.items():
            instance.__init__(enumerations[k])
        # overwrite the new Enum's __new__() so that is returns the static instances
        enum.__new__ = lambda cls, value: enum._instances[reverse_enumerations[value]]
        return enum

    def __str__(cls):
        return f'<enum \'{cls.__name__}\'>'

    def __len__(cls):
        return len(cls._enumerations)

    def __iter__(cls):
        return (cls(value) for value in cls._enumerations.values())

    def __getitem__(cls, key):
        return cls(cls._enumerations[key])

    def __getattribute__(cls, key):
        if key.startswith('_'):
            return object.__getattribute__(cls, key)
        else:
            return cls(object.__getattribute__(cls, '_enumerations')[key])

    def __contains__(cls, other):
        if type(other) == cls:
            return True
        else:
            return False

class Enum(metaclass=EnumMeta):    
    def __init__(self, value):
        if hasattr(self, '_Enum__key') and hasattr(self, '_Enum__value'):
            return
        # make sure the passed in value is a valid enumeration value
        if value not in self.__class__._enumerations.values():
            raise ValueError(f'{value} is not a valid {self.__class__.__name__}')
        # save the actual enumeration value
        for k, v in self.__class__._enumerations.items():
            if v == value:
                self.__key = k
                self.__value = v

    def __str__(self):
        return "%s.%s" % (self.__class__.__name__, self.__key)

    def __eq__(self, other):
        if type(self) != type(other):
            return False
        else:
            return (self.__key == other.__key and self.__value == other.__value)

You should also check the official cpython implementation of Enum here on Github (it's 2018 lines of code).

Conclusion

This was a great exercise to improve and exercise my Python Fu. Highly recommended! Thanks Zsolt for bugfixes and improvement suggestions.