credit.loss
===========

.. py:module:: credit.loss


Attributes
----------

.. autoapisummary::

   credit.loss.logger


Classes
-------

.. autoapisummary::

   credit.loss.LogCoshLoss
   credit.loss.XTanhLoss
   credit.loss.XSigmoidLoss
   credit.loss.MSLELoss
   credit.loss.KCRPSLoss
   credit.loss.SpectralLoss2D
   credit.loss.PSDLoss
   credit.loss.VariableTotalLoss2D


Functions
---------

.. autoapisummary::

   credit.loss.load_loss
   credit.loss.latitude_weights
   credit.loss.variable_weights


Module Contents
---------------

.. py:data:: logger

.. py:function:: load_loss(loss_type, reduction='mean')

   Load a specified loss function by its type.
   Helper function of VariableTotalLoss2D

   This function returns a loss function based on the specified
   `loss_type`. It supports several common loss functions, including
   MSE, MAE, MSLE, Huber, Log-Cosh, X-Tanh, and X-Sigmoid. The loss
   function can also be customized to use different reduction methods
   (e.g., 'mean', 'sum'). Use reduction=none if using latitude or variable
   weights

   :param loss_type: The type of loss function to load. Supported
                     values are "mse", "mae", "msle", "huber", "logcosh",
                     "xtanh", and "xsigmoid".
   :type loss_type: str
   :param reduction: Specifies the reduction to apply to
                     the output: 'mean' (default) or 'sum'.
   :type reduction: str, optional

   :returns: The corresponding loss function.
   :rtype: torch.nn.Module

   :raises ValueError: If the specified `loss_type` is not supported.

   .. rubric:: Example

   >>> loss_fn = load_loss("mse")
   >>> loss = loss_fn(pred, target)


.. py:class:: LogCoshLoss(reduction='mean')

   Bases: :py:obj:`torch.nn.Module`


   Log-Cosh Loss Function.

   This loss function computes the logarithm of the hyperbolic cosine of the
   prediction error. It is less sensitive to outliers compared to the Mean
   Squared Error (MSE) loss.

   :param reduction: Specifies the reduction to apply to the output.
                     'mean' | 'none'. 'mean': the output is averaged; 'none': no reduction is applied.
   :type reduction: str


   .. py:attribute:: reduction
      :value: 'mean'



   .. py:method:: forward(y_t, y_prime_t)

      Forward pass for Log-Cosh loss.

      :param y_t: Target tensor.
      :type y_t: torch.Tensor
      :param y_prime_t: Predicted tensor.
      :type y_prime_t: torch.Tensor

      :returns: Log-Cosh loss value.
      :rtype: torch.Tensor



.. py:class:: XTanhLoss(reduction='mean')

   Bases: :py:obj:`torch.nn.Module`


   X-Tanh Loss Function.

   This loss function computes the element-wise product of the prediction error
   and the hyperbolic tangent of the error. This loss function aims to be more
   robust to outliers than traditional MSE.

   :param reduction: Specifies the reduction to apply to the output:
                     'mean' | 'none'. 'mean': the output is averaged; 'none': no reduction is applied.
   :type reduction: str


   .. py:attribute:: reduction
      :value: 'mean'



   .. py:method:: forward(y_t, y_prime_t)

      Forward pass for X-Tanh loss.

      :param y_t: Target tensor.
      :type y_t: torch.Tensor
      :param y_prime_t: Predicted tensor.
      :type y_prime_t: torch.Tensor

      :returns: X-Tanh loss value.
      :rtype: torch.Tensor



