credit.trainers.trainer_gen2
============================

.. py:module:: credit.trainers.trainer_gen2


Attributes
----------

.. autoapisummary::

   credit.trainers.trainer_gen2.logger
   credit.trainers.trainer_gen2.Trainer


Classes
-------

.. autoapisummary::

   credit.trainers.trainer_gen2.TrainerERA5Gen2


Module Contents
---------------

.. py:data:: logger

.. py:class:: TrainerERA5Gen2(model: torch.nn.Module, rank: int, conf: dict)

   Bases: :py:obj:`credit.trainers.base_trainer.BaseTrainer`


   Helper class that provides a standard way to create an ABC using
   inheritance.


   .. py:attribute:: ic_preblocks


   .. py:attribute:: step_preblocks


   .. py:attribute:: step_postblocks


   .. py:attribute:: rollout_postblocks


   .. py:attribute:: varnum_diag


   .. py:attribute:: retain_graph


   .. py:attribute:: forecast_len


   .. py:attribute:: backprop_on_timestep


   .. py:attribute:: valid_history_len


   .. py:attribute:: valid_forecast_len


   .. py:attribute:: skip_nan_prune


   .. py:method:: train_one_epoch(epoch, trainloader, optimizer, criterion, scaler, scheduler, metrics)

      Train for one epoch.

      The inner loop iterates over forecast_len autoregressive steps. For each step:
        1. Pull the next batch from the dataloader (raw, unnormalized).
        2. At t=1: IC-only preblocks produce ic_preprocessed (regridded statics);
           rollout preblocks produce the final normalized input x.
           At t>1: assemble rollout batch from corrected_pred (prognostic),
           ic_preprocessed (statics), and curr_batch (dynamic forcing);
           rollout preblocks normalize and concat.
        3. Forward pass → y_pred_flat (flat, normalized).
        4. Apply postblocks: Reconstruct → inverse scaler → physics fixers.
           After this, full_data_dict["y_processed"] is a nested dict split by Reconstruct.
        5. Compute loss on y_pred_flat vs the normalized target from preblocks.

      :param epoch: Current epoch number.
      :param trainloader: DataLoader for training.
      :param optimizer: Standard training objects.
      :param criterion: Standard training objects.
      :param scaler: Standard training objects.
      :param scheduler: Standard training objects.
      :param metrics: Standard training objects.

      :returns: Training metrics for the epoch.
      :rtype: dict



   .. py:method:: validate(epoch, valid_loader, criterion, metrics)

      Validate for one epoch.

      Runs self.valid_forecast_len autoregressive steps per sample.
      Loss and metrics are computed only at the final step.

      :param epoch: Current epoch number.
      :param valid_loader: DataLoader for validation.
      :param criterion: Loss and metric callables.
      :param metrics: Loss and metric callables.

      :returns: Validation metrics for the epoch.
      :rtype: dict



.. py:data:: Trainer

