credit.trainers.trainer_gen2#
Attributes#
Classes#
Helper class that provides a standard way to create an ABC using |
Module Contents#
- credit.trainers.trainer_gen2.logger#
- class credit.trainers.trainer_gen2.TrainerERA5Gen2(model: torch.nn.Module, rank: int, conf: dict)#
Bases:
credit.trainers.base_trainer.BaseTrainerHelper class that provides a standard way to create an ABC using inheritance.
- ic_preblocks#
- step_preblocks#
- step_postblocks#
- rollout_postblocks#
- varnum_diag#
- retain_graph#
- forecast_len#
- backprop_on_timestep#
- valid_history_len#
- valid_forecast_len#
- skip_nan_prune#
- train_one_epoch(epoch, trainloader, optimizer, criterion, scaler, scheduler, metrics)#
Train for one epoch.
- The inner loop iterates over forecast_len autoregressive steps. For each step:
Pull the next batch from the dataloader (raw, unnormalized).
At t=1: IC-only preblocks produce ic_preprocessed (regridded statics); rollout preblocks produce the final normalized input x. At t>1: assemble rollout batch from corrected_pred (prognostic), ic_preprocessed (statics), and curr_batch (dynamic forcing); rollout preblocks normalize and concat.
Forward pass → y_pred_flat (flat, normalized).
Apply postblocks: Reconstruct → inverse scaler → physics fixers. After this, full_data_dict[“y_processed”] is a nested dict split by Reconstruct.
Compute loss on y_pred_flat vs the normalized target from preblocks.
- Parameters:
epoch – Current epoch number.
trainloader – DataLoader for training.
optimizer – Standard training objects.
criterion – Standard training objects.
scaler – Standard training objects.
scheduler – Standard training objects.
metrics – Standard training objects.
- Returns:
Training metrics for the epoch.
- Return type:
dict
- validate(epoch, valid_loader, criterion, metrics)#
Validate for one epoch.
Runs self.valid_forecast_len autoregressive steps per sample. Loss and metrics are computed only at the final step.
- Parameters:
epoch – Current epoch number.
valid_loader – DataLoader for validation.
criterion – Loss and metric callables.
metrics – Loss and metric callables.
- Returns:
Validation metrics for the epoch.
- Return type:
dict
- credit.trainers.trainer_gen2.Trainer#