credit.trainers.trainerWRF_multi#

Attributes#

Classes#

TrainerWRFMulti

Trainer class for handling the training, validation, and checkpointing of models.

Module Contents#

credit.trainers.trainerWRF_multi.logger#
class credit.trainers.trainerWRF_multi.TrainerWRFMulti(model: torch.nn.Module, rank: int, conf: dict)#

Bases: credit.trainers.base_trainer.BaseTrainer

Trainer class for handling the training, validation, and checkpointing of models.

This class is responsible for executing the training loop, validating the model on a separate dataset, and managing checkpoints during training. It supports both single-GPU and distributed (FSDP, DDP) training.

model#

The model to be trained.

Type:

torch.nn.Module

rank#

The rank of the process in distributed training.

Type:

int

module#

If True, use model with module parallelism (default: False).

Type:

bool

train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler,

scheduler, metrics):

Perform training for one epoch and return training metrics.

validate(epoch, conf, valid_loader, criterion, metrics)#

Validate the model on the validation dataset and return validation metrics.

fit_deprecated(conf, train_loader, valid_loader, optimizer, train_criterion,

valid_criterion, scaler, scheduler, metrics, trial=False):

Perform the full training loop across multiple epochs, including validation and checkpointing.

train_one_epoch(epoch, trainloader, optimizer, criterion, scaler, scheduler, metrics)#

Trains the model for one epoch.

Parameters:
  • epoch (int) – Current epoch number.

  • conf (dict) – Configuration dictionary containing training settings.

  • trainloader (DataLoader) – DataLoader for the training dataset.

  • optimizer (torch.optim.Optimizer) – Optimizer used for training.

  • criterion (callable) – Loss function used for training.

  • scaler (torch.cuda.amp.GradScaler) – Gradient scaler for mixed precision training.

  • scheduler (torch.optim.lr_scheduler._LRScheduler) – Learning rate scheduler.

  • metrics (callable) – Function to compute metrics for evaluation.

Returns:

Dictionary containing training metrics and loss for the epoch.

Return type:

dict

validate(epoch, valid_loader, criterion, metrics)#

Validates the model on the validation dataset.

Parameters:
  • epoch (int) – Current epoch number.

  • conf (dict) – Configuration dictionary containing validation settings.

  • valid_loader (DataLoader) – DataLoader for the validation dataset.

  • criterion (callable) – Loss function used for validation.

  • metrics (callable) – Function to compute metrics for evaluation.

Returns:

Dictionary containing validation metrics and loss for the epoch.

Return type:

dict