Useful Python decorators for Data Scientists

Marton Trencseni - Sun 22 May 2022 - Python

Introduction

In this post, I will show some @decorators that may be useful for Data Scientists. It may also be useful to revisit previous Bytepawn posts on decorators:

The ipython notebook is up on Github.

@parallel

Let's assume I write a really inefficient way to find primes:

from sympy import isprime

def generate_primes(domain: int=1000*1000, num_attempts: int=1000) -> list[int]:
    primes: set[int] = set()
    seed(time())
    for _ in range(num_attempts):
        candidate: int = randint(4, domain)
        if isprime(candidate):
            primes.add(candidate)
    return sorted(primes)

print(len(generate_primes()))

Outputs something like:

88

Then I realize that I could get a "free" speedup if I run the original generate_primes() on all my CPU threads in parallel. This is pretty common, it makes sense to define a @parallel:

def parallel(func=None, args=(), merge_func=lambda x:x, parallelism = cpu_count()):
    def decorator(func: Callable):
        def inner(*args, **kwargs):
            results = Parallel(n_jobs=parallelism)(delayed(func)(*args, **kwargs) for i in range(parallelism))
            return merge_func(results)
        return inner
    if func is None:
        # decorator was used like @parallel(...)
        return decorator
    else:
        # decorator was used like @parallel, without parens
        return decorator(func)

With this, with one line we can parallelize our function:

@parallel(merge_func=lambda li: sorted(set(chain(*li))))
def generate_primes(...): # same signature, nothing changes
    ... # same code, nothing changes

print(len(generate_primes()))

Outputs something like:

1281

In my case, my Macbook has 8 cores, 16 threads (cpu_count() is 16), so I generated 16x as many primes. Notes:

  • The only overhead is having to define a merge_func, which merges the results of the different runs of the function into one result, to hide the parallelism from outside callers of the decorated function (generate_primes() in this case). In this toy example, I just merge the lists and make sure the primes are uniques by using set().
  • There are many Python libraries and approches (eg. threads vs processes) to achieve parallelism. This example uses process parallelism with joblib.Parallel(), which works well on Darwin + python3 + ipython and avoids locking on the Python Global Interpreter Lock (GIL).

@production

Sometimes we write a big complicated pipeline, with extra steps which we only want to run in certain environments. Eg. do something on our local dev environment, but not in production or vica versa. It'd be nice to be able to decorate functions and get them to only run in certain environments, and do nothing elsewhere.

One way to achieve this is with a few simple decorators: @production for stuff we want to only run on prod, @development for stuff we only want to run in dev, we can even introduce an @inactive which just turns the function off altogether. The benefit of this approach is that this way the deployment history and current state is tracked in code/Github. Also, we can make these changes in one line, leading to cleaner commits; eg. @inactive is cleaner than a big commit where an entire block of code is commented out.

production_servers = [...]

def production(func: Callable):
    def inner(*args, **kwargs):
        if gethostname() in production_servers:
            return func(*args, **kwargs)
        else:
            print('This host is not a production server, skipping function decorated with @production...')
    return inner

def development(func: Callable):
    def inner(*args, **kwargs):
        if gethostname() not in production_servers:
            return func(*args, **kwargs)
        else:
            print('This host is a production server, skipping function decorated with @development...')
    return inner

def inactive(func: Callable):
    def inner(*args, **kwargs):
        print('Skipping function decorated with @inactive...')
    return inner

@production
def foo():
    print('Running in production, touching databases!')

foo()

@development
def foo():
    print('Running in production, touching databases!')

foo()

@inactive
def foo():
    print('Running in production, touching databases!')

foo()

Output:

Running in production, touching databases!
This host is a production server, skipping function decorated with @development...
Skipping function decorated with @inactive...

This idea can be adapted to other frameworks/environments.

@deployable

At my current work, we use Airflow for ETL/data pipelines. We have a rich library of helper functions which internally construct the appropriate DAG, so users (Data Scientists) don't have to worry about it.

The most commonly used one is dag_vertica_create_table_as(), which runs a SELECT on our Vertica DWH and dumps the result into a table every night:

