credit.trainers.trainerWRF#

Attributes#

Classes#

Trainer

Helper class that provides a standard way to create an ABC using

Module Contents#

credit.trainers.trainerWRF.logger#
class credit.trainers.trainerWRF.Trainer(model: torch.nn.Module, rank: int)#

Bases: credit.trainers.base_trainer.BaseTrainer

Helper class that provides a standard way to create an ABC using inheritance.

train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler, scheduler, metrics)#

Train the model for one epoch.

Parameters:
  • epoch (int) – The current epoch number.

  • conf (Dict[str, Any]) – The configuration dictionary.

  • trainloader (torch.utils.data.DataLoader) – The training data loader.

  • optimizer (torch.optim.Optimizer) – The optimizer.

  • criterion (torch.nn.Module) – The loss function.

  • scaler (torch.cuda.amp.GradScaler) – The gradient scaler for mixed precision training.

  • scheduler (torch.optim.lr_scheduler.LRScheduler) – The learning rate scheduler.

  • metrics (Dict[str, Any]) – The metrics to track during training.

Returns:

A dictionary containing the training results.

Return type:

Dict[str, float]

validate(epoch, conf, valid_loader, criterion, metrics)#

Validate the model on the validation set.

Parameters:
  • epoch (int) – The current epoch number.

  • conf (Dict[str, Any]) – The configuration dictionary.

  • valid_loader (torch.utils.data.DataLoader) – The validation data loader.

  • criterion (torch.nn.Module) – The loss function.

  • metrics (Dict[str, Any]) – The metrics to track during validation.

Returns:

A dictionary containing the validation results.

Return type:

Dict[str, float]