credit.trainers.trainerERA5_ensemble#

Attributes#

Classes#

Gather

Custom autograd function for gathering tensors from all processes while preserving gradients.

TrainerERA5Ensemble

Helper class that provides a standard way to create an ABC using

Functions#

gather_tensor(tensor)

Gathers tensors from all ranks and preserves autograd graph.

Module Contents#

credit.trainers.trainerERA5_ensemble.logger#
class credit.trainers.trainerERA5_ensemble.Gather(*args, **kwargs)#

Bases: torch.autograd.Function

Custom autograd function for gathering tensors from all processes while preserving gradients.

This layer performs an all_gather operation on the provided tensor across all distributed processes and concatenates them along the batch dimension (dim=0). The backward pass correctly routes gradients back to the originating processes.

This is useful for operations like computng ensembles where you need to compute the CRPS between samples across all GPUs, while still being able to backpropagate through the gathered tensor.

static forward(ctx, input)#

Gather tensors from all ranks and concatenate them on the batch dimension.

Parameters:
  • ctx – Context object to store information for backward pass

  • input – Tensor to be gathered across processes

Returns:

Concatenated tensor from all processes

static backward(ctx, grad_output)#

Distribute gradients back to their originating processes.

Parameters:
  • ctx – Context object with stored information from forward pass

  • grad_output – Gradient with respect to the forward output

Returns:

Gradient for the input tensor

credit.trainers.trainerERA5_ensemble.gather_tensor(tensor)#

Gathers tensors from all ranks and preserves autograd graph.

This function allows you to gather tensors from all processes in a distributed setting while maintaining the autograd graph for backward passes. This is critical for operations that need to compute losses across all samples in a distributed training environment.

Parameters:

tensor – The tensor to gather across processes

Returns:

Tensor concatenated from all processes along dimension 0

Example

>>> # On each GPU
>>> local_tensor = torch.randn(8, 128)  # local batch of embeddings
>>> # Gather embeddings from all GPUs (total batch_size * world_size)
>>> gathered_tensor = gather_tensor(local_tensor)
>>> # Now you can compute a loss that depends on all samples
class credit.trainers.trainerERA5_ensemble.TrainerERA5Ensemble(model: torch.nn.Module, rank: int, conf: dict)#

Bases: credit.trainers.base_trainer.BaseTrainer

Helper class that provides a standard way to create an ABC using inheritance.

train_one_epoch(epoch, trainloader, optimizer, criterion, scaler, scheduler, metrics)#

Trains the model for one epoch.

Parameters:
  • epoch (int) – Current epoch number.

  • conf (dict) – Configuration dictionary containing training settings.

  • trainloader (DataLoader) – DataLoader for the training dataset.

  • optimizer (torch.optim.Optimizer) – Optimizer used for training.

  • criterion (callable) – Loss function used for training.

  • scaler (torch.cuda.amp.GradScaler) – Gradient scaler for mixed precision training.

  • scheduler (torch.optim.lr_scheduler._LRScheduler) – Learning rate scheduler.

  • metrics (callable) – Function to compute metrics for evaluation.

Returns:

Dictionary containing training metrics and loss for the epoch.

Return type:

dict

validate(epoch, valid_loader, criterion, metrics)#

Validates the model on the validation dataset.

Parameters:
  • epoch (int) – Current epoch number.

  • conf (dict) – Configuration dictionary containing validation settings.

  • valid_loader (DataLoader) – DataLoader for the validation dataset.

  • criterion (callable) – Loss function used for validation.

  • metrics (callable) – Function to compute metrics for evaluation.

Returns:

Dictionary containing validation metrics and loss for the epoch.

Return type:

dict