dag = dag_vertica_create_table_as(
    table='my_aggregate_table',
    owner='Marton Trencseni (marton.trencseni@maf.ae)',
    schedule_interval='@daily',
    ...
    select="""
    SELECT
        ...
    FROM
        ...
    """
)

This then becomes a query on the DWH, roughly like:

CREATE TABLE my_aggregate_table AS
SELECT ...

In reality it's more complicated: we first run the query for today, and conditionally delete yesterday's if today's was successfully created. This conditional logic (and some other accidental complexity specific to our environment, such as having to issue GRANTs) results in the DAG having 9 steps, but this is not the point here, and is beyond the scope of the article.

Over the last 2 years we have created almost 500 DAGs, so we scaled up our Airflow EC2 instances and introduced seperate development and production environments. It'd be nice to have a way to tag DAGs whether they should be running on dev or prod, track this in the code/Github, and use the same mechanism to make sure the DAGs don't accidentally run in the wrong environment.

There are about 10 similar convenience functions, such as dag_vertica_create_or_replace_view_as() and dag_vertica_train_predict_model(), etc, and we'd like all calls of these dag_xxx() functions to be switchable between production and development (or skip everywhere).

However, the @production and @development decorators from the previous section won't work here, because we don't want to switch dag_vertica_create_table_as() to never run on one of the environments. We want to be able to set it per invocation, and have this feature in all of our dag_xxxx() functions, without having to copy/paste code. What we want is to add a deploy parameter to all of our in all of our dag_xxxx() functions (with a good default), so we can just add this parameter in our DAGs for added security. We can achieve this with the @deployable decorator:

def deployable(func):
    def inner(*args, **kwargs):
        if 'deploy' in kwargs:
            if kwargs['deploy'].lower() in ['production', 'prod'] and gethostname() not in production_servers:
                print('This host is not a production server, skipping...')
                return
            if kwargs['deploy'].lower() in ['development', 'dev'] and gethostname() not in development_servers:
                print('This host is not a development server, skipping...')
                return
            if kwargs['deploy'].lower() in ['skip', 'none']:
                print('Skipping...')
                return
            del kwargs['deploy'] # to avoid func() throwing an unexpected keyword exception
        return func(*args, **kwargs)
    return inner

Then we can add the decorator to our function definitions (1 line added for each):

@deployable
def dag_vertica_create_table_as(...): # same signature, nothing changes
    ... # code signature, nothing changes

@deployable
def dag_vertica_create_or_replace_view_as(...): # same signature, nothing changes
    ... # code signature, nothing changes

@deployable
def dag_vertica_train_predict_model(...): # same signature, nothing changes
    ... # code signature, nothing changes

If we stop here, nothing happens, we don't break anything. However, now we can go to the DAG files where we use these functions, and add 1 line:

dag = dag_vertica_create_table_as(
    deploy='development', # the function will return None on production
    ...
)

@redirect (stdout)

Sometimes we write a big function, which also calls other code, and all sorts of messages are print()ed. Or, we may have a bug, have a bunch of print()s, and want to add line numbers to the printouts so it's easier to refer to them. In these cases, @redirect may be useful. This decorator redirects print() standard output to our own line-by-line printer, and we can do whatever we'd like with it (including throwing it away):

def redirect(func=None, line_print: Callable = None):
    def decorator(func: Callable):
        def inner(*args, **kwargs):
            with StringIO() as buf, redirect_stdout(buf):
                func(*args, **kwargs)
                output = buf.getvalue()
            lines = output.splitlines()
            if line_print is not None:
                for line in lines:
                    line_print(line)
            else:
                width = floor(log(len(lines), 10)) + 1
                for i, line in enumerate(lines):
                    i += 1
                    print(f'{i:0{width}}: {line}')
        return inner
    if func is None:
        # decorator was used like @redirect(...)
        return decorator
    else:
        # decorator was used like @redirect, without parens
        return decorator(func)

If we use redirect() without specifying an explicit line_print() function, it will print the lines, but with line numbers added:

@redirect
def print_lines(num_lines):
    for i in range(num_lines):
        print(f'Line #{i+1}')

print_lines(10)

