credit.trainers.trainer_gen2#

Attributes#

Classes#

TrainerERA5Gen2

Helper class that provides a standard way to create an ABC using

Module Contents#

credit.trainers.trainer_gen2.logger#
class credit.trainers.trainer_gen2.TrainerERA5Gen2(model: torch.nn.Module, rank: int, conf: dict)#

Bases: credit.trainers.base_trainer.BaseTrainer

Helper class that provides a standard way to create an ABC using inheritance.

ic_preblocks#
step_preblocks#
step_postblocks#
rollout_postblocks#
varnum_diag#
retain_graph#
forecast_len#
backprop_on_timestep#
valid_history_len#
valid_forecast_len#
skip_nan_prune#
train_one_epoch(epoch, trainloader, optimizer, criterion, scaler, scheduler, metrics)#

Train for one epoch.

The inner loop iterates over forecast_len autoregressive steps. For each step:
  1. Pull the next batch from the dataloader (raw, unnormalized).

  2. At t=1: IC-only preblocks produce ic_preprocessed (regridded statics); rollout preblocks produce the final normalized input x. At t>1: assemble rollout batch from corrected_pred (prognostic), ic_preprocessed (statics), and curr_batch (dynamic forcing); rollout preblocks normalize and concat.

  3. Forward pass → y_pred_flat (flat, normalized).

  4. Apply postblocks: Reconstruct → inverse scaler → physics fixers. After this, full_data_dict[“y_processed”] is a nested dict split by Reconstruct.

  5. Compute loss on y_pred_flat vs the normalized target from preblocks.

Parameters:
  • epoch – Current epoch number.

  • trainloader – DataLoader for training.

  • optimizer – Standard training objects.

  • criterion – Standard training objects.

  • scaler – Standard training objects.

  • scheduler – Standard training objects.

  • metrics – Standard training objects.

Returns:

Training metrics for the epoch.

Return type:

dict

validate(epoch, valid_loader, criterion, metrics)#

Validate for one epoch.

Runs self.valid_forecast_len autoregressive steps per sample. Loss and metrics are computed only at the final step.

Parameters:
  • epoch – Current epoch number.

  • valid_loader – DataLoader for validation.

  • criterion – Loss and metric callables.

  • metrics – Loss and metric callables.

Returns:

Validation metrics for the epoch.

Return type:

dict

credit.trainers.trainer_gen2.Trainer#