credit.losses
=============

.. py:module:: credit.losses


Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/credit/losses/almost_fair_crps/index
   /autoapi/credit/losses/base_losses/index
   /autoapi/credit/losses/covariance/index
   /autoapi/credit/losses/downscaling_loss/index
   /autoapi/credit/losses/kcrps/index
   /autoapi/credit/losses/les_loss/index
   /autoapi/credit/losses/logcosh/index
   /autoapi/credit/losses/msle/index
   /autoapi/credit/losses/power/index
   /autoapi/credit/losses/spectral/index
   /autoapi/credit/losses/weighted_loss/index
   /autoapi/credit/losses/xsigmoid/index
   /autoapi/credit/losses/xtanh/index


Attributes
----------

.. autoapisummary::

   credit.losses.logger


Classes
-------

.. autoapisummary::

   credit.losses.VariableTotalLoss2D
   credit.losses.DownscalingLoss


Functions
---------

.. autoapisummary::

   credit.losses.base_losses
   credit.losses.load_loss


Package Contents
----------------

.. py:function:: base_losses(conf, reduction='mean', validation=False)

   Load a specified loss function by its type.

   :param conf: Configuration dictionary containing loss settings.
   :type conf: dict
   :param reduction: Default reduction method if not specified in parameters.
   :type reduction: str, optional
   :param validation: Use validation loss settings if True, else training loss.
   :type validation: bool

   :returns: Instantiated loss function.
   :rtype: torch.nn.Module


.. py:class:: VariableTotalLoss2D(conf, validation=False)

   Bases: :py:obj:`torch.nn.Module`


   Custom loss function class for 2D geospatial data
   with optional spectral and power loss components.

   This class defines a loss function that combines a base loss
   (e.g., L1, MSE) with optional spectral and power loss components
   for 2D geospatial data. The loss function can incorporate latitude
   and variable-specific weights.

   :param conf: Configuration dictionary containing loss
                function settings and weights.
   :type conf: dict
   :param validation: If True, the loss function
                      is used in validation mode. Defaults to False.
   :type validation: bool, optional


   .. py:attribute:: conf


   .. py:attribute:: training_loss


   .. py:attribute:: vars


   .. py:attribute:: lat_weights
      :value: None



   .. py:attribute:: var_weights
      :value: None



   .. py:attribute:: use_spectral_loss


   .. py:attribute:: use_power_loss


   .. py:attribute:: validation
      :value: False



   .. py:method:: forward(target, pred)

      Calculate the total loss for the given target and prediction.

      This method computes the base loss between the target and prediction,
      applies latitude and variable weights, and optionally adds spectral
      and power loss components.

      :param target: Ground truth tensor.
      :type target: torch.Tensor
      :param pred: Predicted tensor.
      :type pred: torch.Tensor

      :returns: The computed loss value.
      :rtype: torch.Tensor



.. py:class:: DownscalingLoss(conf, validation=False)

   Bases: :py:obj:`torch.nn.Module`


   Custom loss function for downscaling.

   :param conf: configuration dictionary containing loss function
                settings and weights.
   :type conf: dict
   :param validation: whether loss function is in validation
                      mode.  Defaults to False.
   :type validation: bool, optional


   .. py:attribute:: training_loss


   .. py:attribute:: use_power_loss


   .. py:attribute:: use_spectral_loss


   .. py:attribute:: spectral_lambda_reg


   .. py:attribute:: spectral_wavenum_init


   .. py:attribute:: validation
      :value: False



   .. py:method:: forward(target, pred)

      Calculate the total loss for the given target and prediction.

      This method computes the base loss between the target and
      prediction, applies optional variable weights, and optionally
      adds spectral and power loss components.

      :param target: Ground truth tensor.
      :type target: torch.Tensor
      :param pred: Predicted tensor.
      :type pred: torch.Tensor

      :returns: The computed loss value.
      :rtype: torch.Tensor



.. py:data:: logger

.. py:function:: load_loss(conf, reduction='none', validation=False)

   Load the appropriate loss function based on the configuration.

   This function determines whether to use a weighted custom loss wrapper
   (such as `VariableTotalLoss2D`) when latitude or variable weights are enabled,
   or to load a standard or custom loss via `available_losses`.

   If in validation mode and a separate validation loss is specified in the config,
   that loss type will be used. Otherwise, the training loss is used.

   :param conf: Configuration dictionary. Must contain a 'loss' section with keys like:
                - 'training_loss' (str): The primary loss function name.
                - 'validation_loss' (optional, str): An alternate loss for validation.
                - 'use_latitude_weights' (bool): Whether to use latitude-based weighting.
                - 'use_variable_weights' (bool): Whether to use variable-specific weighting.
   :type conf: dict
   :param reduction: Reduction method to apply to the loss ('mean', 'sum', or 'none').
                     Default is 'none'.
   :type reduction: str, optional
   :param validation: Whether the loss is being used for validation. Defaults to False.
   :type validation: bool, optional

   :returns:

             A loss function instance, either weighted (`VariableTotalLoss2D`)
                              or a standard/custom loss from `available_losses`.
   :rtype: torch.nn.Module

   :raises ValueError: If the requested loss type is not recognized in `available_losses`.


