credit.trainers.utils
=====================

.. py:module:: credit.trainers.utils


Functions
---------

.. autoapisummary::

   credit.trainers.utils.cleanup
   credit.trainers.utils.cycle
   credit.trainers.utils.accum_log
   credit.trainers.utils.inject_flat_var_keys
   credit.trainers.utils.inject_postblock_info
   credit.trainers.utils.load_dataset
   credit.trainers.utils.load_dataloader
   credit.trainers.utils.load_model_states_and_optimizer


Module Contents
---------------

.. py:function:: cleanup()

.. py:function:: cycle(dl)

.. py:function:: accum_log(log, new_logs)

.. py:function:: inject_flat_var_keys(conf: dict) -> None

   Inject Gen1-compatible variable keys into conf["data"] for metrics/loss.

   ``LatWeightedMetrics`` and ``VariableTotalLoss2D`` expect flat lists at:
   - ``conf["data"]["variables"]``
   - ``conf["data"]["surface_variables"]``
   - ``conf["data"]["diagnostic_variables"]``

   These are derived from the nested Gen2 source config.


.. py:function:: inject_postblock_info(conf: dict) -> None

   Inject post-block indices into conf["postblock"] for proper initialization of Gen 1 postblocks.


.. py:function:: load_dataset(conf: dict, is_train: bool) -> credit.datasets.multi_source.MultiSourceDataset

   Build a MultiSourceDataset for train or validation.


.. py:function:: load_dataloader(conf: dict, dataset: credit.datasets.multi_source.MultiSourceDataset, rank: int, world_size: int, is_train: bool) -> torch.utils.data.DataLoader

   Build a DataLoader with DistributedMultiStepBatchSampler.


.. py:function:: load_model_states_and_optimizer(conf, model, device)

   Load model weights, optimizer, scheduler, and gradient scaler.


