credit.trainers.base_trainer
============================

.. py:module:: credit.trainers.base_trainer

.. autoapi-nested-parse::

   base_trainer.py
   -------------------------------------------------------
   Content:
       - Trainer
           - train_one_epoch
           - validate
           - fit



Attributes
----------

.. autoapisummary::

   credit.trainers.base_trainer.logger


Classes
-------

.. autoapisummary::

   credit.trainers.base_trainer.BaseTrainer


Module Contents
---------------

.. py:data:: logger

.. py:class:: BaseTrainer(model: torch.nn.Module, rank: int)

   Bases: :py:obj:`abc.ABC`


   Helper class that provides a standard way to create an ABC using
   inheritance.


   .. py:attribute:: model


   .. py:attribute:: rank


   .. py:attribute:: device


   .. py:method:: train_one_epoch(epoch: int, conf: Dict[str, Any], trainloader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, scaler: torch.cuda.amp.GradScaler, scheduler: torch.optim.lr_scheduler.LRScheduler, metrics: Dict[str, Any]) -> Dict[str, float]
      :abstractmethod:


      Train the model for one epoch.

      :param epoch: The current epoch number.
      :type epoch: int
      :param conf: The configuration dictionary.
      :type conf: Dict[str, Any]
      :param trainloader: The training data loader.
      :type trainloader: torch.utils.data.DataLoader
      :param optimizer: The optimizer.
      :type optimizer: torch.optim.Optimizer
      :param criterion: The loss function.
      :type criterion: torch.nn.Module
      :param scaler: The gradient scaler for mixed precision training.
      :type scaler: torch.cuda.amp.GradScaler
      :param scheduler: The learning rate scheduler.
      :type scheduler: torch.optim.lr_scheduler.LRScheduler
      :param metrics: The metrics to track during training.
      :type metrics: Dict[str, Any]

      :returns: A dictionary containing the training results.
      :rtype: Dict[str, float]



   .. py:method:: validate(epoch: int, conf: Dict[str, Any], valid_loader: torch.utils.data.DataLoader, criterion: torch.nn.Module, metrics: Dict[str, Any]) -> Dict[str, float]
      :abstractmethod:


      Validate the model on the validation set.

      :param epoch: The current epoch number.
      :type epoch: int
      :param conf: The configuration dictionary.
      :type conf: Dict[str, Any]
      :param valid_loader: The validation data loader.
      :type valid_loader: torch.utils.data.DataLoader
      :param criterion: The loss function.
      :type criterion: torch.nn.Module
      :param metrics: The metrics to track during validation.
      :type metrics: Dict[str, Any]

      :returns: A dictionary containing the validation results.
      :rtype: Dict[str, float]



   .. py:method:: save_checkpoint(save_loc: str, state_dict: Dict[str, Any]) -> None

      Save a checkpoint of the model.

      :param save_loc: The location to save the checkpoint.
      :type save_loc: str
      :param state_dict: The state dictionary to save.
      :type state_dict: Dict[str, Any]



   .. py:method:: save_fsdp_checkpoint(save_loc: str, state_dict: Dict[str, Any]) -> None

      Save a checkpoint for FSDP training.

      :param save_loc: The location to save the checkpoint.
      :type save_loc: str
      :param state_dict: The state dictionary to save.
      :type state_dict: Dict[str, Any]



   .. py:method:: fit(conf: Dict[str, Any], train_loader: torch.utils.data.DataLoader, valid_loader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, train_criterion: torch.nn.Module, valid_criterion: torch.nn.Module, scaler: torch.cuda.amp.GradScaler, scheduler: torch.optim.lr_scheduler.LRScheduler, metrics: Dict[str, Any], rollout_scheduler: Optional[callable] = None, trial: bool = False) -> Dict[str, Any]

      Fit the model to the data.

      :param conf: Configuration dictionary.
      :type conf: Dict[str, Any]
      :param train_loader: DataLoader for training data.
      :type train_loader: DataLoader
      :param valid_loader: DataLoader for validation data.
      :type valid_loader: DataLoader
      :param optimizer: The optimizer to use for training.
      :type optimizer: Optimizer
      :param train_criterion: Loss function for training.
      :type train_criterion: torch.nn.Module
      :param valid_criterion: Loss function for validation.
      :type valid_criterion: torch.nn.Module
      :param scaler: Gradient scaler for mixed precision training.
      :type scaler: GradScaler
      :param scheduler: Learning rate scheduler.
      :type scheduler: _LRScheduler
      :param metrics: Dictionary of metrics to track during training.
      :type metrics: Dict[str, Any]
      :param rollout_scheduler: Function to schedule rollout probability, if applicable.
      :type rollout_scheduler: Optional[callable]
      :param trial: Whether this is a trial run (e.g., for hyperparameter tuning).
      :type trial: bool

      :returns: Dictionary containing the best results from training.
      :rtype: Dict[str, Any]