.. py:class:: XSigmoidLoss(reduction='mean')

   Bases: :py:obj:`torch.nn.Module`


   X-Sigmoid Loss Function.

   This loss function computes a modified loss by using a sigmoid function
   transformation. It is designed to handle large errors in a non-linear fashion.

   :param reduction: Specifies the reduction to apply to the output.
                     'mean' | 'none'. 'mean': the output is averaged; 'none': no reduction is applied.
   :type reduction: str


   .. py:attribute:: reduction
      :value: 'mean'



   .. py:method:: forward(y_t, y_prime_t)

      Forward pass for X-Sigmoid loss.

      :param y_t: Target tensor.
      :type y_t: torch.Tensor
      :param y_prime_t: Predicted tensor.
      :type y_prime_t: torch.Tensor

      :returns: X-Sigmoid loss value.
      :rtype: torch.Tensor



.. py:class:: MSLELoss(reduction='mean')

   Bases: :py:obj:`torch.nn.Module`


   Mean Squared Logarithmic Error (MSLE) Loss Function.

   This loss function computes the mean squared logarithmic error between the
   predicted and target values. It is useful for handling targets that span
   several orders of magnitude.

   :param reduction: Specifies the reduction to apply to the output.
                     'mean' | 'none'. 'mean': the output is averaged; 'none': no reduction is applied.
   :type reduction: str


   .. py:attribute:: reduction
      :value: 'mean'



   .. py:method:: forward(prediction, target)

      Forward pass for MSLE loss.

      :param prediction: Predicted tensor.
      :type prediction: torch.Tensor
      :param target: Target tensor.
      :type target: torch.Tensor

      :returns: MSLE loss value.
      :rtype: torch.Tensor



