credit.preblock
===============

.. py:module:: credit.preblock


Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/credit/preblock/base/index
   /autoapi/credit/preblock/concat/index
   /autoapi/credit/preblock/log/index
   /autoapi/credit/preblock/norm/index
   /autoapi/credit/preblock/regrid/index
   /autoapi/credit/preblock/scaler/index
   /autoapi/credit/preblock/sqrt/index


Attributes
----------

.. autoapisummary::

   credit.preblock._BRIDGESCALER_AVAILABLE
   credit.preblock.PREBLOCK_REGISTRY
   credit.preblock._VALID_SECTIONS


Classes
-------

.. autoapisummary::

   credit.preblock.LogTransform
   credit.preblock.SqrtTransform
   credit.preblock.Regridder
   credit.preblock.ConcatToTensor
   credit.preblock.ERA5Normalizer


Functions
---------

.. autoapisummary::

   credit.preblock._build_preblock_section
   credit.preblock.build_preblocks
   credit.preblock._run_preblock_group
   credit.preblock.apply_preblocks


Package Contents
----------------

.. py:class:: LogTransform(variables: list[str], data_types: list[str] = None, base: str = 'e', eps: float = 1e-08)

   Bases: :py:obj:`credit.preblock.base.BasePreblock`


   Applies a log transformation to specified variables in a batch dict.

   Expected dict structure:
       batch[source][data_type]['source/var_type/var_shape/var_name']

   Config example:
       type: "log_transform"
       args:
           variables:
               - 'ERA5/prognostic/3D/Q'
           data_types:     # optional, defaults to ['input', 'target']
               - 'input'
               - 'target'
           base: 'e'       # optional, default 'e'. Options: 'e', '2', '10'
           eps: 1.0e-8     # optional, default 1e-8


   .. py:attribute:: variables


   .. py:attribute:: data_types
      :value: ['input', 'target']



   .. py:attribute:: eps


   .. py:method:: forward(batch: dict) -> dict


.. py:class:: SqrtTransform(variables: list[str], data_types: list[str] = None)

   Bases: :py:obj:`credit.preblock.base.BasePreblock`


   Applies a sqrt transformation to specified variables in a batch dict.

   Expected dict structure:
       batch[source][data_type]['source/var_type/var_shape/var_name']

   Config example:
       type: "sqrt_transform"
       args:
           variables:
               - 'ERA5/prognostic/3D/Q'
           data_types:     # optional, defaults to ['input', 'target']
               - 'input'
               - 'target'


   .. py:attribute:: variables


   .. py:attribute:: data_types
      :value: ['input', 'target']



   .. py:method:: forward(batch: dict) -> dict


.. py:class:: Regridder(weight_file, variables: list[str], data_types: list[str] = None, reshape_to_xy=True, flip_axis=None)

   Bases: :py:obj:`credit.preblock.base.BasePreblock`


   Regridding layer using weights file provided by the ESMF library.
   :param weight_file: path to weights file
   :param variables: list of variable keys to regrid (e.g. ['era5/prognostic/3d/T'])
   :param data_types: list of data types to process (default: ['input', 'target'])
   :param reshape_to_xy: whether to reshape the flattened array back to xy coordinates
   :param flip_axis: axes to flip before regridding (e.g. [-1, -2])
   :type flip_axis: list, tuple, or None


   .. py:attribute:: variables


   .. py:attribute:: data_types
      :value: ['input', 'target']



   .. py:attribute:: reshape_to_xy
      :value: True



   .. py:attribute:: flip_axis
      :value: None



   .. py:attribute:: n_a


   .. py:attribute:: n_b


   .. py:attribute:: dst_shape


   .. py:attribute:: _W
      :value: None



   .. py:attribute:: _W_device
      :value: None



   .. py:method:: _get_W(device)


   .. py:method:: _regrid(x: torch.Tensor) -> torch.Tensor


   .. py:method:: forward(batch: dict) -> dict


