credit.transforms
=================

.. py:module:: credit.transforms

.. autoapi-nested-parse::

   transforms.py provides transforms.

   -------------------------------------------------------
   Content:
       - load_transforms
       - NormalizeState
       - Normalize_ERA5_and_Forcing
       - BridgescalerScaleState
       - NormalizeState_Quantile
       - NormalizeTendency
       - ToTensor
       - ToTensor_ERA5_and_Forcing
       - NormalizeState_Quantile_Bridgescalar
       - ToTensor_BridgeScaler



Attributes
----------

.. autoapisummary::

   credit.transforms.logger


Classes
-------

.. autoapisummary::

   credit.transforms.NormalizeState
   credit.transforms.Normalize_ERA5_and_Forcing
   credit.transforms.BridgescalerScaleState
   credit.transforms.NormalizeState_Quantile
   credit.transforms.NormalizeTendency
   credit.transforms.ToTensor
   credit.transforms.ToTensor_ERA5_and_Forcing
   credit.transforms.NormalizeState_Quantile_Bridgescalar
   credit.transforms.ToTensor_BridgeScaler


Functions
---------

.. autoapisummary::

   credit.transforms.device_compatible_to
   credit.transforms.load_transforms


Module Contents
---------------

.. py:data:: logger

.. py:function:: device_compatible_to(tensor: torch.Tensor, device: torch.device) -> torch.Tensor

   Safely move tensor to device, with float32 casting on MPS (Metal Performance Shaders). Addresses runtime error in OSX about MPS not supporting float64.

   :param tensor: Input tensor to move.
   :type tensor: torch.Tensor
   :param device: Target device.
   :type device: torch.device

   :returns: Tensor moved to device (cast to float32 if device is MPS).
   :rtype: torch.Tensor


.. py:function:: load_transforms(conf, scaler_only=False)

   Load transforms.

   :param conf: path to config
   :type conf: str
   :param scaler_only: True --> retrun scaler; False --> return scaler and ToTensor
   :type scaler_only: bool

   :returns: transform
   :rtype: tf.tensor


.. py:class:: NormalizeState(conf)

   Class to normalize state.


   .. py:attribute:: mean_ds


   .. py:attribute:: std_ds


   .. py:attribute:: variables


   .. py:attribute:: surface_variables


   .. py:attribute:: levels


   .. py:method:: __call__(sample: credit.data.Sample, inverse: bool = False) -> credit.data.Sample

      Normalize via quantile transform.

      Normalize via provided scaler file/s.

      :param sample: batch.
      :param inverse: if true, will inverse the transform.

      :returns: transformed type.
      :rtype: torch.tensor



   .. py:method:: transform_dataset(DS: xarray.Dataset) -> xarray.Dataset


   .. py:method:: transform_array(x: torch.Tensor) -> torch.Tensor

      Transform from unscaled to scaled values.

      Transform.

      :param x: batch.

      :returns: transformed x.



   .. py:method:: transform(sample: Dict[str, numpy.ndarray]) -> Dict[str, numpy.ndarray]

      Transform from unscaled to scaled values.

      Transform.

      :param sample: batch.

      :returns: transformed sample.



   .. py:method:: inverse_transform(x: torch.Tensor) -> torch.Tensor

      Inverse transform between tensor forms.

      Inverse transform.

      :param x: batch.

      :returns: inverse transformed x.



