credit.losses.weighted_loss
===========================

.. py:module:: credit.losses.weighted_loss


Attributes
----------

.. autoapisummary::

   credit.losses.weighted_loss.logger


Classes
-------

.. autoapisummary::

   credit.losses.weighted_loss.VariableTotalLoss2D


Functions
---------

.. autoapisummary::

   credit.losses.weighted_loss.latitude_weights
   credit.losses.weighted_loss.variable_weights


Module Contents
---------------

.. py:data:: logger

.. py:function:: latitude_weights(conf)

   Calculate latitude-based weights for loss function.
   This function calculates weights based on latitude values
   to be used in loss functions for geospatial data. The weights
   are derived from the cosine of the latitude and normalized
   by their mean.

   :param conf: Configuration dictionary containing the
                path to the latitude weights file.
   :type conf: dict

   :returns:

             A 2D tensor of weights with dimensions
                 corresponding to latitude and longitude.
   :rtype: torch.Tensor


.. py:function:: variable_weights(conf, channels, frames)

   Create variable-specific weights for different atmospheric
   and surface channels.

   This function loads weights for different atmospheric variables
   (e.g., U, V, T, Q) and surface variables (e.g., SP, t2m) from
   the configuration file. It then combines them into a single
   weight tensor for use in loss calculations.

   :param conf: Configuration dictionary containing the
                variable weights.
   :type conf: dict
   :param channels: Number of channels for atmospheric variables.
   :type channels: int
   :param frames: Number of time frames.
   :type frames: int

   :returns:

             A tensor containing the combined weights for
                 all variables.
   :rtype: torch.Tensor


.. py:class:: VariableTotalLoss2D(conf, validation=False)

   Bases: :py:obj:`torch.nn.Module`


   Custom loss function class for 2D geospatial data
   with optional spectral and power loss components.

   This class defines a loss function that combines a base loss
   (e.g., L1, MSE) with optional spectral and power loss components
   for 2D geospatial data. The loss function can incorporate latitude
   and variable-specific weights.

   :param conf: Configuration dictionary containing loss
                function settings and weights.
   :type conf: dict
   :param validation: If True, the loss function
                      is used in validation mode. Defaults to False.
   :type validation: bool, optional


   .. py:attribute:: conf


   .. py:attribute:: training_loss


   .. py:attribute:: vars


   .. py:attribute:: lat_weights
      :value: None



   .. py:attribute:: var_weights
      :value: None



   .. py:attribute:: use_spectral_loss


   .. py:attribute:: use_power_loss


   .. py:attribute:: validation
      :value: False



   .. py:method:: forward(target, pred)

      Calculate the total loss for the given target and prediction.

      This method computes the base loss between the target and prediction,
      applies latitude and variable weights, and optionally adds spectral
      and power loss components.

      :param target: Ground truth tensor.
      :type target: torch.Tensor
      :param pred: Predicted tensor.
      :type pred: torch.Tensor

      :returns: The computed loss value.
      :rtype: torch.Tensor



