credit.trainers.trainerWRF_multi#
Attributes#
Classes#
Trainer class for handling the training, validation, and checkpointing of models. |
Module Contents#
- credit.trainers.trainerWRF_multi.logger#
- class credit.trainers.trainerWRF_multi.TrainerWRFMulti(model: torch.nn.Module, rank: int, conf: dict)#
Bases:
credit.trainers.base_trainer.BaseTrainerTrainer class for handling the training, validation, and checkpointing of models.
This class is responsible for executing the training loop, validating the model on a separate dataset, and managing checkpoints during training. It supports both single-GPU and distributed (FSDP, DDP) training.
- model#
The model to be trained.
- Type:
torch.nn.Module
- rank#
The rank of the process in distributed training.
- Type:
int
- module#
If True, use model with module parallelism (default: False).
- Type:
bool
- train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler,
scheduler, metrics):
Perform training for one epoch and return training metrics.
- validate(epoch, conf, valid_loader, criterion, metrics)#
Validate the model on the validation dataset and return validation metrics.
- fit_deprecated(conf, train_loader, valid_loader, optimizer, train_criterion,
valid_criterion, scaler, scheduler, metrics, trial=False):
Perform the full training loop across multiple epochs, including validation and checkpointing.
- train_one_epoch(epoch, trainloader, optimizer, criterion, scaler, scheduler, metrics)#
Trains the model for one epoch.
- Parameters:
epoch (int) – Current epoch number.
conf (dict) – Configuration dictionary containing training settings.
trainloader (DataLoader) – DataLoader for the training dataset.
optimizer (torch.optim.Optimizer) – Optimizer used for training.
criterion (callable) – Loss function used for training.
scaler (torch.cuda.amp.GradScaler) – Gradient scaler for mixed precision training.
scheduler (torch.optim.lr_scheduler._LRScheduler) – Learning rate scheduler.
metrics (callable) – Function to compute metrics for evaluation.
- Returns:
Dictionary containing training metrics and loss for the epoch.
- Return type:
dict
- validate(epoch, valid_loader, criterion, metrics)#
Validates the model on the validation dataset.
- Parameters:
epoch (int) – Current epoch number.
conf (dict) – Configuration dictionary containing validation settings.
valid_loader (DataLoader) – DataLoader for the validation dataset.
criterion (callable) – Loss function used for validation.
metrics (callable) – Function to compute metrics for evaluation.
- Returns:
Dictionary containing validation metrics and loss for the epoch.
- Return type:
dict