Parameter files and argsets
Another goal of Curifactory is to allow effective parameterization of experiments. Where this might normally be done with a json or yaml file, Curifactory directly uses python files for experiment parameterization/configuration. This has a few advantages:
Arguments can be any python object, rather than simply a primitive type or dictionary.
Parameter files can reference/use other parameter files, allowing modularity and composition.
The resulting arguments that are passed into an experiment can be algorithmically generated or modified inside an arguments script file, with the full power of the python language! An example for how this might be useful is a single arguments script that generates 10 very similar argument sets for comparison, rather than having to individually define 10 different parameter configuration files. This could allow custom gridsearches for example.
Note
Throughout this documentation, we refer to “paramset” and “argset” as slightly different. A “paramset” refers
to a whole parameters script file, while an “argset” refers to a single Args
instance. A single paramset
returns one or more argsets in a list.
The Args
class
As discussed on the Getting Started page, To define possible arguments, there should be a
class that inherits curifactory.ExperimentArgs
, and for ease of use should have the
@dataclass
decorator. By default, the cookiecutter project places an Args
class
for this inside of the params/__init__.py
. Possible arguments are the variables within
this class, and by defining default values for each one, this allows an arguments file to define
only what it needs to change from the defaults.
An example Args
class is shown below:
from dataclasses import dataclass, field
from typing import List
from curifactory import ExperimentArgs
@dataclass
class Args(ExperimentArgs):
example_arg: str = ""
example_number_of_epochs: int = 10
# due to how dataclasses handle initialization, default lists and dictionaries need to
# be handled with field factory from the dataclasses package.
example_data: List[int] = field(default_factory=lambda: [1,2,3,4])
The actual parameter files (by default go in the params/
folder) are then each expected to define a
get_params()
function, which should return a list of Args
instances. A very simple example based on
the above Args
class might look like:
from params import Args
def get_params():
return [Args(name='test_params', example_number_of_epochs=15)]
Note
As Args
is a completely user-defined class, you can technically name this class whatever you
choose. The rest of this documentation is written under the assumption it is named “Args”.
Warning
While the arguments in your dataclass can be arbitrary types, weird issues can sometimes arise if you include non-serializable objects. We’ve run into problems with things like including a pytorch distributed strategy object as an argument, as it can end up in a weird recursive serialization loop when curifactory tries to get a serialized JSON string representation of the corresponding arguments.
Programmatic definition
The get_params()
function can contain arbitrary code, and this is where advantages 2 and 3 listed above can be
exploited. For instance, if we wanted to define a set of parameters for testing multiple different numbers of epochs,
we could return a list of multiple Args
, each with a different epochs number:
from params import Args
def get_params():
args = []
for i in range(5, 15):
args.append(Args(name=f"epochs_run_{i}", example_number_of_epochs=i))
return args
If we wanted to make parameter sets compositional, we can import one of the other parameter files and
reference its get_params()
call in the new one:
from params import base, Args
def get_params():
args = base.get_params()
args[0].name = 'modified' # assuming we know there's only one Args instance (otherwise we do this in a loop)
args[0].starting_data = [0, 2, 4, 6]
return args
In the above example, there’s another parameters file named base
, we get its arguments with base.get_params()
,
run our modifications, and return the modified argsets. In this way, any changes that get made to the base parameters also influence
this one, allowing for a form of parameter set hierarchy.
We can also create common functions for helping build up large amounts of argsets. As an example, we may frequently
wish to create “seeded” argsets, where we have the same arguments several times but with a different seed for sklearn
models or similar. Rather than manually define this, or reimplementing it in every relevant get_params()
function,
we could extract it as in this example:
from copy import deepcopy
from params import Args
def seed_set(args: Args, seed_count: int = 5):
seed_args = []
for i in range(seed_count):
# Make a copy of the passed args and apply a different seed
new_args = deepcopy(args)
new_args.name += f"_seed{i}"
new_args.seed = i
seed_args.append(new_args)
return seed_args
from params import Args
from params.common import seed_set
def get_params():
knn_args = Args(name="test_knn", model_type="knn")
svm_args = Args(name="test_svm", model_type="svm")
all_args = []
all_args.extend(seed_set(knn_args))
all_args.extend(seed_set(svm_args, 3))
return all_args
Using args
Every stage automatically has access to the currently relevant Args
instance, as it is part of
the passed record.
from curifactory import Record
import params
import src
@stage(['training_data'], ['model'])
def train_model(record: Record, training_data):
args: params.Args = record.args # use the type hinting to get good autocomplete in IDEs
if args.model_type == "knn":
# pass relevant args into the codebase functions
src.train_knn(args.seed)
# ...