credit.losses.weighted_loss#
Attributes#
Classes#
Custom loss function class for 2D geospatial data |
Functions#
|
Calculate latitude-based weights for loss function. |
|
Create variable-specific weights for different atmospheric |
Module Contents#
- credit.losses.weighted_loss.logger#
- credit.losses.weighted_loss.latitude_weights(conf)#
Calculate latitude-based weights for loss function. This function calculates weights based on latitude values to be used in loss functions for geospatial data. The weights are derived from the cosine of the latitude and normalized by their mean.
- Parameters:
conf (dict) – Configuration dictionary containing the path to the latitude weights file.
- Returns:
- A 2D tensor of weights with dimensions
corresponding to latitude and longitude.
- Return type:
torch.Tensor
- credit.losses.weighted_loss.variable_weights(conf, channels, frames)#
Create variable-specific weights for different atmospheric and surface channels.
This function loads weights for different atmospheric variables (e.g., U, V, T, Q) and surface variables (e.g., SP, t2m) from the configuration file. It then combines them into a single weight tensor for use in loss calculations.
- Parameters:
conf (dict) – Configuration dictionary containing the variable weights.
channels (int) – Number of channels for atmospheric variables.
frames (int) – Number of time frames.
- Returns:
- A tensor containing the combined weights for
all variables.
- Return type:
torch.Tensor
- class credit.losses.weighted_loss.VariableTotalLoss2D(conf, validation=False)#
Bases:
torch.nn.ModuleCustom loss function class for 2D geospatial data with optional spectral and power loss components.
This class defines a loss function that combines a base loss (e.g., L1, MSE) with optional spectral and power loss components for 2D geospatial data. The loss function can incorporate latitude and variable-specific weights.
- Parameters:
conf (dict) – Configuration dictionary containing loss function settings and weights.
validation (bool, optional) – If True, the loss function is used in validation mode. Defaults to False.
- conf#
- training_loss#
- vars#
- lat_weights = None#
- var_weights = None#
- use_spectral_loss#
- use_power_loss#
- validation = False#
- forward(target, pred)#
Calculate the total loss for the given target and prediction.
This method computes the base loss between the target and prediction, applies latitude and variable weights, and optionally adds spectral and power loss components.
- Parameters:
target (torch.Tensor) – Ground truth tensor.
pred (torch.Tensor) – Predicted tensor.
- Returns:
The computed loss value.
- Return type:
torch.Tensor