credit.losses.covariance#
Attributes#
Classes#
Mean Squared Error weighted by the error covariance matrix across variables, levels and output times. |
Functions#
|
Module Contents#
- credit.losses.covariance.passthrough(in_val)#
- credit.losses.covariance.reduction_functions#
- class credit.losses.covariance.CovarianceWeightedMSELoss(reduction: str = 'mean', batch_normalize: bool = False, off_diagonal_scale: float = 1.0, **kwargs)#
Bases:
torch.nn.ModuleMean Squared Error weighted by the error covariance matrix across variables, levels and output times. Assumes input Tensors have shape (batch, variable, time, lat, lon).
- Parameters:
reduction (str) – one of mean, none, sum, min, max
batch_normalize (bool) – If true, normalize each variable by the y_true batch means and standard devs.
- reduction = 'mean'#
- reduction_function#
- batch_normalize = False#
- off_diagonal_scale = 1.0#
- forward(y_true, y_pred)#