Useful Python decorators for Data Scientists
Marton Trencseni - Sun 22 May 2022 - Python
Introduction
In this post, I will show some @decorators
that may be useful for Data Scientists. It may also be useful to revisit previous Bytepawn posts on decorators:
- Building a toy Python @dataclass decorator
- Python decorator patterns
- all Bytepawn posts tagged with python
The ipython notebook is up on Github.
@parallel
Let's assume I write a really inefficient way to find primes:
from sympy import isprime
def generate_primes(domain: int=1000*1000, num_attempts: int=1000) -> list[int]:
primes: set[int] = set()
seed(time())
for _ in range(num_attempts):
candidate: int = randint(4, domain)
if isprime(candidate):
primes.add(candidate)
return sorted(primes)
print(len(generate_primes()))
Outputs something like:
88
Then I realize that I could get a "free" speedup if I run the original generate_primes()
on all my CPU threads in parallel. This is pretty common, it makes sense to define a @parallel
:
def parallel(func=None, args=(), merge_func=lambda x:x, parallelism = cpu_count()):
def decorator(func: Callable):
def inner(*args, **kwargs):
results = Parallel(n_jobs=parallelism)(delayed(func)(*args, **kwargs) for i in range(parallelism))
return merge_func(results)
return inner
if func is None:
# decorator was used like @parallel(...)
return decorator
else:
# decorator was used like @parallel, without parens
return decorator(func)
With this, with one line we can parallelize our function:
@parallel(merge_func=lambda li: sorted(set(chain(*li))))
def generate_primes(...): # same signature, nothing changes
... # same code, nothing changes
print(len(generate_primes()))
Outputs something like:
1281
In my case, my Macbook has 8 cores, 16 threads (cpu_count()
is 16), so I generated 16x as many primes. Notes:
- The only overhead is having to define a
merge_func
, which merges the results of the different runs of the function into one result, to hide the parallelism from outside callers of the decorated function (generate_primes()
in this case). In this toy example, I just merge the lists and make sure the primes are uniques by usingset()
. - There are many Python libraries and approches (eg. threads vs processes) to achieve parallelism. This example uses process parallelism with
joblib.Parallel()
, which works well on Darwin + python3 + ipython and avoids locking on the Python Global Interpreter Lock (GIL).
@production
Sometimes we write a big complicated pipeline, with extra steps which we only want to run in certain environments. Eg. do something on our local dev environment, but not in production or vica versa. It'd be nice to be able to decorate functions and get them to only run in certain environments, and do nothing elsewhere.
One way to achieve this is with a few simple decorators: @production
for stuff we want to only run on prod, @development
for stuff we only want to run in dev, we can even introduce an @inactive
which just turns the function off altogether. The benefit of this approach is that this way the deployment history and current state is tracked in code/Github. Also, we can make these changes in one line, leading to cleaner commits; eg. @inactive
is cleaner than a big commit where an entire block of code is commented out.
production_servers = [...]
def production(func: Callable):
def inner(*args, **kwargs):
if gethostname() in production_servers:
return func(*args, **kwargs)
else:
print('This host is not a production server, skipping function decorated with @production...')
return inner
def development(func: Callable):
def inner(*args, **kwargs):
if gethostname() not in production_servers:
return func(*args, **kwargs)
else:
print('This host is a production server, skipping function decorated with @development...')
return inner
def inactive(func: Callable):
def inner(*args, **kwargs):
print('Skipping function decorated with @inactive...')
return inner
@production
def foo():
print('Running in production, touching databases!')
foo()
@development
def foo():
print('Running in production, touching databases!')
foo()
@inactive
def foo():
print('Running in production, touching databases!')
foo()
Output:
Running in production, touching databases!
This host is a production server, skipping function decorated with @development...
Skipping function decorated with @inactive...
This idea can be adapted to other frameworks/environments.
@deployable
At my current work, we use Airflow for ETL/data pipelines. We have a rich library of helper functions which internally construct the appropriate DAG, so users (Data Scientists) don't have to worry about it.
The most commonly used one is dag_vertica_create_table_as()
, which runs a SELECT
on our Vertica DWH and dumps the result into a table every night:
dag = dag_vertica_create_table_as(
table='my_aggregate_table',
owner='Marton Trencseni (marton.trencseni@maf.ae)',
schedule_interval='@daily',
...
select="""
SELECT
...
FROM
...
"""
)
This then becomes a query on the DWH, roughly like:
CREATE TABLE my_aggregate_table AS
SELECT ...
In reality it's more complicated: we first run the query for today, and conditionally delete yesterday's if today's was successfully created. This conditional logic (and some other accidental complexity specific to our environment, such as having to issue GRANT
s) results in the DAG having 9 steps, but this is not the point here, and is beyond the scope of the article.
Over the last 2 years we have created almost 500 DAGs, so we scaled up our Airflow EC2 instances and introduced seperate development and production environments. It'd be nice to have a way to tag DAGs whether they should be running on dev or prod, track this in the code/Github, and use the same mechanism to make sure the DAGs don't accidentally run in the wrong environment.
There are about 10 similar convenience functions, such as dag_vertica_create_or_replace_view_as()
and dag_vertica_train_predict_model()
, etc, and we'd like all calls of these dag_xxx()
functions to be switchable between production and development (or skip everywhere).
However, the @production
and @development
decorators from the previous section won't work here, because we don't want to switch dag_vertica_create_table_as()
to never run on one of the environments. We want to be able to set it per invocation, and have this feature in all of our dag_xxxx()
functions, without having to copy/paste code. What we want is to add a deploy
parameter to all of our in all of our dag_xxxx()
functions (with a good default), so we can just add this parameter in our DAGs for added security. We can achieve this with the @deployable
decorator:
def deployable(func):
def inner(*args, **kwargs):
if 'deploy' in kwargs:
if kwargs['deploy'].lower() in ['production', 'prod'] and gethostname() not in production_servers:
print('This host is not a production server, skipping...')
return
if kwargs['deploy'].lower() in ['development', 'dev'] and gethostname() not in development_servers:
print('This host is not a development server, skipping...')
return
if kwargs['deploy'].lower() in ['skip', 'none']:
print('Skipping...')
return
del kwargs['deploy'] # to avoid func() throwing an unexpected keyword exception
return func(*args, **kwargs)
return inner
Then we can add the decorator to our function definitions (1 line added for each):
@deployable
def dag_vertica_create_table_as(...): # same signature, nothing changes
... # code signature, nothing changes
@deployable
def dag_vertica_create_or_replace_view_as(...): # same signature, nothing changes
... # code signature, nothing changes
@deployable
def dag_vertica_train_predict_model(...): # same signature, nothing changes
... # code signature, nothing changes
If we stop here, nothing happens, we don't break anything. However, now we can go to the DAG files where we use these functions, and add 1 line:
dag = dag_vertica_create_table_as(
deploy='development', # the function will return None on production
...
)
@redirect
(stdout)
Sometimes we write a big function, which also calls other code, and all sorts of messages are print()
ed. Or, we may have a bug, have a bunch of print()
s, and want to add line numbers to the printouts so it's easier to refer to them. In these cases, @redirect
may be useful. This decorator redirects print()
standard output to our own line-by-line printer, and we can do whatever we'd like with it (including throwing it away):
def redirect(func=None, line_print: Callable = None):
def decorator(func: Callable):
def inner(*args, **kwargs):
with StringIO() as buf, redirect_stdout(buf):
func(*args, **kwargs)
output = buf.getvalue()
lines = output.splitlines()
if line_print is not None:
for line in lines:
line_print(line)
else:
width = floor(log(len(lines), 10)) + 1
for i, line in enumerate(lines):
i += 1
print(f'{i:0{width}}: {line}')
return inner
if func is None:
# decorator was used like @redirect(...)
return decorator
else:
# decorator was used like @redirect, without parens
return decorator(func)
If we use redirect()
without specifying an explicit line_print()
function, it will print the lines, but with line numbers added:
@redirect
def print_lines(num_lines):
for i in range(num_lines):
print(f'Line #{i+1}')
print_lines(10)
Output:
01: Line #1
02: Line #2
03: Line #3
04: Line #4
05: Line #5
06: Line #6
07: Line #7
08: Line #8
09: Line #9
10: Line #10
If we want to save all printed text to a variable, we can also achieve that:
lines = []
def save_lines(line):
lines.append(line)
@redirect(line_print=save_lines)
def print_lines(num_lines):
for i in range(num_lines):
print(f'Line #{i+1}')
print_lines(3)
print(lines)
Output:
['Line #1', 'Line #2', 'Line #3']
The actual heavy lifting of redirecting stdout is done by contextlib.redirect_stdout
, as shown in this StackOverflow thread.
@stacktrace
The next decorator pattern is @stacktrace
, which emits useful messages when functions are called and values are returned from functions:
def stacktrace(func=None, exclude_files=['anaconda']):
def tracer_func(frame, event, arg):
co = frame.f_code
func_name = co.co_name
caller_filename = frame.f_back.f_code.co_filename
if func_name == 'write':
return # ignore write() calls from print statements
for file in exclude_files:
if file in caller_filename:
return # ignore in ipython notebooks
args = str(tuple([frame.f_locals[arg] for arg in frame.f_code.co_varnames]))
if args.endswith(',)'):
args = args[:-2] + ')'
if event == 'call':
print(f'--> Executing: {func_name}{args}')
return tracer_func
elif event == 'return':
print(f'--> Returning: {func_name}{args} -> {repr(arg)}')
return
def decorator(func: Callable):
def inner(*args, **kwargs):
settrace(tracer_func)
func(*args, **kwargs)
settrace(None)
return inner
if func is None:
# decorator was used like @stacktrace(...)
return decorator
else:
# decorator was used like @stacktrace, without parens
return decorator(func)
With this, we can decorate the topmost function where we want tracing to start, and we will get useful output about the branching:
def b():
print('...')
@stacktrace
def a(arg):
print(arg)
b()
return 'world'
a('foo')
Output:
--> Executing: a('foo')
foo
--> Executing: b()
...
--> Returning: b() -> None
--> Returning: a('foo') -> 'world'
The only trick here is hiding parts of the callstack which is not interesting. In my case, I'm running this code in ipython over Anaconda, so I hide parts of the callstack where the code is in a file which is has anaconda
in its path (otherwise I would get about 50-100 useless callstack entries in the snippet above). This is accomplished by the exclude_files
parameter of the decorator.
@traceclass
Similarly to the above, we can define a decorator @traceclass
which we use with classes, to get traces of its members' execution. This was included in the previous decorator post, there it was just called @trace
and had a bug (since fixed in the original post). The decorator:
def traceclass(cls: type):
def make_traced(cls: type, method_name: str, method: Callable):
def traced_method(*args, **kwargs):
print(f'--> Executing: {cls.__name__}::{method_name}()')
return method(*args, **kwargs)
return traced_method
for name in cls.__dict__.keys():
if callable(getattr(cls, name)) and name != '__class__':
setattr(cls, name, make_traced(cls, name, getattr(cls, name)))
return cls
We can use it like:
@traceclass
class Foo:
i: int = 0
def __init__(self, i: int = 0):
self.i = i
def increment(self):
self.i += 1
def __str__(self):
return f'This is a {self.__class__.__name__} object with i = {self.i}'
f1 = Foo()
f2 = Foo(4)
f1.increment()
print(f1)
print(f2)
Output:
--> Executing: Foo::__init__()
--> Executing: Foo::__init__()
--> Executing: Foo::increment()
--> Executing: Foo::__str__()
This is a Foo object with i = 1
--> Executing: Foo::__str__()
This is a Foo object with i = 4
Conclusion
In Python functions are first class citizens, and decorators are powerful syntactic sugar exploiting this functionality to give programmers a seemingly "magic" way to construct useful compositions of functions and classes. These are 5 decorators that may useful specifically for Data Scientists working in ipython notebooks.
Thanks Zsolt for bugfixes and improvement suggestions.