credit.losses#

Submodules#

Attributes#

Classes#

VariableTotalLoss2D

Custom loss function class for 2D geospatial data

DownscalingLoss

Custom loss function for downscaling.

Functions#

base_losses(conf[, reduction, validation])

Load a specified loss function by its type.

load_loss(conf[, reduction, validation])

Load the appropriate loss function based on the configuration.

Package Contents#

credit.losses.base_losses(conf, reduction='mean', validation=False)#

Load a specified loss function by its type.

Parameters:
  • conf (dict) – Configuration dictionary containing loss settings.

  • reduction (str, optional) – Default reduction method if not specified in parameters.

  • validation (bool) – Use validation loss settings if True, else training loss.

Returns:

Instantiated loss function.

Return type:

torch.nn.Module

class credit.losses.VariableTotalLoss2D(conf, validation=False)#

Bases: torch.nn.Module

Custom loss function class for 2D geospatial data with optional spectral and power loss components.

This class defines a loss function that combines a base loss (e.g., L1, MSE) with optional spectral and power loss components for 2D geospatial data. The loss function can incorporate latitude and variable-specific weights.

Parameters:
  • conf (dict) – Configuration dictionary containing loss function settings and weights.

  • validation (bool, optional) – If True, the loss function is used in validation mode. Defaults to False.

conf#
training_loss#
vars#
lat_weights = None#
var_weights = None#
use_spectral_loss#
use_power_loss#
validation = False#
forward(target, pred)#

Calculate the total loss for the given target and prediction.

This method computes the base loss between the target and prediction, applies latitude and variable weights, and optionally adds spectral and power loss components.

Parameters:
  • target (torch.Tensor) – Ground truth tensor.

  • pred (torch.Tensor) – Predicted tensor.

Returns:

The computed loss value.

Return type:

torch.Tensor

class credit.losses.DownscalingLoss(conf, validation=False)#

Bases: torch.nn.Module

Custom loss function for downscaling.

Parameters:
  • conf (dict) – configuration dictionary containing loss function settings and weights.

  • validation (bool, optional) – whether loss function is in validation mode. Defaults to False.

training_loss#
use_power_loss#
use_spectral_loss#
spectral_lambda_reg#
spectral_wavenum_init#
validation = False#
forward(target, pred)#

Calculate the total loss for the given target and prediction.

This method computes the base loss between the target and prediction, applies optional variable weights, and optionally adds spectral and power loss components.

Parameters:
  • target (torch.Tensor) – Ground truth tensor.

  • pred (torch.Tensor) – Predicted tensor.

Returns:

The computed loss value.

Return type:

torch.Tensor

credit.losses.logger#
credit.losses.load_loss(conf, reduction='none', validation=False)#

Load the appropriate loss function based on the configuration.

This function determines whether to use a weighted custom loss wrapper (such as VariableTotalLoss2D) when latitude or variable weights are enabled, or to load a standard or custom loss via available_losses.

If in validation mode and a separate validation loss is specified in the config, that loss type will be used. Otherwise, the training loss is used.

Parameters:
  • conf (dict) – Configuration dictionary. Must contain a ‘loss’ section with keys like: - ‘training_loss’ (str): The primary loss function name. - ‘validation_loss’ (optional, str): An alternate loss for validation. - ‘use_latitude_weights’ (bool): Whether to use latitude-based weighting. - ‘use_variable_weights’ (bool): Whether to use variable-specific weighting.

  • reduction (str, optional) – Reduction method to apply to the loss (‘mean’, ‘sum’, or ‘none’). Default is ‘none’.

  • validation (bool, optional) – Whether the loss is being used for validation. Defaults to False.

Returns:

A loss function instance, either weighted (VariableTotalLoss2D)

or a standard/custom loss from available_losses.

Return type:

torch.nn.Module

Raises:

ValueError – If the requested loss type is not recognized in available_losses.