credit.preblock
===============

.. py:module:: credit.preblock


Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/credit/preblock/base/index
   /autoapi/credit/preblock/concat/index
   /autoapi/credit/preblock/log/index
   /autoapi/credit/preblock/regrid/index
   /autoapi/credit/preblock/scaler/index
   /autoapi/credit/preblock/sqrt/index


Attributes
----------

.. autoapisummary::

   credit.preblock.PREBLOCK_REGISTRY


Classes
-------

.. autoapisummary::

   credit.preblock.LogTransform
   credit.preblock.SqrtTransform
   credit.preblock.BridgeScalerTransformer
   credit.preblock.Regridder
   credit.preblock.ConcatToTensor


Functions
---------

.. autoapisummary::

   credit.preblock.build_preblocks
   credit.preblock.apply_preblocks


Package Contents
----------------

.. py:class:: LogTransform(variables: list[str], data_types: list[str] = None, base: str = 'e', eps: float = 1e-08)

   Bases: :py:obj:`credit.preblock.base.BasePreblock`


   Applies a log transformation to specified variables in a batch dict.

   Expected dict structure:
       batch[source][data_type]['source/var_type/var_shape/var_name']

   Config example:
       type: "log_transform"
       args:
           variables:
               - 'ERA5/prognostic/3D/Q'
           data_types:     # optional, defaults to ['input', 'target']
               - 'input'
               - 'target'
           base: 'e'       # optional, default 'e'. Options: 'e', '2', '10'
           eps: 1.0e-8     # optional, default 1e-8


   .. py:attribute:: variables


   .. py:attribute:: data_types
      :value: ['input', 'target']



   .. py:attribute:: eps


   .. py:method:: forward(batch: dict) -> dict


.. py:class:: SqrtTransform(variables: list[str], data_types: list[str] = None)

   Bases: :py:obj:`credit.preblock.base.BasePreblock`


   Applies a log transformation to specified variables in a batch dict.

   Expected dict structure:
       batch[source][data_type]['source/var_type/var_shape/var_name']

   Config example:
       type: "sqrt_transform"
       args:
           variables:
               - 'ERA5/prognostic/3D/Q'
           data_types:     # optional, defaults to ['input', 'target']
               - 'input'
               - 'target'


   .. py:attribute:: variables


   .. py:attribute:: data_types
      :value: ['input', 'target']



   .. py:method:: forward(batch: dict) -> dict


.. py:class:: BridgeScalerTransformer(scaler_path: str, variables: list[str], method: str)

   Bases: :py:obj:`credit.preblock.base.BasePreblock`


   Scaling preblock using a fitted bridgescaler dict.

   Applies per-variable z-score scaling (or its inverse) to tensors in a
   nested batch dict of the form ``batch[source][data_type][var_key]``.

   The scaler dict must have been fit with ``bridgescaler.scale_var_dict``
   using the same nested structure and saved with ``bridgescaler.save_scaler_dict``.

   Example config::

       type: "bridgescaler_transform"
       args:
           scaler_path: "/path/to/scaler.json"
           variables:
               - "era5/prognostic/3d/T"
               - "era5/prognostic/3d/U"
           method: "transform"


   .. py:attribute:: variables


   .. py:attribute:: method


   .. py:attribute:: scaler_path


   .. py:attribute:: scaler


   .. py:method:: forward(batch: dict) -> dict


.. py:class:: Regridder(weight_file, variables: list[str], data_types: list[str] = None, reshape_to_xy=True, flip_axis=None)

   Bases: :py:obj:`credit.preblock.base.BasePreblock`


   Regridding layer using weights file provided by the ESMF library.
   :param weight_file: path to weights file
   :param variables: list of variable keys to regrid (e.g. ['era5/prognostic/3d/T'])
   :param data_types: list of data types to process (default: ['input', 'target'])
   :param reshape_to_xy: whether to reshape the flattened array back to xy coordinates
   :param flip_axis: axes to flip before regridding (e.g. [-1, -2])
   :type flip_axis: list, tuple, or None


   .. py:attribute:: variables


   .. py:attribute:: data_types
      :value: ['input', 'target']



   .. py:attribute:: reshape_to_xy
      :value: True



   .. py:attribute:: flip_axis
      :value: None



   .. py:attribute:: n_a


   .. py:attribute:: n_b


   .. py:attribute:: dst_shape


   .. py:attribute:: _W
      :value: None



   .. py:attribute:: _W_device
      :value: None



   .. py:method:: _get_W(device)


   .. py:method:: _regrid(x: torch.Tensor) -> torch.Tensor


   .. py:method:: forward(batch: dict) -> dict


.. py:class:: ConcatToTensor(*args: Any, **kwargs: Any)

   Bases: :py:obj:`credit.preblock.base.BasePreblock`


   End-of-chain preblock that concatenates a nested batch dict of tensors
   into a single input tensor (and optionally a target tensor).

   Expects a batch dict of the form::

       batch[source][data_type][var_name] -> torch.Tensor

   where tensor shapes are (batch, channel, time, lon, lat) and concatenation
   is performed along dim=1 (channel). Traversal order follows key insertion
   order: for each source, all var_names under a data_type are concatenated,
   then the next source, and so on.

   ``metadata`` keys are passed through as-is (not concatenated).

   Returns either::

       (input_tensor, metadata)                    # if no "target" data_type present
       (input_tensor, target_tensor, metadata)     # if "target" is present

   Example config::

       type: "concatenate_to_tensor"
       args: {}


   .. py:method:: forward(batch)


.. py:data:: PREBLOCK_REGISTRY

.. py:function:: build_preblocks(preblock_cfg: dict) -> torch.nn.ModuleDict

   Instantiates all preblocks from the config's 'preblocks' section.

   :param preblock_cfg: the full preblocks dict from the config, e.g.:
                        {
                            'era5_log_transform': {'type': 'log_transform', 'args': {...}},
                            'era5_z_transform':   {'type': 'z_transform',   'args': {...}},
                        }

   :returns: nn.ModuleDict of instantiated preblocks, ordered as in config.


.. py:function:: apply_preblocks(preblocks: torch.nn.ModuleDict, batch: dict)

   Sequentially applies transform preblocks (dict→dict), then concatenates to tensors.

   Concatenation is always performed last and is not configurable.


