credit.trainers.trainerERA5_ensemble
====================================

.. py:module:: credit.trainers.trainerERA5_ensemble


Attributes
----------

.. autoapisummary::

   credit.trainers.trainerERA5_ensemble.logger


Classes
-------

.. autoapisummary::

   credit.trainers.trainerERA5_ensemble.Gather
   credit.trainers.trainerERA5_ensemble.TrainerERA5Ensemble


Functions
---------

.. autoapisummary::

   credit.trainers.trainerERA5_ensemble.gather_tensor


Module Contents
---------------

.. py:data:: logger

.. py:class:: Gather(*args, **kwargs)

   Bases: :py:obj:`torch.autograd.Function`


   Custom autograd function for gathering tensors from all processes while preserving gradients.

   This layer performs an all_gather operation on the provided tensor across all
   distributed processes and concatenates them along the batch dimension (dim=0).
   The backward pass correctly routes gradients back to the originating processes.

   This is useful for operations like computng ensembles where you need to compute
   the CRPS between samples across all GPUs, while still being able to backpropagate
   through the gathered tensor.


   .. py:method:: forward(ctx, input)
      :staticmethod:


      Gather tensors from all ranks and concatenate them on the batch dimension.

      :param ctx: Context object to store information for backward pass
      :param input: Tensor to be gathered across processes

      :returns: Concatenated tensor from all processes



   .. py:method:: backward(ctx, grad_output)
      :staticmethod:


      Distribute gradients back to their originating processes.

      :param ctx: Context object with stored information from forward pass
      :param grad_output: Gradient with respect to the forward output

      :returns: Gradient for the input tensor



.. py:function:: gather_tensor(tensor)

   Gathers tensors from all ranks and preserves autograd graph.

   This function allows you to gather tensors from all processes in a distributed
   setting while maintaining the autograd graph for backward passes. This is critical
   for operations that need to compute losses across all samples in a distributed
   training environment.

   :param tensor: The tensor to gather across processes

   :returns: Tensor concatenated from all processes along dimension 0

   .. rubric:: Example

   >>> # On each GPU
   >>> local_tensor = torch.randn(8, 128)  # local batch of embeddings
   >>> # Gather embeddings from all GPUs (total batch_size * world_size)
   >>> gathered_tensor = gather_tensor(local_tensor)
   >>> # Now you can compute a loss that depends on all samples


.. py:class:: TrainerERA5Ensemble(model: torch.nn.Module, rank: int, conf: dict)

   Bases: :py:obj:`credit.trainers.base_trainer.BaseTrainer`


   Helper class that provides a standard way to create an ABC using
   inheritance.


   .. py:method:: train_one_epoch(epoch, trainloader, optimizer, criterion, scaler, scheduler, metrics)

      Trains the model for one epoch.

      :param epoch: Current epoch number.
      :type epoch: int
      :param conf: Configuration dictionary containing training settings.
      :type conf: dict
      :param trainloader: DataLoader for the training dataset.
      :type trainloader: DataLoader
      :param optimizer: Optimizer used for training.
      :type optimizer: torch.optim.Optimizer
      :param criterion: Loss function used for training.
      :type criterion: callable
      :param scaler: Gradient scaler for mixed precision training.
      :type scaler: torch.cuda.amp.GradScaler
      :param scheduler: Learning rate scheduler.
      :type scheduler: torch.optim.lr_scheduler._LRScheduler
      :param metrics: Function to compute metrics for evaluation.
      :type metrics: callable

      :returns: Dictionary containing training metrics and loss for the epoch.
      :rtype: dict



   .. py:method:: validate(epoch, valid_loader, criterion, metrics)

      Validates the model on the validation dataset.

      :param epoch: Current epoch number.
      :type epoch: int
      :param conf: Configuration dictionary containing validation settings.
      :type conf: dict
      :param valid_loader: DataLoader for the validation dataset.
      :type valid_loader: DataLoader
      :param criterion: Loss function used for validation.
      :type criterion: callable
      :param metrics: Function to compute metrics for evaluation.
      :type metrics: callable

      :returns: Dictionary containing validation metrics and loss for the epoch.
      :rtype: dict