Output:

01: Line #1
02: Line #2
03: Line #3
04: Line #4
05: Line #5
06: Line #6
07: Line #7
08: Line #8
09: Line #9
10: Line #10

If we want to save all printed text to a variable, we can also achieve that:

lines = []
def save_lines(line):
    lines.append(line)

@redirect(line_print=save_lines)
def print_lines(num_lines):
    for i in range(num_lines):
        print(f'Line #{i+1}')

print_lines(3)
print(lines)

Output:

['Line #1', 'Line #2', 'Line #3']

The actual heavy lifting of redirecting stdout is done by contextlib.redirect_stdout, as shown in this StackOverflow thread.

@stacktrace

The next decorator pattern is @stacktrace, which emits useful messages when functions are called and values are returned from functions:

def stacktrace(func=None, exclude_files=['anaconda']):
    def tracer_func(frame, event, arg):
        co = frame.f_code
        func_name = co.co_name
        caller_filename = frame.f_back.f_code.co_filename
        if func_name == 'write':
            return # ignore write() calls from print statements
        for file in exclude_files:
            if file in caller_filename:
                return # ignore in ipython notebooks
        args = str(tuple([frame.f_locals[arg] for arg in frame.f_code.co_varnames]))
        if args.endswith(',)'):
            args = args[:-2] + ')'
        if event == 'call':
            print(f'--> Executing: {func_name}{args}')
            return tracer_func
        elif event == 'return':
            print(f'--> Returning: {func_name}{args} -> {repr(arg)}')
        return
    def decorator(func: Callable):
        def inner(*args, **kwargs):
            settrace(tracer_func)
            func(*args, **kwargs)
            settrace(None)
        return inner
    if func is None:
        # decorator was used like @stacktrace(...)
        return decorator
    else:
        # decorator was used like @stacktrace, without parens
        return decorator(func)

With this, we can decorate the topmost function where we want tracing to start, and we will get useful output about the branching:

def b():
    print('...')

@stacktrace
def a(arg):
    print(arg)
    b()
    return 'world'

a('foo')

Output:

--> Executing: a('foo')
foo
--> Executing: b()
...
--> Returning: b() -> None
--> Returning: a('foo') -> 'world'

The only trick here is hiding parts of the callstack which is not interesting. In my case, I'm running this code in ipython over Anaconda, so I hide parts of the callstack where the code is in a file which is has anaconda in its path (otherwise I would get about 50-100 useless callstack entries in the snippet above). This is accomplished by the exclude_files parameter of the decorator.

@traceclass

Similarly to the above, we can define a decorator @traceclass which we use with classes, to get traces of its members' execution. This was included in the previous decorator post, there it was just called @trace and had a bug (since fixed in the original post). The decorator:

def traceclass(cls: type):
    def make_traced(cls: type, method_name: str, method: Callable):
        def traced_method(*args, **kwargs):
            print(f'--> Executing: {cls.__name__}::{method_name}()')
            return method(*args, **kwargs)
        return traced_method
    for name in cls.__dict__.keys():
        if callable(getattr(cls, name)) and name != '__class__':
            setattr(cls, name, make_traced(cls, name, getattr(cls, name)))
    return cls

We can use it like:

@traceclass
class Foo:
    i: int = 0
    def __init__(self, i: int = 0):
        self.i = i
    def increment(self):
        self.i += 1
    def __str__(self):
        return f'This is a {self.__class__.__name__} object with i = {self.i}'

f1 = Foo()
f2 = Foo(4)
f1.increment()
print(f1)
print(f2)

Output:

--> Executing: Foo::__init__()
--> Executing: Foo::__init__()
--> Executing: Foo::increment()
--> Executing: Foo::__str__()
This is a Foo object with i = 1
--> Executing: Foo::__str__()
This is a Foo object with i = 4

Conclusion

In Python functions are first class citizens, and decorators are powerful syntactic sugar exploiting this functionality to give programmers a seemingly "magic" way to construct useful compositions of functions and classes. These are 5 decorators that may useful specifically for Data Scientists working in ipython notebooks.

Thanks Zsolt for bugfixes and improvement suggestions.