credit.trainers.trainerERA5gen1#

Attributes#

Classes#

TrainerERA5Gen1

Helper class that provides a standard way to create an ABC using

Module Contents#

credit.trainers.trainerERA5gen1.logger#
class credit.trainers.trainerERA5gen1.TrainerERA5Gen1(model: torch.nn.Module, rank: int, conf: dict)#

Bases: credit.trainers.base_trainer.BaseTrainer

Helper class that provides a standard way to create an ABC using inheritance.

flag_mass_conserve = False#
flag_water_conserve = False#
flag_energy_conserve = False#
opt_mass = None#
opt_water = None#
opt_energy = None#
varnum_diag#
static_dim_size#
retain_graph#
forecast_len#
valid_history_len#
valid_forecast_len#
train_one_epoch(epoch, trainloader, optimizer, criterion, scaler, scheduler, metrics)#

Train for one epoch.

Parameters:
  • epoch – Current epoch number.

  • conf – Full configuration dict (data keys accessed here for schema stability).

  • trainloader – DataLoader for training.

  • optimizer – Standard training objects.

  • criterion – Standard training objects.

  • scaler – Standard training objects.

  • scheduler – Standard training objects.

  • metrics – Standard training objects.

Returns:

Training metrics for the epoch.

Return type:

dict

validate(epoch, valid_loader, criterion, metrics)#

Validate for one epoch.

Parameters:
  • epoch – Current epoch number.

  • conf – Full configuration dict.

  • valid_loader – DataLoader for validation.

  • criterion – Loss and metric callables.

  • metrics – Loss and metric callables.

Returns:

Validation metrics for the epoch.

Return type:

dict

credit.trainers.trainerERA5gen1.Trainer#