eval_step(ds, batch_size, model, num_classes, loss_fn)

Eval step.

  • ds (Dataset) –

    dataset to iterate batches from.

  • batch_size (int) –

    size of each batch.

  • model (nn.Module) –

    model to train.

  • num_classes (int) –

    number of classes.

  • loss_fn (torch.nn.loss._WeightedLoss) –

    loss function to use between labels and predictions.

  • Tuple[float, np.array, np.array]

    Tuple[float, np.array, np.array]: cumulative loss, ground truths and predictions.

def eval_step(
    ds: Dataset, batch_size: int, model: nn.Module, num_classes: int, loss_fn: torch.nn.modules.loss._WeightedLoss
) -> Tuple[float, np.array, np.array]:  # pragma: no cover, tested via train workload
    """Eval step.

        ds (Dataset): dataset to iterate batches from.
        batch_size (int): size of each batch.
        model (nn.Module): model to train.
        num_classes (int): number of classes.
        loss_fn (torch.nn.loss._WeightedLoss): loss function to use between labels and predictions.

        Tuple[float, np.array, np.array]: cumulative loss, ground truths and predictions.
    loss = 0.0
    y_trues, y_preds = [], []
    ds_generator = ds.iter_torch_batches(batch_size=batch_size, collate_fn=utils.collate_fn)
    with torch.inference_mode():
        for i, batch in enumerate(ds_generator):
            z = model(batch)
            targets = F.one_hot(batch["targets"], num_classes=num_classes).float()  # one-hot (for loss_fn)
            J = loss_fn(z, targets).item()
            loss += (J - loss) / (i + 1)
            y_preds.extend(torch.argmax(z, dim=1).cpu().numpy())
    return loss, np.vstack(y_trues), np.vstack(y_preds)


Training loop that each worker will execute.

  • config (dict) –

    arguments to use for training.

def train_loop_per_worker(config: dict) -> None:  # pragma: no cover, tested via train workload
    """Training loop that each worker will execute.

        config (dict): arguments to use for training.
    # Hyperparameters
    dropout_p = config["dropout_p"]
    lr = config["lr"]
    lr_factor = config["lr_factor"]
    lr_patience = config["lr_patience"]
    batch_size = config["batch_size"]
    num_epochs = config["num_epochs"]
    num_classes = config["num_classes"]

    # Get datasets
    train_ds = session.get_dataset_shard("train")
    val_ds = session.get_dataset_shard("val")

    # Model
    llm = BertModel.from_pretrained("allenai/scibert_scivocab_uncased", return_dict=False)
    model = models.FinetunedLLM(llm=llm, dropout_p=dropout_p, embedding_dim=llm.config.hidden_size, num_classes=num_classes)
    model = train.torch.prepare_model(model)

    # Training components
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=lr_factor, patience=lr_patience)

    # Training
    batch_size_per_worker = batch_size // session.get_world_size()
    for epoch in range(num_epochs):
        # Step
        train_loss = train_step(train_ds, batch_size_per_worker, model, num_classes, loss_fn, optimizer)
        val_loss, _, _ = eval_step(val_ds, batch_size_per_worker, model, num_classes, loss_fn)

        # Checkpoint
        metrics = dict(epoch=epoch, lr=optimizer.param_groups[0]["lr"], train_loss=train_loss, val_loss=val_loss)
        checkpoint = TorchCheckpoint.from_model(model=model)
        session.report(metrics, checkpoint=checkpoint)

train_model(experiment_name=None, dataset_loc=None, train_loop_config=None, num_workers=1, cpu_per_worker=1, gpu_per_worker=0, num_samples=None, num_epochs=1, batch_size=256, results_fp=None)

Main train function to train our model as a distributed workload.

  • experiment_name (str) –

    name of the experiment for this training workload.

  • dataset_loc (str) –

    location of the dataset.

  • train_loop_config (str) –

    arguments to use for training.

  • num_workers (int) –

    number of workers to use for training. Defaults to 1.

  • cpu_per_worker (int) –

    number of CPUs to use per worker. Defaults to 1.

  • gpu_per_worker (int) –

    number of GPUs to use per worker. Defaults to 0.

  • num_samples (int) –

    number of samples to use from dataset. If this is passed in, it will override the config. Defaults to None.

  • num_epochs (int) –

    number of epochs to train for. If this is passed in, it will override the config. Defaults to None.

  • batch_size (int) –

    number of samples per batch. If this is passed in, it will override the config. Defaults to None.

  • results_fp (str) –

    filepath to save results to. Defaults to None.

  • ray.air.result.Result

    ray.air.result.Result: training results.

