credit.trainers.base_trainer#

Content:
  • Trainer
    • train_one_epoch

    • validate

    • fit

Attributes#

Classes#

BaseTrainer

Helper class that provides a standard way to create an ABC using

Module Contents#

credit.trainers.base_trainer.logger#
class credit.trainers.base_trainer.BaseTrainer(model: torch.nn.Module, rank: int)#

Bases: abc.ABC

Helper class that provides a standard way to create an ABC using inheritance.

model#
rank#
device#
abstract train_one_epoch(epoch: int, conf: Dict[str, Any], trainloader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, scaler: torch.cuda.amp.GradScaler, scheduler: torch.optim.lr_scheduler.LRScheduler, metrics: Dict[str, Any]) Dict[str, float]#

Train the model for one epoch.

Parameters:
  • epoch (int) – The current epoch number.

  • conf (Dict[str, Any]) – The configuration dictionary.

  • trainloader (torch.utils.data.DataLoader) – The training data loader.

  • optimizer (torch.optim.Optimizer) – The optimizer.

  • criterion (torch.nn.Module) – The loss function.

  • scaler (torch.cuda.amp.GradScaler) – The gradient scaler for mixed precision training.

  • scheduler (torch.optim.lr_scheduler.LRScheduler) – The learning rate scheduler.

  • metrics (Dict[str, Any]) – The metrics to track during training.

Returns:

A dictionary containing the training results.

Return type:

Dict[str, float]

abstract validate(epoch: int, conf: Dict[str, Any], valid_loader: torch.utils.data.DataLoader, criterion: torch.nn.Module, metrics: Dict[str, Any]) Dict[str, float]#

Validate the model on the validation set.

Parameters:
  • epoch (int) – The current epoch number.

  • conf (Dict[str, Any]) – The configuration dictionary.

  • valid_loader (torch.utils.data.DataLoader) – The validation data loader.

  • criterion (torch.nn.Module) – The loss function.

  • metrics (Dict[str, Any]) – The metrics to track during validation.

Returns:

A dictionary containing the validation results.

Return type:

Dict[str, float]

save_checkpoint(save_loc: str, state_dict: Dict[str, Any]) None#

Save a checkpoint of the model.

Parameters:
  • save_loc (str) – The location to save the checkpoint.

  • state_dict (Dict[str, Any]) – The state dictionary to save.

save_fsdp_checkpoint(save_loc: str, state_dict: Dict[str, Any]) None#

Save a checkpoint for FSDP training.

Parameters:
  • save_loc (str) – The location to save the checkpoint.

  • state_dict (Dict[str, Any]) – The state dictionary to save.

fit(conf: Dict[str, Any], train_loader: torch.utils.data.DataLoader, valid_loader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, train_criterion: torch.nn.Module, valid_criterion: torch.nn.Module, scaler: torch.cuda.amp.GradScaler, scheduler: torch.optim.lr_scheduler.LRScheduler, metrics: Dict[str, Any], rollout_scheduler: callable | None = None, trial: bool = False) Dict[str, Any]#

Fit the model to the data.

Parameters:
  • conf (Dict[str, Any]) – Configuration dictionary.

  • train_loader (DataLoader) – DataLoader for training data.

  • valid_loader (DataLoader) – DataLoader for validation data.

  • optimizer (Optimizer) – The optimizer to use for training.

  • train_criterion (torch.nn.Module) – Loss function for training.

  • valid_criterion (torch.nn.Module) – Loss function for validation.

  • scaler (GradScaler) – Gradient scaler for mixed precision training.

  • scheduler (_LRScheduler) – Learning rate scheduler.

  • metrics (Dict[str, Any]) – Dictionary of metrics to track during training.

  • rollout_scheduler (Optional[callable]) – Function to schedule rollout probability, if applicable.

  • trial (bool) – Whether this is a trial run (e.g., for hyperparameter tuning).

Returns:

Dictionary containing the best results from training.

Return type:

Dict[str, Any]