.. py:class:: KCRPSLoss(reduction, biased: bool = False)

   Bases: :py:obj:`torch.nn.Module`


   Adapted from Nvidia Modulus
   pred : Tensor
       Tensor containing the ensemble predictions. The ensemble dimension
       is assumed to be the leading dimension
   obs : Union[Tensor, np.ndarray]
       Tensor or array containing an observation over which the CRPS is computed
       with respect to.
   biased :
       When False, uses the unbiased estimators described in (Zamo and Naveau, 2018)::

           E|X-y|/m - 1/(2m(m-1)) sum_(i,j=1)|x_i - x_j|

       Unlike ``crps`` this is fair for finite ensembles. Non-fair ``crps`` favors less
       dispersive ensembles since it is biased high by E|X- X'|/ m where m is the
       ensemble size.

   Estimate the CRPS from a finite ensemble

   Computes the local Continuous Ranked Probability Score (CRPS) by using
   the kernel version of CRPS. The cost is O(m log m).

   Creates a map of CRPS and does not accumulate over lat/lon regions.
   Approximates:

   .. math::
       CRPS(X, y) = E[X - y] - 0.5 E[X-X']

   with

   .. math::
       sum_i=1^m |X_i - y| / m - 1/(2m^2) sum_i,j=1^m |x_i - x_j|



   .. py:attribute:: biased
      :value: False



   .. py:attribute:: batched_forward


   .. py:method:: forward(target, pred)


   .. py:method:: single_sample_forward(target, pred)

      Forward pass for KCRPS loss for a single sample.

      :param target: Target tensor.
      :type target: torch.Tensor
      :param pred: Predicted tensor.
      :type pred: torch.Tensor

      :returns: CRPS loss values at each lat/lon
      :rtype: torch.Tensor



   .. py:method:: _kernel_crps_implementation(pred: torch.Tensor, obs: torch.Tensor, biased: bool) -> torch.Tensor

      An O(m log m) implementation of the kernel CRPS formulas



.. py:class:: SpectralLoss2D(wavenum_init=20, reduction='none')

   Bases: :py:obj:`torch.nn.Module`


   Spectral Loss in 2D.

   This loss function compares the spectral (frequency domain) content of the
   predicted and target outputs using FFT. It is useful for ensuring that the
   predicted output has similar frequency characteristics as the target.

   :param wavenum_init: The initial wavenumber to start considering in the loss calculation.
   :type wavenum_init: int
   :param reduction: Specifies the reduction to apply to the output:
                     'mean' | 'none'. 'mean': the output is averaged; 'none': no reduction is applied.
   :type reduction: str


   .. py:attribute:: wavenum_init
      :value: 20



   .. py:attribute:: reduction
      :value: 'none'



   .. py:method:: forward(output, target, weights=None, fft_dim=-1)

      Forward pass for Spectral Loss 2D.

      :param output: Predicted tensor.
      :type output: torch.Tensor
      :param target: Target tensor.
      :type target: torch.Tensor
      :param weights: Latitude weights for the loss.
      :type weights: torch.Tensor, optional
      :param fft_dim: The dimension to apply FFT.
      :type fft_dim: int

      :returns: Spectral loss value.
      :rtype: torch.Tensor



.. py:class:: PSDLoss(wavenum_init=20)

   Bases: :py:obj:`torch.nn.Module`


   Power Spectral Density (PSD) Loss Function.

   This loss function calculates the Power Spectral Density (PSD) of the
   predicted and target outputs and compares them to ensure similar frequency
   content in the predictions.

   :param wavenum_init: The initial wavenumber to start considering in the loss calculation.
   :type wavenum_init: int


   .. py:attribute:: wavenum_init
      :value: 20



   .. py:method:: forward(target, pred, weights=None)

      Forward pass for PSD loss.

      :param target: Target tensor.
      :type target: torch.Tensor
      :param pred: Predicted tensor.
      :type pred: torch.Tensor
      :param weights: Latitude weights for the loss.
      :type weights: torch.Tensor, optional

      :returns: PSD loss value.
      :rtype: torch.Tensor



   .. py:method:: get_psd(f_x, device, dtype)


.. py:function:: latitude_weights(conf)

   Calculate latitude-based weights for loss function.
   This function calculates weights based on latitude values
   to be used in loss functions for geospatial data. The weights
   are derived from the cosine of the latitude and normalized
   by their mean.

   :param conf: Configuration dictionary containing the
                path to the latitude weights file.
   :type conf: dict

   :returns:

             A 2D tensor of weights with dimensions
                 corresponding to latitude and longitude.
   :rtype: torch.Tensor


.. py:function:: variable_weights(conf, channels, frames)

   Create variable-specific weights for different atmospheric
   and surface channels.

   This function loads weights for different atmospheric variables
   (e.g., U, V, T, Q) and surface variables (e.g., SP, t2m) from
   the configuration file. It then combines them into a single
   weight tensor for use in loss calculations.

   :param conf: Configuration dictionary containing the
                variable weights.
   :type conf: dict
   :param channels: Number of channels for atmospheric variables.
   :type channels: int
   :param frames: Number of time frames.
   :type frames: int

   :returns:

             A tensor containing the combined weights for
                 all variables.
   :rtype: torch.Tensor


.. py:class:: VariableTotalLoss2D(conf, validation=False)

   Bases: :py:obj:`torch.nn.Module`


   Custom loss function class for 2D geospatial data
   with optional spectral and power loss components.

   This class defines a loss function that combines a base loss
   (e.g., L1, MSE) with optional spectral and power loss components
   for 2D geospatial data. The loss function can incorporate latitude
   and variable-specific weights.

   :param conf: Configuration dictionary containing loss
                function settings and weights.
   :type conf: dict
   :param validation: If True, the loss function
                      is used in validation mode. Defaults to False.
   :type validation: bool, optional


   .. py:attribute:: conf


   .. py:attribute:: training_loss


   .. py:attribute:: vars


   .. py:attribute:: lat_weights
      :value: None



   .. py:attribute:: var_weights
      :value: None



   .. py:attribute:: use_spectral_loss


   .. py:attribute:: use_power_loss


   .. py:attribute:: validation
      :value: False



   .. py:method:: forward(target, pred)

      Calculate the total loss for the given target and prediction.

      This method computes the base loss between the target and prediction,
      applies latitude and variable weights, and optionally adds spectral
      and power loss components.

      :param target: Ground truth tensor.
      :type target: torch.Tensor
      :param pred: Predicted tensor.
      :type pred: torch.Tensor

      :returns: The computed loss value.
      :rtype: torch.Tensor



