credit.postblock
================

.. py:module:: credit.postblock


Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/credit/postblock/base/index
   /autoapi/credit/postblock/gen1/index
   /autoapi/credit/postblock/geopotential/index
   /autoapi/credit/postblock/reconstruct/index
   /autoapi/credit/postblock/scaler/index
   /autoapi/credit/postblock/wet_mask_samudra/index


Attributes
----------

.. autoapisummary::

   credit.postblock.POSTBLOCK_REGISTRY
   credit.postblock._VALID_SECTIONS


Classes
-------

.. autoapisummary::

   credit.postblock.Reconstruct
   credit.postblock.WetMaskBlock
   credit.postblock.BridgeScalerTransformer
   credit.postblock.TracerFixer
   credit.postblock.GlobalMassFixer
   credit.postblock.GlobalWaterFixer
   credit.postblock.GlobalEnergyFixer
   credit.postblock.GeopotentialDiagnostic


Functions
---------

.. autoapisummary::

   credit.postblock._build_postblock_section
   credit.postblock.build_postblocks
   credit.postblock.apply_postblocks


Package Contents
----------------

.. py:class:: Reconstruct(*args: Any, **kwargs: Any)

   Bases: :py:obj:`credit.postblock.base.BasePostblock`


   Splits ``batch_dict["y_pred"]`` into a nested variable dict at ``batch_dict["y_processed"]``.

   Slices are read from ``batch_dict["metadata"]["target"]["_channel_map"]``, built
   by ``ConcatToTensor`` and covering only prognostic + diagnostic variables.
   Each slice is unflattened from ``(B, n_levels * n_time, H, W)`` back to
   ``(B, n_levels, n_time, H, W)``. ``y_pred`` is left untouched. All other
   keys in ``batch_dict`` pass through unchanged.


   .. py:method:: forward(batch_dict: dict) -> dict


.. py:class:: WetMaskBlock(conf, key: str = 'prediction')

   Bases: :py:obj:`torch.nn.Module`


   Post-processing layer that applies wet mask to ocean predictions.
   Zero trainable parameters, but mask influences gradients.

   Masks predictions so land points = 0, ocean points preserve values.
   This encourages the model to focus learning on ocean regions.


   .. py:attribute:: key
      :value: 'prediction'



   .. py:method:: forward(batch_dict: dict) -> dict

      Apply wet mask to ``batch_dict[self.key]`` (land=0, ocean preserved).



.. py:class:: BridgeScalerTransformer(scaler_path: str, variables: list[str], method: str, key: str = 'y_processed')

   Bases: :py:obj:`credit.postblock.base.BasePostblock`


   Scaling postblock using a fitted bridgescaler dict.

   Applies per-variable scaling (or its inverse) to the nested prediction dict
   at ``batch_dict[key]``, which has the form
   ``batch_dict[key][source][var_key]`` where ``var_key`` is
   ``"source/field_type/dim/varname"`` (e.g. ``"era5/prognostic/3d/T"``).

   Defaults to operating on ``"y_processed"`` — the nested dict written by
   ``Reconstruct``. Use ``method="inverse_transform"`` to convert normalized
   predictions back to physical units before physics fixers.

   The scaler dict must have been fit with ``bridgescaler.scale_var_dict``
   using the same nested structure and saved with ``bridgescaler.save_scaler_dict``.

   Example config::

       type: "bridgescaler_transform"
       args:
           scaler_path: "/path/to/scaler.json"
           variables:
               - "era5/prognostic/3d/T"
               - "era5/prognostic/3d/U"
           method: "inverse_transform"


   .. py:attribute:: variables


   .. py:attribute:: method


   .. py:attribute:: scaler_path


   .. py:attribute:: key
      :value: 'y_processed'



   .. py:attribute:: scaler


   .. py:method:: forward(batch_dict: dict) -> dict


.. py:class:: TracerFixer(post_conf)

   Bases: :py:obj:`torch.nn.Module`


   This module fixes tracer values by replacing their values to a given threshold
   (e.g., `tracer[tracer<thres] = thres`).

   :param post_conf: config dictionary that includes all specs for the tracer fixer.
   :type post_conf: dict


   .. py:attribute:: tracer_indices


   .. py:attribute:: tracer_thres


   .. py:attribute:: tracer_thres_max


   .. py:method:: forward(x)


.. py:class:: GlobalMassFixer(post_conf)

   Bases: :py:obj:`torch.nn.Module`


   This module applies global mass conservation fixes for both dry air and water budget.
   The output ensures that the global dry air mass and global water budgets are conserved
   through correction ratios applied during model runs. Variables `specific total water`
   and `precipitation` will be corrected to close the budget. All corrections are done
   using float32 PyTorch tensors.

   :param post_conf: config dictionary that includes all specs for the global mass fixer.
   :type post_conf: dict


   .. py:attribute:: q_ind_start


   .. py:attribute:: q_ind_end


   .. py:method:: forward(x)


