credit.trainers
===============

.. py:module:: credit.trainers


Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/credit/trainers/base_trainer/index
   /autoapi/credit/trainers/ic_optimization/index
   /autoapi/credit/trainers/trainerERA5/index
   /autoapi/credit/trainers/trainerERA5_Diffusion/index
   /autoapi/credit/trainers/trainerERA5_ensemble/index
   /autoapi/credit/trainers/trainerLES/index
   /autoapi/credit/trainers/trainerWRF/index
   /autoapi/credit/trainers/trainerWRF_multi/index
   /autoapi/credit/trainers/trainer_downscaling/index
   /autoapi/credit/trainers/trainer_om4_samudra/index
   /autoapi/credit/trainers/utils/index


Attributes
----------

.. autoapisummary::

   credit.trainers.logger
   credit.trainers.trainer_types


Classes
-------

.. autoapisummary::

   credit.trainers.TrainerERA5
   credit.trainers.TrainerERA5_Diffusion
   credit.trainers.TrainerEnsemble
   credit.trainers.Trainer404
   credit.trainers.TrainerIC
   credit.trainers.TrainerSamudra
   credit.trainers.TrainerLES
   credit.trainers.TrainerWRF
   credit.trainers.TrainerWRFMulti


Functions
---------

.. autoapisummary::

   credit.trainers.load_trainer


Package Contents
----------------

.. py:class:: TrainerERA5(model: torch.nn.Module, rank: int)

   Bases: :py:obj:`credit.trainers.base_trainer.BaseTrainer`


   Helper class that provides a standard way to create an ABC using
   inheritance.


   .. py:method:: train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler, scheduler, metrics)

      Trains the model for one epoch.

      :param epoch: Current epoch number.
      :type epoch: int
      :param conf: Configuration dictionary containing training settings.
      :type conf: dict
      :param trainloader: DataLoader for the training dataset.
      :type trainloader: DataLoader
      :param optimizer: Optimizer used for training.
      :type optimizer: torch.optim.Optimizer
      :param criterion: Loss function used for training.
      :type criterion: callable
      :param scaler: Gradient scaler for mixed precision training.
      :type scaler: torch.cuda.amp.GradScaler
      :param scheduler: Learning rate scheduler.
      :type scheduler: torch.optim.lr_scheduler._LRScheduler
      :param metrics: Function to compute metrics for evaluation.
      :type metrics: callable

      :returns: Dictionary containing training metrics and loss for the epoch.
      :rtype: dict



   .. py:method:: validate(epoch, conf, valid_loader, criterion, metrics)

      Validates the model on the validation dataset.

      :param epoch: Current epoch number.
      :type epoch: int
      :param conf: Configuration dictionary containing validation settings.
      :type conf: dict
      :param valid_loader: DataLoader for the validation dataset.
      :type valid_loader: DataLoader
      :param criterion: Loss function used for validation.
      :type criterion: callable
      :param metrics: Function to compute metrics for evaluation.
      :type metrics: callable

      :returns: Dictionary containing validation metrics and loss for the epoch.
      :rtype: dict



.. py:class:: TrainerERA5_Diffusion(model: torch.nn.Module, rank: int)

   Bases: :py:obj:`credit.trainers.base_trainer.BaseTrainer`


   Helper class that provides a standard way to create an ABC using
   inheritance.


   .. py:method:: train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler, scheduler, metrics)

      Trains the model for one epoch.

      :param epoch: Current epoch number.
      :type epoch: int
      :param conf: Configuration dictionary containing training settings.
      :type conf: dict
      :param trainloader: DataLoader for the training dataset.
      :type trainloader: DataLoader
      :param optimizer: Optimizer used for training.
      :type optimizer: torch.optim.Optimizer
      :param criterion: Loss function used for training.
      :type criterion: callable
      :param scaler: Gradient scaler for mixed precision training.
      :type scaler: torch.cuda.amp.GradScaler
      :param scheduler: Learning rate scheduler.
      :type scheduler: torch.optim.lr_scheduler._LRScheduler
      :param metrics: Function to compute metrics for evaluation.
      :type metrics: callable

      :returns: Dictionary containing training metrics and loss for the epoch.
      :rtype: dict



   .. py:method:: validate(epoch, conf, valid_loader, criterion, metrics)

      Validates the model on the validation dataset.

      :param epoch: Current epoch number.
      :type epoch: int
      :param conf: Configuration dictionary containing validation settings.
      :type conf: dict
      :param valid_loader: DataLoader for the validation dataset.
      :type valid_loader: DataLoader
      :param criterion: Loss function used for validation.
      :type criterion: callable
      :param metrics: Function to compute metrics for evaluation.
      :type metrics: callable

      :returns: Dictionary containing validation metrics and loss for the epoch.
      :rtype: dict



