dags provides tools to combine several interrelated functions into one function. The order in which the functions are called is determined by a topological sort on a dag that is constructed from the function signatures. You can specify which of the function results will be returned in the combined function.
dags is a tiny library, all the hard work is done by the great NetworkX.
To understand what dags does, let's look at a very simple example of a few functions that do simple calculations.
def f(x, y):
return x**2 + y**2
def g(y, z):
return 0.5 * y * z
def h(f, g):
return g / f
Assume that we are interested in a function that calculates h, given x, y and z.
We could hardcode this function as:
def hardcoded_combined(x, y, z):
_f = f(x, y)
_g = g(y, z)
return h(_f, _g)
hardcoded_combined(x=1, y=2, z=3)
0.6
Instead, we can use dags to construct the same function:
from dags import concatenate_functions
combined = concatenate_functions([h, f, g], targets="h")
combined(x=1, y=2, z=3)
0.6
More examples can be found in the documentation
- The dag is constructed while the combined function is created and does not cause too much overhead when the function is called.
- If all individual functions are jax compatible, the combined function is jax compatible.
- When jitted or vmapped with jax, we have not seen any performance loss compared to hard coding the combined function.
- When there is more than one target, you can determine whether the result is returned as tuple, list or dict or pass in an aggregator to combine the multiple outputs.
- Since the relationships are discoverd from function signatures, dags provides decorators to rename arguments in order to make it easy to wrap functions you do not control yourself.
dags is available on PyPI and conda-forge. Install it with
$ pip install dags
# or
$ pixi add dags
# or
$ conda install -c conda-forge dags
The documentation is hosted on Read the Docs.