credit.losses.covariance#

Attributes#

Classes#

CovarianceWeightedMSELoss

Mean Squared Error weighted by the error covariance matrix across variables, levels and output times.

Functions#

passthrough(in_val)

Module Contents#

credit.losses.covariance.passthrough(in_val)#
credit.losses.covariance.reduction_functions#
class credit.losses.covariance.CovarianceWeightedMSELoss(reduction: str = 'mean', batch_normalize: bool = False, off_diagonal_scale: float = 1.0, **kwargs)#

Bases: torch.nn.Module

Mean Squared Error weighted by the error covariance matrix across variables, levels and output times. Assumes input Tensors have shape (batch, variable, time, lat, lon).

Parameters:
  • reduction (str) – one of mean, none, sum, min, max

  • batch_normalize (bool) – If true, normalize each variable by the y_true batch means and standard devs.

reduction = 'mean'#
reduction_function#
batch_normalize = False#
off_diagonal_scale = 1.0#
forward(y_true, y_pred)#