credit.losses.power#

Classes#

PSDLoss

Power Spectral Density (PSD) Loss Function.

Module Contents#

class credit.losses.power.PSDLoss(wavenum_init=20)#

Bases: torch.nn.Module

Power Spectral Density (PSD) Loss Function.

This loss function calculates the Power Spectral Density (PSD) of the predicted and target outputs and compares them to ensure similar frequency content in the predictions.

Parameters:

wavenum_init (int) – The initial wavenumber to start considering in the loss calculation.

wavenum_init = 20#
forward(target, pred, weights=None)#

Forward pass for PSD loss.

Parameters:
  • target (torch.Tensor) – Target tensor.

  • pred (torch.Tensor) – Predicted tensor.

  • weights (torch.Tensor, optional) – Latitude weights for the loss.

Returns:

PSD loss value.

Return type:

torch.Tensor

get_psd(f_x, device, dtype)#