.. py:class:: GlobalWaterFixer(post_conf)

   Bases: :py:obj:`torch.nn.Module`


   Base class for all neural network modules.

   Your models should also subclass this class.

   Modules can also contain other Modules, allowing them to be nested in
   a tree structure. You can assign the submodules as regular attributes::

       import torch.nn as nn
       import torch.nn.functional as F


       class Model(nn.Module):
           def __init__(self) -> None:
               super().__init__()
               self.conv1 = nn.Conv2d(1, 20, 5)
               self.conv2 = nn.Conv2d(20, 20, 5)

           def forward(self, x):
               x = F.relu(self.conv1(x))
               return F.relu(self.conv2(x))

   Submodules assigned in this way will be registered, and will also have their
   parameters converted when you call :meth:`to`, etc.

   .. note::
       As per the example above, an ``__init__()`` call to the parent class
       must be made before assignment on the child.

   :ivar training: Boolean represents whether this module is in training or
                   evaluation mode.
   :vartype training: bool


   .. py:attribute:: q_ind_start


   .. py:attribute:: q_ind_end


   .. py:attribute:: precip_ind


   .. py:attribute:: evapor_ind


   .. py:method:: forward(x)


.. py:class:: GlobalEnergyFixer(post_conf)

   Bases: :py:obj:`torch.nn.Module`


   This module applys global energy conservation fixes. The output ensures that the global sum
   of total energy in the atmosphere is balanced by radiantion and energy fluxes at the top of
   the atmosphere and the surface. Variables `air temperature` will be modified to close the
   budget. All corrections are done using float32 Pytorch tensors.

   :param post_conf: config dictionary that includes all specs for the global energy fixer.
   :type post_conf: dict


   .. py:attribute:: T_ind_start


   .. py:attribute:: T_ind_end


   .. py:attribute:: q_ind_start


   .. py:attribute:: q_ind_end


   .. py:attribute:: U_ind_start


   .. py:attribute:: U_ind_end


   .. py:attribute:: V_ind_start


   .. py:attribute:: V_ind_end


   .. py:attribute:: TOA_solar_ind


   .. py:attribute:: TOA_OLR_ind


   .. py:attribute:: surf_solar_ind


   .. py:attribute:: surf_LR_ind


   .. py:attribute:: surf_SH_ind


   .. py:attribute:: surf_LH_ind


   .. py:method:: forward(x)


