Example Experiment
Below is a fully functioning experiment script example that implements all the
necessary stages for training an sklearn model on the iris dataset. This script
implements both the run()
and a get_params()
, and is fully self
contained. This script can be found in the curifactory repo under
from dataclasses import dataclass
import curifactory as cf
from curifactory.caching import PickleCacher
from curifactory.reporting import JsonReporter
from sklearn.base import ClassifierMixin
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
class Args(cf.ExperimentArgs):
balanced: bool = False
"""Whether class weights should be balanced or not."""
n: int = 100
"""The number of trees for a random forest."""
seed: int = 42
"""The random state seed for data splitting and model training."""
model_type: ClassifierMixin = LogisticRegression
"""The sklearn model to use."""
test_percent: float = 0.25
"""The percentage of data to use for testing."""
inputs=None, outputs=["training_data", "testing_data"], cachers=[PickleCacher] * 2
def load_data(record):
args: Args = record.args
data = load_iris()
x_train, x_test, y_train, y_test = train_test_split(
data.data, data.target, test_size=args.test_percent, random_state=args.seed
return (x_train, y_train), (x_test, y_test)
@cf.stage(inputs=["training_data"], outputs=["model"], cachers=[PickleCacher])
def train_model(record, training_data):
args: Args = record.args
# set up common arguments from passed parameters
weight = "balanced" if args.balanced else None
model_args = dict(class_weight=weight, random_state=args.seed)
# set up model-specific from parameters
if type(args.model_type) == RandomForestClassifier:
# fit the parameterized model
clf = args.model_type(**model_args).fit(training_data[0], training_data[1])
return clf
@cf.aggregate(outputs=["scores"], cachers=None)
def test_models(record, records):
scores = {}
# iterate through every record and score its associated model
for prev_record in records:
if "model" in prev_record.state:
score = prev_record.state["model"].score(
# store the result keyed to the argument set name
scores[prev_record.args.name] = score
return scores
def get_params():
return [
Args(name="simple_lr", balanced=True, model_type=LogisticRegression, seed=1),
Args(name="simple_rf", model_type=RandomForestClassifier, seed=1),
def run(argsets, manager):
for argset in argsets:
record = cf.Record(manager, argset)
test_models(cf.Record(manager, None))