.. py:class:: Normalize_ERA5_and_Forcing(conf)

   Class to normalize ERA5 and Forcing Datasets.


   .. py:attribute:: mean_ds


   .. py:attribute:: std_ds


   .. py:attribute:: mean_tensors


   .. py:attribute:: std_tensors


   .. py:attribute:: levels


   .. py:attribute:: varname_upper_air


   .. py:attribute:: num_upper_air


   .. py:attribute:: flag_surface


   .. py:attribute:: flag_dyn_forcing


   .. py:attribute:: flag_diagnostic


   .. py:attribute:: flag_forcing


   .. py:attribute:: flag_static


   .. py:method:: __call__(sample: credit.data.Sample, inverse: bool = False) -> credit.data.Sample

      Normalize ERA5 and Forcing.

      :param sample: batch.
      :param inverse: whether to transform or inverse transform the sample.

      :returns: transformed and normalized sample.
      :rtype: torch.tensor



   .. py:method:: transform_dataset(DS: xarray.Dataset) -> xarray.Dataset


   .. py:method:: transform_array(x: torch.Tensor) -> torch.Tensor

      Transform of y_pred.

      Transform via provided scaler file/s of the prediction variable.
      Dynamic forcing, forcing, and static vars not transformed.

      :param x: batch.

      :returns: transformed x.



   .. py:method:: transform(sample: Dict[str, numpy.ndarray]) -> Dict[str, numpy.ndarray]

      Transform training batches.

      Transform handles forcing & static as follows:
      - forcing & static don't need to be transformed; users should transform them and save them to the file
      - other variables (upper-air, surface, dynamic forcing, diagnostics) need to be transformed

      :param sample: batch.

      :returns: transformed sample.



   .. py:method:: inverse_transform(x: torch.Tensor) -> torch.Tensor

      Inverse transform of y_pred.

      Inverse transform of prediction variable. Dynamic forcing, forcing,
      and static vars not transformed.

      :param x: batch.

      :returns: inverse transformed x.



   .. py:method:: inverse_transform_input(x: torch.Tensor) -> torch.Tensor

      Inverse transform for input x.

      Forcing and static variables are not transformed
      (they were not transformed in the transform function).

      :param x: batch.

      :returns: transformed x.



.. py:class:: BridgescalerScaleState(conf)

   Bases: :py:obj:`object`


   Convert to rescaled tensor using Bridgescaler.


   .. py:attribute:: scaler_file


   .. py:attribute:: variables


   .. py:attribute:: surface_variables


   .. py:attribute:: n_levels


   .. py:attribute:: var_levels
      :value: []



   .. py:attribute:: n_surface_variables


   .. py:attribute:: n_3dvar_levels


   .. py:attribute:: scaler_df


   .. py:attribute:: scaler_3d


   .. py:attribute:: scaler_surf


   .. py:method:: inverse_transform(x: torch.Tensor) -> torch.Tensor

      Inverse transform.

      Inverse transform.

      :param x: batch.

      :returns: inverse transformed batch.



   .. py:method:: transform_array(x: torch.Tensor) -> torch.Tensor

      Transform.

      Transform.

      :param x: batch.

      :returns: transformed batch.



   .. py:method:: transform(sample: Dict[str, numpy.ndarray]) -> Dict[str, numpy.ndarray]

      Transform.

      Transform.

      :param sample: batch.

      :returns: transformed batch.



.. py:class:: NormalizeState_Quantile(conf)

   Class to use the Quantile scaler functionality.


   .. py:attribute:: scaler_file


   .. py:attribute:: variables


   .. py:attribute:: surface_variables


   .. py:attribute:: levels


   .. py:attribute:: scaler_df


   .. py:attribute:: scaler_3ds


   .. py:attribute:: scaler_surfs


   .. py:attribute:: scaler_3d


   .. py:attribute:: scaler_surf


   .. py:method:: __call__(sample: credit.data.Sample, inverse: bool = False) -> credit.data.Sample

      Normalize via quantile transform.

      Normalize via provided scaler file/s.

      :param sample: batch.
      :param inverse: if true, will inverse the transform.

      :returns: transformed type.
      :rtype: torch.tensor



   .. py:method:: inverse_transform(x: torch.Tensor) -> torch.Tensor

      Inverse transform.

      Inverse transform.

      :param x: batch.

      :returns: inverse transformed x.



   .. py:method:: transform(sample: Dict[str, numpy.ndarray]) -> Dict[str, numpy.ndarray]

      Transform.

      Transform.

      :param sample: batch.

      :returns: transformed batch.



.. py:class:: NormalizeTendency(variables, surface_variables, base_path)

   Normalize tendency.


   .. py:attribute:: variables


   .. py:attribute:: surface_variables


   .. py:attribute:: base_path


   .. py:attribute:: mean


   .. py:attribute:: std


   .. py:method:: transform(tensor, surface_tensor)

      Transform.

      Transform input tensor/s.

      :param tensor: batch.
      :type tensor: torch tensor
      :param surface_tensor: surface batch.
      :type surface_tensor: torch tensor

      :returns: transformed torch tensors.
      :rtype: torch.Tensor



   .. py:method:: inverse_transform(tensor, surface_tensor)

      Inverse transform.

      Inverse transform input tensor/s.

      :param tensor: batch.
      :type tensor: torch tensor
      :param surface_tensor: surface batch.
      :type surface_tensor: torch tensor

      :returns: inverse transformed torch tensors.
      :rtype: torch.Tensor