.. py:class:: ConcatToTensor(to_device: bool = True)

   Bases: :py:obj:`credit.preblock.base.BasePreblock`


   End-of-chain preblock that concatenates a nested batch dict of tensors
   into a single input tensor (and optionally a target tensor).

   Expects a batch dict of the form::

       batch[data_type][source][var_name] -> torch.Tensor

   where tensor shapes are (batch, n_levels, time, lat, lon) and concatenation
   is performed along dim=1 (channel). Input tensors are sorted by
   ``_channel_sort_key`` before concatenation so the channel order matches
   the canonical variable schema regardless of insertion order in the batch.

   ``metadata`` keys are passed through as-is (not concatenated).

   In addition to the tensors, two channel maps are attached to metadata under
   ``metadata["_channel_map"]``:

   * ``"input"``  — every variable and its slice in the concatenated input tensor.
   * ``"output"`` — prognostic + diagnostic variables only, with slices
     reindexed from 0 to match ``y_pred`` channel ordering.

   Each entry has the form::

       var_key -> {"slice": slice(start, end), "orig_shape": (n_levels, T)}

   Returns either::

       (input_tensor, metadata)                    # if no "target" data_type present
       (input_tensor, target_tensor, metadata)     # if "target" is present

   Example config::

       type: "concatenate_to_tensor"
       args:
         to_device: true   # set false to skip .to(device) in apply_preblocks


   .. py:attribute:: to_device
      :value: True



   .. py:method:: forward(batch)


.. py:class:: ERA5Normalizer(mean_path: str, std_path: str, levels: list[int] | None = None)

   Bases: :py:obj:`credit.preblock.base.BasePreblock`


   Normalizes per-variable ERA5 tensors using pre-computed mean/std files.

   Normalization: ``(x - mean) / std`` applied per variable. Variables not
   found in the statistics file are passed through unchanged.

   :param mean_path: Path to NetCDF file containing per-variable means.
   :param std_path: Path to NetCDF file containing per-variable standard deviations.
   :param levels: Optional list of 1-indexed model levels to select from the
                  full 137-level stats (e.g. [60, 90, 120, 137] for a 4-level
                  smoke test).  When omitted, all levels in the stats file are
                  used.


   .. py:attribute:: _mean
      :type:  dict[str, torch.Tensor]


   .. py:attribute:: _std
      :type:  dict[str, torch.Tensor]


   .. py:method:: _normalize_tensor(key: str, tensor: torch.Tensor) -> torch.Tensor

      Normalize *tensor* using the variable name extracted from *key*.



   .. py:method:: forward(batch: dict) -> dict

      Normalize all input/target tensors, returning a new batch dict.



.. py:data:: _BRIDGESCALER_AVAILABLE
   :value: True


.. py:data:: PREBLOCK_REGISTRY

.. py:data:: _VALID_SECTIONS

.. py:function:: _build_preblock_section(section_cfg: dict) -> torch.nn.ModuleDict

.. py:function:: build_preblocks(preblock_cfg: dict | None = None, phase: str = 'per_step') -> torch.nn.ModuleDict

   Instantiate preblocks for a single phase from a two-section config.

   Config format::

       preblocks:
         ic_only:          # run once at t=0 on the raw batch (e.g. static regrid)
           regrid_static:
             type: regrid
             args: ...
         per_step:         # run every rollout step (e.g. log_transform, concat)
           log_transform:
             type: log_transform
           concat:
             type: concat

   Typical usage — build once per phase, store separately::

       ic_preblocks   = build_preblocks(cfg, phase="ic_only")
       step_preblocks = build_preblocks(cfg, phase="per_step")

       # t=0: run both in sequence
       ic_preprocessed    = apply_preblocks(ic_preblocks, batch, device=device)
       preprocessed_batch = apply_preblocks(step_preblocks, ic_preprocessed, device=device)

       # t>0: run per_step only
       preprocessed_batch = apply_preblocks(step_preblocks, rollout_batch, device=device)

   :param preblock_cfg: the full ``preblocks`` config dict (both sections).
   :param phase: which section to build — ``"ic_only"`` or ``"per_step"``.

   :returns: ``nn.ModuleDict`` of instantiated blocks for the requested phase.

   :raises ValueError: if the config contains keys other than ``"ic_only"`` / ``"per_step"``,
       or if ``phase`` is not one of those values.


.. py:function:: _run_preblock_group(group: torch.nn.ModuleDict, batch: dict, device=None)

   Sequentially applies a group of preblocks, returning the transformed batch.


.. py:function:: apply_preblocks(preblocks: torch.nn.ModuleDict, batch: dict, device=None) -> dict

   Apply a preblock group built by ``build_preblocks``.

   :param preblocks: ``nn.ModuleDict`` built by ``build_preblocks`` for a single phase.
   :param batch: nested variable dict from the dataset (or a prior preblock pass).
   :param device: move output tensors here after concat.

   :returns: ``{"x": tensor, "y": tensor, "metadata": ...}``.
             Otherwise: the transformed nested batch dict (pre-concat).
   :rtype: When concat has run


