credit.trainers.ic_optimization
===============================

.. py:module:: credit.trainers.ic_optimization


Attributes
----------

.. autoapisummary::

   credit.trainers.ic_optimization.logger


Classes
-------

.. autoapisummary::

   credit.trainers.ic_optimization.ForecastProcessor
   credit.trainers.ic_optimization.TimeStepper
   credit.trainers.ic_optimization.TrainerIC


Functions
---------

.. autoapisummary::

   credit.trainers.ic_optimization.save_netcdf_increment


Module Contents
---------------

.. py:data:: logger

.. py:function:: save_netcdf_increment(darray_upper_air: xarray.DataArray, darray_single_level: xarray.DataArray, nc_filename: str, forecast_hour: int, meta_data: dict, conf: dict, name_tag: str)

   Save CREDIT model prediction output to netCDF file. Also performs pressure level
   interpolation on the output if you wish.

   :param darray_upper_air: upper air variable predictions
   :type darray_upper_air: xr.DataArray
   :param darray_single_level: surface variable predictions
   :type darray_single_level: xr.DataArray
   :param nc_filename: file description to go into output filenames
   :type nc_filename: str
   :param forecast_hour: how many hours since the initialization of the model.
   :type forecast_hour: int
   :param meta_data: metadata dictionary for output variables
   :type meta_data: dict
   :param conf: configuration dictionary for training and/or rollout
   :type conf: dict


.. py:class:: ForecastProcessor(conf, device)

   .. py:attribute:: conf


   .. py:attribute:: device


   .. py:attribute:: batch_size


   .. py:attribute:: ensemble_size


   .. py:attribute:: lead_time_periods


   .. py:attribute:: latlons


   .. py:attribute:: meta_data


   .. py:method:: process(y_pred, datetimes, save_datetimes, nametag)


.. py:class:: TimeStepper(dataset)

   .. py:attribute:: dataset


   .. py:attribute:: _active
      :value: False



   .. py:method:: __iter__()


   .. py:method:: reset(idx=0)

      Initialize new sample starting from forecast step 0.



   .. py:method:: __next__()

      Advance forecast steps until forecast_len + 1.



.. py:class:: TrainerIC(model: torch.nn.Module, rank: int)

   Bases: :py:obj:`credit.trainers.base_trainer.BaseTrainer`


   Helper class that provides a standard way to create an ABC using
   inheritance.


   .. py:method:: train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler, scheduler, metrics)

      Trains the model for one epoch.

      :param epoch: Current epoch number.
      :type epoch: int
      :param conf: Configuration dictionary containing training settings.
      :type conf: dict
      :param trainloader: DataLoader for the training dataset.
      :type trainloader: DataLoader
      :param optimizer: Optimizer used for training.
      :type optimizer: torch.optim.Optimizer
      :param criterion: Loss function used for training.
      :type criterion: callable
      :param scaler: Gradient scaler for mixed precision training.
      :type scaler: torch.cuda.amp.GradScaler
      :param scheduler: Learning rate scheduler.
      :type scheduler: torch.optim.lr_scheduler._LRScheduler
      :param metrics: Function to compute metrics for evaluation.
      :type metrics: callable

      :returns: Dictionary containing training metrics and loss for the epoch.
      :rtype: dict



   .. py:method:: validate(epoch, conf, valid_loader, criterion, metrics)

      Validates the model on the validation dataset.

      :param epoch: Current epoch number.
      :type epoch: int
      :param conf: Configuration dictionary containing validation settings.
      :type conf: dict
      :param valid_loader: DataLoader for the validation dataset.
      :type valid_loader: DataLoader
      :param criterion: Loss function used for validation.
      :type criterion: callable
      :param metrics: Function to compute metrics for evaluation.
      :type metrics: callable

      :returns: Dictionary containing validation metrics and loss for the epoch.
      :rtype: dict