.. py:class:: GeopotentialDiagnostic(output_name: str = 'ARCO_ERA5/derived_diagnostic/3d/geopotential', dataset_name: str = 'ARCO_ERA5', chunk_size: int = 1000, data_keys: Iterable[str] = ('prediction', 'target'), surface_geopotential_var: str = 'ARCO_ERA5/static/2d/geopotential_at_surface', surface_pressure_var: str = 'ARCO_ERA5/prognostic/2d/surface_pressure', temperature_var: str = 'ARCO_ERA5/prognostic/3d/temperature', specific_humidity_var: str = 'ARCO_ERA5/prognostic/3d/specific_humidity', flip_vertical: bool = True, level_info_file: str = 'ERA5_Lev_Info.nc', model_a_half_var: str = 'a_half', model_b_half_var: str = 'b_half', static_source_key: str = 'ic_raw', levels: list[int] | None = None)

   Bases: :py:obj:`torch.nn.Module`


   GeopotentialDiagnostic is a neural network module used for computing geopotential
   diagnostics using multi-dimensional input data.

   This class processes geophysical variables such as surface geopotential, surface
   pressure, temperature, and specific humidity to calculate geopotential fields.
   The input data is expected to conform to a specific format, and the class makes
   use of auxiliary metadata files that describe model-specific level information.

   .. attribute:: output_name

      The key used in the dataset to store the computed
      geopotential diagnostic output.

      :type: str

   .. attribute:: dataset_name

      The name of the dataset from which input variables
      will be retrieved.

      :type: str

   .. attribute:: chunk_size

      The chunk size used for vectorized computations
      to optimize memory usage during processing.

      :type: int

   .. attribute:: data_keys

      The keys in the input data dictionary that
      will be processed (e.g., "prediction", "target").

      :type: Iterable[str]

   .. attribute:: surface_geopotential_var

      The key for the surface geopotential variable
      in the dataset.

      :type: str

   .. attribute:: surface_pressure_var

      The key for the surface pressure variable
      in the dataset.

      :type: str

   .. attribute:: temperature_var

      The key for the temperature variable in the dataset.

      :type: str

   .. attribute:: specific_humidity_var

      The key for the specific humidity variable
      in the dataset.

      :type: str

   .. attribute:: flip_vertical

      Whether to flip the vertical dimension of the input tensors. Default True

      :type: bool

   .. attribute:: level_info_file

      The filename of the auxiliary metadata file that
      stores information about model levels.

      :type: str

   .. attribute:: model_a_half_var

      The variable name for the `a` (pressure) hybrid sigma-pressure coefficient in
      the level information file.

      :type: str

   .. attribute:: model_b_half_var

      The variable name for the `b` (sigma) hybrid sigma-pressure coefficient parameter in
      the level information file.

      :type: str


   .. py:attribute:: output_name
      :value: 'ARCO_ERA5/derived_diagnostic/3d/geopotential'



   .. py:attribute:: dataset_name
      :value: 'ARCO_ERA5'



   .. py:attribute:: chunk_size
      :value: 1000



   .. py:attribute:: data_keys
      :value: ('prediction', 'target')



   .. py:attribute:: surface_geopotential_var
      :value: 'ARCO_ERA5/static/2d/geopotential_at_surface'



   .. py:attribute:: surface_pressure_var
      :value: 'ARCO_ERA5/prognostic/2d/surface_pressure'



   .. py:attribute:: temperature_var
      :value: 'ARCO_ERA5/prognostic/3d/temperature'



   .. py:attribute:: specific_humidity_var
      :value: 'ARCO_ERA5/prognostic/3d/specific_humidity'



   .. py:attribute:: flip_vertical
      :value: True



   .. py:attribute:: level_info_file


   .. py:attribute:: model_a_half_var
      :value: 'a_half'



   .. py:attribute:: model_b_half_var
      :value: 'b_half'



   .. py:attribute:: static_source_key
      :value: 'ic_raw'



   .. py:attribute:: levels
      :value: None



   .. py:method:: forward(data_dict: dict)

      Processes a dictionary of input data, rearranges dimensions, computes derived quantities
      using a custom function `geopotential`, and updates the data dictionary with the results.

      :param data_dict: Input dictionary containing data corresponding to various
                        data types. The data for each type is expected to be organized into specified
                        attributes (e.g., temperature, specific humidity).
      :type data_dict: dict

      :returns: Updated data dictionary, where new computed fields are added to the
                relevant dataset, preserving the original structure.
      :rtype: dict

      :raises ValueError: If any required data type is not found in the input `data_dict`.



.. py:data:: POSTBLOCK_REGISTRY

.. py:data:: _VALID_SECTIONS

.. py:function:: _build_postblock_section(section_cfg: dict) -> torch.nn.ModuleDict

.. py:function:: build_postblocks(postblock_cfg: dict | None = None, phase: str = 'per_step') -> torch.nn.ModuleDict

   Instantiate postblocks for a single phase from a two-section config.

   Config format::

       postblocks:
         per_step:          # run after every forward pass in the rollout loop
           reconstruct:
             type: reconstruct
           inverse_scale:
             type: bridgescaler_transform
             args:
               method: inverse_transform
               scaler_path: /path/to/scaler.json
         post_rollout:      # run once after all rollout steps complete
           mass_fixer:
             type: global_mass_fixer
             args: ...

   Typical usage — build once per phase, store separately::

       step_postblocks    = build_postblocks(cfg, phase="per_step")
       rollout_postblocks = build_postblocks(cfg, phase="post_rollout")

       # inside rollout loop, after each forward pass:
       full_data_dict = apply_postblocks(step_postblocks, full_data_dict)

       # once after rollout loop completes:
       apply_postblocks(rollout_postblocks, full_data_dict)

   :param postblock_cfg: the full ``postblocks`` config dict (both sections).
   :param phase: which section to build — ``"per_step"`` or ``"post_rollout"``.

   :returns: ``nn.ModuleDict`` of instantiated blocks for the requested phase.

   :raises ValueError: if the config contains keys other than ``"per_step"`` / ``"post_rollout"``,
       or if ``phase`` is not one of those values.


.. py:function:: apply_postblocks(postblocks: torch.nn.ModuleDict, batch_dict: dict) -> dict

   Apply a postblock group built by ``build_postblocks``.

   :param postblocks: ``nn.ModuleDict`` built by ``build_postblocks`` for a single phase.
   :param batch_dict: dict containing at minimum ``"y_pred"`` and ``"metadata"``.

   :returns: The same ``batch_dict`` after all blocks in the group have run.


