Example Experiment
Below is a fully functioning experiment script example that implements all the
necessary stages for training an sklearn model on the iris dataset. This script
implements both the run()
and a get_params()
, and is fully self
contained. This script can be found in the curifactory repo under
examples/minimal/experiments/iris.py
.
from dataclasses import dataclass
import curifactory as cf
from curifactory.caching import PickleCacher
from curifactory.reporting import JsonReporter
from sklearn.base import ClassifierMixin
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
@dataclass
class Args(cf.ExperimentArgs):
balanced: bool = False
"""Whether class weights should be balanced or not."""
n: int = 100
"""The number of trees for a random forest."""
seed: int = 42
"""The random state seed for data splitting and model training."""
model_type: ClassifierMixin = LogisticRegression
"""The sklearn model to use."""
test_percent: float = 0.25
"""The percentage of data to use for testing."""
@cf.stage(
inputs=None, outputs=["training_data", "testing_data"], cachers=[PickleCacher] * 2
)
def load_data(record):
args: Args = record.args
data = load_iris()
x_train, x_test, y_train, y_test = train_test_split(
data.data, data.target, test_size=args.test_percent, random_state=args.seed
)
return (x_train, y_train), (x_test, y_test)
@cf.stage(inputs=["training_data"], outputs=["model"], cachers=[PickleCacher])
def train_model(record, training_data):
args: Args = record.args
# set up common arguments from passed parameters
weight = "balanced" if args.balanced else None
model_args = dict(class_weight=weight, random_state=args.seed)
# set up model-specific from parameters
if type(args.model_type) == RandomForestClassifier:
model_args.update(dict(n_estimators=args.n))
# fit the parameterized model
clf = args.model_type(**model_args).fit(training_data[0], training_data[1])
return clf
@cf.aggregate(outputs=["scores"], cachers=None)
def test_models(record, records):
scores = {}
# iterate through every record and score its associated model
for prev_record in records:
if "model" in prev_record.state:
score = prev_record.state["model"].score(
prev_record.state["testing_data"][0],
prev_record.state["testing_data"][1],
)
# store the result keyed to the argument set name
scores[prev_record.args.name] = score
print(scores)
record.report(JsonReporter(scores))
return scores
def get_params():
return [
Args(name="simple_lr", balanced=True, model_type=LogisticRegression, seed=1),
Args(name="simple_rf", model_type=RandomForestClassifier, seed=1),
]
def run(argsets, manager):
for argset in argsets:
record = cf.Record(manager, argset)
train_model(load_data(record))
test_models(cf.Record(manager, None))