trainerERA5_v1#
Attributes#
Classes#
Helper class that provides a standard way to create an ABC using |
Module Contents#
- trainerERA5_v1.logger#
- class trainerERA5_v1.TOADataLoader(conf)#
- TOA#
- times_b = None#
- days_of_year#
- hours_of_day#
- __call__(datetime_input)#
- class trainerERA5_v1.Trainer(model: torch.nn.Module, rank: int)#
Bases:
credit.trainers.base_trainer.BaseTrainerHelper class that provides a standard way to create an ABC using inheritance.
- train_one_epoch(epoch: int, conf: Dict[str, Any], trainloader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, scaler: torch.cuda.amp.GradScaler, scheduler: torch.optim.lr_scheduler._LRScheduler, metrics: Dict[str, Any]) Dict[str, float]#
Train the model for one epoch.
- Parameters:
epoch (int) – The current epoch number.
conf (Dict[str, Any]) – The configuration dictionary.
trainloader (torch.utils.data.DataLoader) – The training data loader.
optimizer (torch.optim.Optimizer) – The optimizer.
criterion (torch.nn.Module) – The loss function.
scaler (torch.cuda.amp.GradScaler) – The gradient scaler for mixed precision training.
scheduler (torch.optim.lr_scheduler.LRScheduler) – The learning rate scheduler.
metrics (Dict[str, Any]) – The metrics to track during training.
- Returns:
A dictionary containing the training results.
- Return type:
Dict[str, float]
- validate(epoch: int, conf: Dict[str, Any], valid_loader: torch.utils.data.DataLoader, criterion: torch.nn.Module, metrics: Dict[str, Any]) Dict[str, float]#
Validate the model on the validation set.
- Parameters:
epoch (int) – The current epoch number.
conf (Dict[str, Any]) – The configuration dictionary.
valid_loader (torch.utils.data.DataLoader) – The validation data loader.
criterion (torch.nn.Module) – The loss function.
metrics (Dict[str, Any]) – The metrics to track during validation.
- Returns:
A dictionary containing the validation results.
- Return type:
Dict[str, float]