credit.trainers.base_trainer
============================

.. py:module:: credit.trainers.base_trainer

.. autoapi-nested-parse::

   base_trainer.py
   -------------------------------------------------------
   Content:
       - EMATracker
       - BaseTrainer (abstract)
           - train_one_epoch  (abstract)
           - validate         (abstract)
           - fit
           - _save_checkpoint (internal)



Attributes
----------

.. autoapisummary::

   credit.trainers.base_trainer._SummaryWriter
   credit.trainers.base_trainer.logger


Classes
-------

.. autoapisummary::

   credit.trainers.base_trainer.EMATracker
   credit.trainers.base_trainer.BaseTrainer


Module Contents
---------------

.. py:data:: _SummaryWriter
   :value: None


.. py:data:: logger

.. py:class:: EMATracker(model: torch.nn.Module, decay: float = 0.9999)

   Exponential moving average of model weights.

   Maintains a shadow copy of the model parameters:
       shadow = decay * shadow + (1 - decay) * param

   Uses adaptive decay so short runs are not dominated by initial random weights:
       effective_decay = min(max_decay, (1 + step) / (10 + step))

   This ramps from ~0.09 at step 0 to max_decay asymptotically, so validation
   always reflects recent training weights regardless of run length.

   Usage:
       ema = EMATracker(model, decay=0.9999)
       # after each optimizer.step():
       ema.update(model)
       # before validation:
       ema.swap(model)   # model now holds EMA weights
       ...validate...
       ema.swap(model)   # restore training weights

   Typical max_decay: 0.9999 for long runs, 0.999 for short runs.


   .. py:attribute:: decay
      :value: 0.9999



   .. py:attribute:: step
      :value: 0



   .. py:attribute:: shadow
      :type:  collections.OrderedDict


   .. py:method:: _is_spectral_norm_buffer(key: str, state: dict) -> bool
      :staticmethod:



   .. py:method:: update(model: torch.nn.Module)


   .. py:method:: swap(model: torch.nn.Module)

      Swap model weights with EMA shadow weights (and vice-versa).



   .. py:method:: state_dict()


   .. py:method:: load_state_dict(d)


.. py:class:: BaseTrainer(model: torch.nn.Module, rank: int, conf: Dict[str, Any])

   Bases: :py:obj:`abc.ABC`


   Helper class that provides a standard way to create an ABC using
   inheritance.


   .. py:attribute:: model


   .. py:attribute:: rank


   .. py:attribute:: device


   .. py:attribute:: conf


   .. py:attribute:: save_loc


   .. py:attribute:: mode


   .. py:attribute:: distributed


   .. py:attribute:: start_epoch


   .. py:attribute:: epochs


   .. py:attribute:: skip_validation


   .. py:attribute:: load_weights


   .. py:attribute:: use_scheduler


   .. py:attribute:: scheduler_type


   .. py:attribute:: amp


   .. py:attribute:: grad_max_norm


   .. py:attribute:: batches_per_epoch


   .. py:attribute:: valid_batches_per_epoch


   .. py:attribute:: ensemble_size


   .. py:attribute:: save_best_weights


   .. py:attribute:: save_backup_weights


   .. py:attribute:: stopping_patience


   .. py:attribute:: save_every_epoch


   .. py:attribute:: stop_after_epoch


   .. py:attribute:: num_epoch


   .. py:attribute:: save_metric_vars


   .. py:attribute:: train_one_epoch_mode


   .. py:attribute:: training_metric


   .. py:attribute:: direction


   .. py:method:: _model_state_dict(model: torch.nn.Module) -> dict
      :staticmethod:


      Return state dict, unwrapping DDP .module wrapper if present.



   .. py:method:: _load_model_state_dict(model: torch.nn.Module, state_dict: dict) -> None
      :staticmethod:


      Load state dict, unwrapping DDP .module wrapper if present.



   .. py:method:: train_one_epoch(epoch: int, trainloader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, scaler: torch.amp.GradScaler, scheduler: torch.optim.lr_scheduler.LRScheduler, metrics: Dict[str, Any]) -> Dict[str, float]
      :abstractmethod:



   .. py:method:: validate(epoch: int, valid_loader: torch.utils.data.DataLoader, criterion: torch.nn.Module, metrics: Dict[str, Any]) -> Dict[str, float]
      :abstractmethod:



   .. py:method:: _log_batch_progress(epoch: int, results_dict: dict, optimizer: Optional[torch.optim.Optimizer], pbar, phase: str = 'train') -> None

      Update a tqdm progress bar with rolling-mean batch metrics.

      Reads any keys in *results_dict* that start with ``<phase>_`` and
      formats them as a space-separated description string. Always appends
      the current learning rate.

      :param epoch: Current epoch number.
      :param results_dict: Dict mapping metric names to lists of per-batch values.
      :param optimizer: Current optimizer (used to read the learning rate).
      :param pbar: tqdm progress bar to update.
      :param phase: Metric prefix, either ``"train"`` or ``"valid"``.



   .. py:method:: _save_checkpoint(epoch: int, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, scaler: torch.amp.GradScaler) -> None

      Save model, optimizer, scheduler, and scaler state.



   .. py:method:: fit(conf: Dict[str, Any], train_loader: torch.utils.data.DataLoader, valid_loader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, train_criterion: torch.nn.Module, valid_criterion: torch.nn.Module, scaler: torch.amp.GradScaler, scheduler: torch.optim.lr_scheduler.LRScheduler, metrics: Dict[str, Any], rollout_scheduler: Optional[Callable] = None, trial: bool = False) -> Dict[str, Any]

      Run the full training loop.

      :param conf: Full configuration dict (passed through to train_one_epoch/validate
                   for data-related settings; trainer settings are accessed via self).
      :param train_loader: DataLoaders.
      :param valid_loader: DataLoaders.
      :param optimizer: Training objects.
      :param train_criterion: Training objects.
      :param valid_criterion: Training objects.
      :param scaler: Training objects.
      :param scheduler: Training objects.
      :param metrics: Training objects.
      :param rollout_scheduler: Optional callable to schedule rollout probability.
      :param trial: Optuna trial object, or False.

      :returns: Dict with the best epoch's results.



