credit.trainers.trainerERA5gen1#
Attributes#
Classes#
Helper class that provides a standard way to create an ABC using |
Module Contents#
- credit.trainers.trainerERA5gen1.logger#
- class credit.trainers.trainerERA5gen1.TrainerERA5Gen1(model: torch.nn.Module, rank: int, conf: dict)#
Bases:
credit.trainers.base_trainer.BaseTrainerHelper class that provides a standard way to create an ABC using inheritance.
- flag_mass_conserve = False#
- flag_water_conserve = False#
- flag_energy_conserve = False#
- opt_mass = None#
- opt_water = None#
- opt_energy = None#
- varnum_diag#
- static_dim_size#
- retain_graph#
- forecast_len#
- valid_history_len#
- valid_forecast_len#
- train_one_epoch(epoch, trainloader, optimizer, criterion, scaler, scheduler, metrics)#
Train for one epoch.
- Parameters:
epoch – Current epoch number.
conf – Full configuration dict (data keys accessed here for schema stability).
trainloader – DataLoader for training.
optimizer – Standard training objects.
criterion – Standard training objects.
scaler – Standard training objects.
scheduler – Standard training objects.
metrics – Standard training objects.
- Returns:
Training metrics for the epoch.
- Return type:
dict
- validate(epoch, valid_loader, criterion, metrics)#
Validate for one epoch.
- Parameters:
epoch – Current epoch number.
conf – Full configuration dict.
valid_loader – DataLoader for validation.
criterion – Loss and metric callables.
metrics – Loss and metric callables.
- Returns:
Validation metrics for the epoch.
- Return type:
dict
- credit.trainers.trainerERA5gen1.Trainer#