.. py:class:: ToTensor(conf)

   Convert variables from xr.Datasets to Pytorch Tensors.


   .. py:attribute:: conf


   .. py:attribute:: hist_len


   .. py:attribute:: for_len


   .. py:attribute:: variables


   .. py:attribute:: surface_variables


   .. py:attribute:: allvars


   .. py:attribute:: static_variables


   .. py:method:: __call__(sample: credit.data.Sample) -> credit.data.Sample

      Convert to reshaped tensor.

      Reshape and convert to torch tensor.

      :param sample: batch.
      :type sample: interator

      :returns: reshaped torch tensor.
      :rtype: torch.tensor



.. py:class:: ToTensor_ERA5_and_Forcing(conf)

   Class to convert ERA5 and Forcing Datasets to torch tensor.


   .. py:attribute:: conf


   .. py:attribute:: output_dtype
      :value: Ellipsis



   .. py:attribute:: hist_len


   .. py:attribute:: for_len


   .. py:attribute:: flag_surface


   .. py:attribute:: flag_dyn_forcing


   .. py:attribute:: flag_diagnostic


   .. py:attribute:: flag_forcing


   .. py:attribute:: flag_static


   .. py:attribute:: varname_upper_air


   .. py:attribute:: flag_upper_air
      :value: True



   .. py:attribute:: num_forcing_static
      :value: 0



   .. py:method:: __call__(sample: credit.data.Sample) -> credit.data.Sample

      Convert variables to input/output torch tensors.

      :param sample: batch.
      :type sample: interator

      :returns: converted torch tensor.
      :rtype: torch.tensor



.. py:class:: NormalizeState_Quantile_Bridgescalar(conf)

   Class to use the bridgescaler Quantile functionality.

   Some hoops have to be jumped thorugh, and the efficiency could be
   improved if we were to retrain the bridgescaler.


   .. py:attribute:: scaler_file


   .. py:attribute:: variables


   .. py:attribute:: surface_variables


   .. py:attribute:: levels


   .. py:attribute:: scaler_df


   .. py:attribute:: scaler_3ds


   .. py:attribute:: scaler_surfs


   .. py:attribute:: scaler_3d


   .. py:attribute:: scaler_surf


   .. py:method:: __call__(sample: credit.data.Sample, inverse: bool = False) -> credit.data.Sample

      Normalize via quantile transform with bridgescaler.

      Normalize via provided scaler file/s.

      :param sample: batch.
      :type sample: iterator

      :returns: transformed torch tensor.
      :rtype: torch.tensor



   .. py:method:: inverse_transform(x: torch.Tensor) -> torch.Tensor

      Inverse transform.

      Inverse transform via provided scaler file/s.

      :param x: batch.

      :returns: inverse transformed torch tensor.



   .. py:method:: transform(sample)

      Transform.

      Transform via provided scaler file/s.

      :param sample: batch.
      :type sample: iterator

      :returns: transformed torch tensor.
      :rtype: torch.Tensor



.. py:class:: ToTensor_BridgeScaler(conf)

   Convert to reshaped tensor.


   .. py:attribute:: conf


   .. py:attribute:: hist_len


   .. py:attribute:: for_len


   .. py:attribute:: variables


   .. py:attribute:: surface_variables


   .. py:attribute:: allvars


   .. py:attribute:: static_variables


   .. py:attribute:: latN


   .. py:attribute:: lonN


   .. py:attribute:: levels


   .. py:attribute:: one_shot


   .. py:method:: __call__(sample: credit.data.Sample) -> credit.data.Sample

      Convert to reshaped tensor.

      Reshape and convert to torch tensor.

      :param sample: batch.
      :type sample: interator

      :returns: reshaped torch tensor.
      :rtype: torch.tensor