.. py:class:: TrainerEnsemble(model: torch.nn.Module, rank: int)

   Bases: :py:obj:`credit.trainers.base_trainer.BaseTrainer`


   Helper class that provides a standard way to create an ABC using
   inheritance.


   .. py:method:: train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler, scheduler, metrics)

      Trains the model for one epoch.

      :param epoch: Current epoch number.
      :type epoch: int
      :param conf: Configuration dictionary containing training settings.
      :type conf: dict
      :param trainloader: DataLoader for the training dataset.
      :type trainloader: DataLoader
      :param optimizer: Optimizer used for training.
      :type optimizer: torch.optim.Optimizer
      :param criterion: Loss function used for training.
      :type criterion: callable
      :param scaler: Gradient scaler for mixed precision training.
      :type scaler: torch.cuda.amp.GradScaler
      :param scheduler: Learning rate scheduler.
      :type scheduler: torch.optim.lr_scheduler._LRScheduler
      :param metrics: Function to compute metrics for evaluation.
      :type metrics: callable

      :returns: Dictionary containing training metrics and loss for the epoch.
      :rtype: dict



   .. py:method:: validate(epoch, conf, valid_loader, criterion, metrics)

      Validates the model on the validation dataset.

      :param epoch: Current epoch number.
      :type epoch: int
      :param conf: Configuration dictionary containing validation settings.
      :type conf: dict
      :param valid_loader: DataLoader for the validation dataset.
      :type valid_loader: DataLoader
      :param criterion: Loss function used for validation.
      :type criterion: callable
      :param metrics: Function to compute metrics for evaluation.
      :type metrics: callable

      :returns: Dictionary containing validation metrics and loss for the epoch.
      :rtype: dict



.. py:class:: Trainer404(model: torch.nn.Module, rank: int)

   Bases: :py:obj:`credit.trainers.base_trainer.BaseTrainer`


   Helper class that provides a standard way to create an ABC using
   inheritance.


   .. py:method:: setup(conf)


   .. py:method:: train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler, scheduler, metrics)

      Train the model for one epoch.

      :param epoch: The current epoch number.
      :type epoch: int
      :param conf: The configuration dictionary.
      :type conf: Dict[str, Any]
      :param trainloader: The training data loader.
      :type trainloader: torch.utils.data.DataLoader
      :param optimizer: The optimizer.
      :type optimizer: torch.optim.Optimizer
      :param criterion: The loss function.
      :type criterion: torch.nn.Module
      :param scaler: The gradient scaler for mixed precision training.
      :type scaler: torch.cuda.amp.GradScaler
      :param scheduler: The learning rate scheduler.
      :type scheduler: torch.optim.lr_scheduler.LRScheduler
      :param metrics: The metrics to track during training.
      :type metrics: Dict[str, Any]

      :returns: A dictionary containing the training results.
      :rtype: Dict[str, float]



   .. py:method:: validate(epoch, conf, valid_loader, criterion, metrics)

      Validates the model on the validation dataset.

      :param epoch: Current epoch number.
      :type epoch: int
      :param conf: Configuration dictionary containing validation settings.
      :type conf: dict
      :param valid_loader: DataLoader for the validation dataset.
      :type valid_loader: DataLoader
      :param criterion: Loss function used for validation.
      :type criterion: callable
      :param metrics: Function to compute metrics for evaluation.
      :type metrics: callable

      :returns: Dictionary containing validation metrics and loss for the epoch.
      :rtype: dict