def train_model(
    experiment_name: Annotated[str, typer.Option(help="name of the experiment for this training workload.")] = None,
    dataset_loc: Annotated[str, typer.Option(help="location of the dataset.")] = None,
    train_loop_config: Annotated[str, typer.Option(help="arguments to use for training.")] = None,
    num_workers: Annotated[int, typer.Option(help="number of workers to use for training.")] = 1,
    cpu_per_worker: Annotated[int, typer.Option(help="number of CPUs to use per worker.")] = 1,
    gpu_per_worker: Annotated[int, typer.Option(help="number of GPUs to use per worker.")] = 0,
    num_samples: Annotated[int, typer.Option(help="number of samples to use from dataset.")] = None,
    num_epochs: Annotated[int, typer.Option(help="number of epochs to train for.")] = 1,
    batch_size: Annotated[int, typer.Option(help="number of samples per batch.")] = 256,
    results_fp: Annotated[str, typer.Option(help="filepath to save results to.")] = None,
) -> ray.air.result.Result:
    """Main train function to train our model as a distributed workload.

        experiment_name (str): name of the experiment for this training workload.
        dataset_loc (str): location of the dataset.
        train_loop_config (str): arguments to use for training.
        num_workers (int, optional): number of workers to use for training. Defaults to 1.
        cpu_per_worker (int, optional): number of CPUs to use per worker. Defaults to 1.
        gpu_per_worker (int, optional): number of GPUs to use per worker. Defaults to 0.
        num_samples (int, optional): number of samples to use from dataset.
            If this is passed in, it will override the config. Defaults to None.
        num_epochs (int, optional): number of epochs to train for.
            If this is passed in, it will override the config. Defaults to None.
        batch_size (int, optional): number of samples per batch.
            If this is passed in, it will override the config. Defaults to None.
        results_fp (str, optional): filepath to save results to. Defaults to None.

        ray.air.result.Result: training results.
    # Set up
    train_loop_config = json.loads(train_loop_config)
    train_loop_config["num_samples"] = num_samples
    train_loop_config["num_epochs"] = num_epochs
    train_loop_config["batch_size"] = batch_size

    # Scaling config
    scaling_config = ScalingConfig(
        resources_per_worker={"CPU": cpu_per_worker, "GPU": gpu_per_worker},

    # Checkpoint config
    checkpoint_config = CheckpointConfig(

    # MLflow callback
    mlflow_callback = MLflowLoggerCallback(

    # Run config
    run_config = RunConfig(

    # Dataset
    ds = data.load_data(dataset_loc=dataset_loc, num_samples=train_loop_config["num_samples"])
    train_ds, val_ds = data.stratify_split(ds, stratify="tag", test_size=0.2)
    tags = train_ds.unique(column="tag")
    train_loop_config["num_classes"] = len(tags)

    # Dataset config
    dataset_config = {
        "train": DatasetConfig(fit=False, transform=False, randomize_block_order=False),
        "val": DatasetConfig(fit=False, transform=False, randomize_block_order=False),

    # Preprocess
    preprocessor = data.CustomPreprocessor()
    train_ds = preprocessor.fit_transform(train_ds)
    val_ds = preprocessor.transform(val_ds)
    train_ds = train_ds.materialize()
    val_ds = val_ds.materialize()

    # Trainer
    trainer = TorchTrainer(
        datasets={"train": train_ds, "val": val_ds},

    # Train
    results = trainer.fit()
    d = {
        "timestamp": datetime.datetime.now().strftime("%B %d, %Y %I:%M:%S %p"),
        "run_id": utils.get_run_id(experiment_name=experiment_name, trial_id=results.metrics["trial_id"]),
        "params": results.config["train_loop_config"],
        "metrics": utils.dict_to_list(results.metrics_dataframe.to_dict(), keys=["epoch", "train_loss", "val_loss"]),
    logger.info(json.dumps(d, indent=2))
    if results_fp:  # pragma: no cover, saving results
        utils.save_dict(d, results_fp)
    return results

train_step(ds, batch_size, model, num_classes, loss_fn, optimizer)

Train step.

  • ds (Dataset) –

    dataset to iterate batches from.

  • batch_size (int) –

    size of each batch.

  • model (nn.Module) –

    model to train.

  • num_classes (int) –

    number of classes.

  • loss_fn (torch.nn.loss._WeightedLoss) –

    loss function to use between labels and predictions.

  • optimizer (torch.optimizer.Optimizer) –

    optimizer to use for updating the model's weights.

  • float( float ) –

    cumulative loss for the dataset.

def train_step(
    ds: Dataset,
    batch_size: int,
    model: nn.Module,
    num_classes: int,
    loss_fn: torch.nn.modules.loss._WeightedLoss,
    optimizer: torch.optim.Optimizer,
) -> float:  # pragma: no cover, tested via train workload
    """Train step.

        ds (Dataset): dataset to iterate batches from.
        batch_size (int): size of each batch.
        model (nn.Module): model to train.
        num_classes (int): number of classes.
        loss_fn (torch.nn.loss._WeightedLoss): loss function to use between labels and predictions.
        optimizer (torch.optimizer.Optimizer): optimizer to use for updating the model's weights.

        float: cumulative loss for the dataset.
    loss = 0.0
    ds_generator = ds.iter_torch_batches(batch_size=batch_size, collate_fn=utils.collate_fn)
    for i, batch in enumerate(ds_generator):
        optimizer.zero_grad()  # reset gradients
        z = model(batch)  # forward pass
        targets = F.one_hot(batch["targets"], num_classes=num_classes).float()  # one-hot (for loss_fn)
        J = loss_fn(z, targets)  # define loss
        J.backward()  # backward pass
        optimizer.step()  # update weights
        loss += (J.detach().item() - loss) / (i + 1)  # cumulative loss
    return loss