.. py:class:: TrainerIC(model: torch.nn.Module, rank: int)

   Bases: :py:obj:`credit.trainers.base_trainer.BaseTrainer`


   Helper class that provides a standard way to create an ABC using
   inheritance.


   .. py:method:: train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler, scheduler, metrics)

      Trains the model for one epoch.

      :param epoch: Current epoch number.
      :type epoch: int
      :param conf: Configuration dictionary containing training settings.
      :type conf: dict
      :param trainloader: DataLoader for the training dataset.
      :type trainloader: DataLoader
      :param optimizer: Optimizer used for training.
      :type optimizer: torch.optim.Optimizer
      :param criterion: Loss function used for training.
      :type criterion: callable
      :param scaler: Gradient scaler for mixed precision training.
      :type scaler: torch.cuda.amp.GradScaler
      :param scheduler: Learning rate scheduler.
      :type scheduler: torch.optim.lr_scheduler._LRScheduler
      :param metrics: Function to compute metrics for evaluation.
      :type metrics: callable

      :returns: Dictionary containing training metrics and loss for the epoch.
      :rtype: dict



   .. py:method:: validate(epoch, conf, valid_loader, criterion, metrics)

      Validates the model on the validation dataset.

      :param epoch: Current epoch number.
      :type epoch: int
      :param conf: Configuration dictionary containing validation settings.
      :type conf: dict
      :param valid_loader: DataLoader for the validation dataset.
      :type valid_loader: DataLoader
      :param criterion: Loss function used for validation.
      :type criterion: callable
      :param metrics: Function to compute metrics for evaluation.
      :type metrics: callable

      :returns: Dictionary containing validation metrics and loss for the epoch.
      :rtype: dict



.. py:class:: TrainerSamudra(model: torch.nn.Module, rank: int)

   Bases: :py:obj:`credit.trainers.base_trainer.BaseTrainer`


   Helper class that provides a standard way to create an ABC using
   inheritance.


   .. py:method:: train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler, scheduler, metrics)

      Trains the model for one epoch.

      :param epoch: Current epoch number.
      :type epoch: int
      :param conf: Configuration dictionary containing training settings.
      :type conf: dict
      :param trainloader: DataLoader for the training dataset.
      :type trainloader: DataLoader
      :param optimizer: Optimizer used for training.
      :type optimizer: torch.optim.Optimizer
      :param criterion: Loss function used for training.
      :type criterion: callable
      :param scaler: Gradient scaler for mixed precision training.
      :type scaler: torch.cuda.amp.GradScaler
      :param scheduler: Learning rate scheduler.
      :type scheduler: torch.optim.lr_scheduler._LRScheduler
      :param metrics: Function to compute metrics for evaluation.
      :type metrics: callable

      :returns: Dictionary containing training metrics and loss for the epoch.
      :rtype: dict



   .. py:method:: validate(epoch, conf, valid_loader, criterion, metrics)

      Validates the model on the validation dataset.

      :param epoch: Current epoch number.
      :type epoch: int
      :param conf: Configuration dictionary containing validation settings.
      :type conf: dict
      :param valid_loader: DataLoader for the validation dataset.
      :type valid_loader: DataLoader
      :param criterion: Loss function used for validation.
      :type criterion: callable
      :param metrics: Function to compute metrics for evaluation.
      :type metrics: callable

      :returns: Dictionary containing validation metrics and loss for the epoch.
      :rtype: dict



.. py:class:: TrainerLES(model: torch.nn.Module, rank: int)

   Bases: :py:obj:`credit.trainers.base_trainer.BaseTrainer`


   Helper class that provides a standard way to create an ABC using
   inheritance.


   .. py:method:: train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler, scheduler, metrics)

      Train the model for one epoch.

      :param epoch: The current epoch number.
      :type epoch: int
      :param conf: The configuration dictionary.
      :type conf: Dict[str, Any]
      :param trainloader: The training data loader.
      :type trainloader: torch.utils.data.DataLoader
      :param optimizer: The optimizer.
      :type optimizer: torch.optim.Optimizer
      :param criterion: The loss function.
      :type criterion: torch.nn.Module
      :param scaler: The gradient scaler for mixed precision training.
      :type scaler: torch.cuda.amp.GradScaler
      :param scheduler: The learning rate scheduler.
      :type scheduler: torch.optim.lr_scheduler.LRScheduler
      :param metrics: The metrics to track during training.
      :type metrics: Dict[str, Any]

      :returns: A dictionary containing the training results.
      :rtype: Dict[str, float]



   .. py:method:: validate(epoch, conf, valid_loader, criterion, metrics)

      Validate the model on the validation set.

      :param epoch: The current epoch number.
      :type epoch: int
      :param conf: The configuration dictionary.
      :type conf: Dict[str, Any]
      :param valid_loader: The validation data loader.
      :type valid_loader: torch.utils.data.DataLoader
      :param criterion: The loss function.
      :type criterion: torch.nn.Module
      :param metrics: The metrics to track during validation.
      :type metrics: Dict[str, Any]

      :returns: A dictionary containing the validation results.
      :rtype: Dict[str, float]



.. py:class:: TrainerWRF(model: torch.nn.Module, rank: int)

   Bases: :py:obj:`credit.trainers.base_trainer.BaseTrainer`


   Helper class that provides a standard way to create an ABC using
   inheritance.


   .. py:method:: train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler, scheduler, metrics)

      Train the model for one epoch.

      :param epoch: The current epoch number.
      :type epoch: int
      :param conf: The configuration dictionary.
      :type conf: Dict[str, Any]
      :param trainloader: The training data loader.
      :type trainloader: torch.utils.data.DataLoader
      :param optimizer: The optimizer.
      :type optimizer: torch.optim.Optimizer
      :param criterion: The loss function.
      :type criterion: torch.nn.Module
      :param scaler: The gradient scaler for mixed precision training.
      :type scaler: torch.cuda.amp.GradScaler
      :param scheduler: The learning rate scheduler.
      :type scheduler: torch.optim.lr_scheduler.LRScheduler
      :param metrics: The metrics to track during training.
      :type metrics: Dict[str, Any]

      :returns: A dictionary containing the training results.
      :rtype: Dict[str, float]



   .. py:method:: validate(epoch, conf, valid_loader, criterion, metrics)

      Validate the model on the validation set.

      :param epoch: The current epoch number.
      :type epoch: int
      :param conf: The configuration dictionary.
      :type conf: Dict[str, Any]
      :param valid_loader: The validation data loader.
      :type valid_loader: torch.utils.data.DataLoader
      :param criterion: The loss function.
      :type criterion: torch.nn.Module
      :param metrics: The metrics to track during validation.
      :type metrics: Dict[str, Any]

      :returns: A dictionary containing the validation results.
      :rtype: Dict[str, float]



.. py:class:: TrainerWRFMulti(model: torch.nn.Module, rank: int)

   Bases: :py:obj:`credit.trainers.base_trainer.BaseTrainer`


   Trainer class for handling the training, validation, and checkpointing of models.

   This class is responsible for executing the training loop, validating the model
   on a separate dataset, and managing checkpoints during training. It supports
   both single-GPU and distributed (FSDP, DDP) training.

   .. attribute:: model

      The model to be trained.

      :type: torch.nn.Module

   .. attribute:: rank

      The rank of the process in distributed training.

      :type: int

   .. attribute:: module

      If True, use model with module parallelism (default: False).

      :type: bool

   .. method:: train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler,

      
                  scheduler, metrics):
      Perform training for one epoch and return training metrics.
      

   .. method:: validate(epoch, conf, valid_loader, criterion, metrics)

      
      Validate the model on the validation dataset and return validation metrics.
      

   .. method:: fit_deprecated(conf, train_loader, valid_loader, optimizer, train_criterion,

      
                 valid_criterion, scaler, scheduler, metrics, trial=False):
      Perform the full training loop across multiple epochs, including validation
      and checkpointing.
      


   .. py:method:: train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler, scheduler, metrics)

      Trains the model for one epoch.

      :param epoch: Current epoch number.
      :type epoch: int
      :param conf: Configuration dictionary containing training settings.
      :type conf: dict
      :param trainloader: DataLoader for the training dataset.
      :type trainloader: DataLoader
      :param optimizer: Optimizer used for training.
      :type optimizer: torch.optim.Optimizer
      :param criterion: Loss function used for training.
      :type criterion: callable
      :param scaler: Gradient scaler for mixed precision training.
      :type scaler: torch.cuda.amp.GradScaler
      :param scheduler: Learning rate scheduler.
      :type scheduler: torch.optim.lr_scheduler._LRScheduler
      :param metrics: Function to compute metrics for evaluation.
      :type metrics: callable

      :returns: Dictionary containing training metrics and loss for the epoch.
      :rtype: dict



   .. py:method:: validate(epoch, conf, valid_loader, criterion, metrics)

      Validates the model on the validation dataset.

      :param epoch: Current epoch number.
      :type epoch: int
      :param conf: Configuration dictionary containing validation settings.
      :type conf: dict
      :param valid_loader: DataLoader for the validation dataset.
      :type valid_loader: DataLoader
      :param criterion: Loss function used for validation.
      :type criterion: callable
      :param metrics: Function to compute metrics for evaluation.
      :type metrics: callable

      :returns: Dictionary containing validation metrics and loss for the epoch.
      :rtype: dict



.. py:data:: logger

.. py:data:: trainer_types

.. py:function:: load_trainer(conf)

