credit.postblock.gen1
=====================

.. py:module:: credit.postblock.gen1

.. autoapi-nested-parse::

   postblock.py
   -------------------------------------------------------
   Content:
       - PostBlock
       - TracerFixer
       - GlobalMassFixer
       - GlobalWaterFixer
       - GlobalEnergyFixer



Attributes
----------

.. autoapisummary::

   credit.postblock.gen1.PI
   credit.postblock.gen1.logger


Classes
-------

.. autoapisummary::

   credit.postblock.gen1.PostBlock
   credit.postblock.gen1.TracerFixer
   credit.postblock.gen1.GlobalMassFixer
   credit.postblock.gen1.GlobalWaterFixer
   credit.postblock.gen1.GlobalEnergyFixer
   credit.postblock.gen1.GlobalEnergyFixerUpDown


Functions
---------

.. autoapisummary::

   credit.postblock.gen1.concat_fix


Module Contents
---------------

.. py:data:: PI
   :value: 3.141592653589793


.. py:data:: logger

.. py:class:: PostBlock(post_conf)

   Bases: :py:obj:`torch.nn.Module`


   Base class for all neural network modules.

   Your models should also subclass this class.

   Modules can also contain other Modules, allowing them to be nested in
   a tree structure. You can assign the submodules as regular attributes::

       import torch.nn as nn
       import torch.nn.functional as F


       class Model(nn.Module):
           def __init__(self) -> None:
               super().__init__()
               self.conv1 = nn.Conv2d(1, 20, 5)
               self.conv2 = nn.Conv2d(20, 20, 5)

           def forward(self, x):
               x = F.relu(self.conv1(x))
               return F.relu(self.conv2(x))

   Submodules assigned in this way will be registered, and will also have their
   parameters converted when you call :meth:`to`, etc.

   .. note::
       As per the example above, an ``__init__()`` call to the parent class
       must be made before assignment on the child.

   :ivar training: Boolean represents whether this module is in training or
                   evaluation mode.
   :vartype training: bool


   .. py:attribute:: operations


   .. py:method:: forward(x)


.. py:class:: TracerFixer(post_conf)

   Bases: :py:obj:`torch.nn.Module`


   This module fixes tracer values by replacing their values to a given threshold
   (e.g., `tracer[tracer<thres] = thres`).

   :param post_conf: config dictionary that includes all specs for the tracer fixer.
   :type post_conf: dict


   .. py:attribute:: tracer_indices


   .. py:attribute:: tracer_thres


   .. py:attribute:: tracer_thres_max


   .. py:method:: forward(x)


.. py:class:: GlobalMassFixer(post_conf)

   Bases: :py:obj:`torch.nn.Module`


   This module applies global mass conservation fixes for both dry air and water budget.
   The output ensures that the global dry air mass and global water budgets are conserved
   through correction ratios applied during model runs. Variables `specific total water`
   and `precipitation` will be corrected to close the budget. All corrections are done
   using float32 PyTorch tensors.

   :param post_conf: config dictionary that includes all specs for the global mass fixer.
   :type post_conf: dict


   .. py:attribute:: q_ind_start


   .. py:attribute:: q_ind_end


   .. py:method:: forward(x)


.. py:class:: GlobalWaterFixer(post_conf)

   Bases: :py:obj:`torch.nn.Module`


   Base class for all neural network modules.

   Your models should also subclass this class.

   Modules can also contain other Modules, allowing them to be nested in
   a tree structure. You can assign the submodules as regular attributes::

       import torch.nn as nn
       import torch.nn.functional as F


       class Model(nn.Module):
           def __init__(self) -> None:
               super().__init__()
               self.conv1 = nn.Conv2d(1, 20, 5)
               self.conv2 = nn.Conv2d(20, 20, 5)

           def forward(self, x):
               x = F.relu(self.conv1(x))
               return F.relu(self.conv2(x))

   Submodules assigned in this way will be registered, and will also have their
   parameters converted when you call :meth:`to`, etc.

   .. note::
       As per the example above, an ``__init__()`` call to the parent class
       must be made before assignment on the child.

   :ivar training: Boolean represents whether this module is in training or
                   evaluation mode.
   :vartype training: bool


   .. py:attribute:: q_ind_start


   .. py:attribute:: q_ind_end


   .. py:attribute:: precip_ind


   .. py:attribute:: evapor_ind


   .. py:method:: forward(x)


.. py:class:: GlobalEnergyFixer(post_conf)

   Bases: :py:obj:`torch.nn.Module`


   This module applys global energy conservation fixes. The output ensures that the global sum
   of total energy in the atmosphere is balanced by radiantion and energy fluxes at the top of
   the atmosphere and the surface. Variables `air temperature` will be modified to close the
   budget. All corrections are done using float32 Pytorch tensors.

   :param post_conf: config dictionary that includes all specs for the global energy fixer.
   :type post_conf: dict


   .. py:attribute:: T_ind_start


   .. py:attribute:: T_ind_end


   .. py:attribute:: q_ind_start


   .. py:attribute:: q_ind_end


   .. py:attribute:: U_ind_start


   .. py:attribute:: U_ind_end


   .. py:attribute:: V_ind_start


   .. py:attribute:: V_ind_end


   .. py:attribute:: TOA_solar_ind


   .. py:attribute:: TOA_OLR_ind


   .. py:attribute:: surf_solar_ind


   .. py:attribute:: surf_LR_ind


   .. py:attribute:: surf_SH_ind


   .. py:attribute:: surf_LH_ind


   .. py:method:: forward(x)


.. py:class:: GlobalEnergyFixerUpDown(post_conf)

   Bases: :py:obj:`torch.nn.Module`


   Global energy conservation fixer using explicit up/down flux decomposition.

   Identical correction logic to ``GlobalEnergyFixer`` but uses separate downwelling
   and upwelling flux indices rather than pre-computed net fluxes.  The net TOA and
   surface imbalances are formed as:

   .. code-block:: text

       R_T  = (DSWRFtoa  - USWRFtoa  - ULWRFtoa) / N_seconds
       F_S  = (FSDS_J    - FSUS      + FLDS_J    - FLUS - SHF - LHF) / N_seconds

   where ``*_J`` variables are in J/m² (energy over the timestep) and ``SHF``/``LHF``
   are positive-upward surface turbulent heat fluxes also in J/m².

   :param post_conf: config dictionary.  The sub-key ``global_energy_fixer_updown``
                     must be present and contain all specs listed below.
   :type post_conf: dict

   Config keys (under ``global_energy_fixer_updown``):
       - ``activate`` / ``activate_outside_model`` / ``simple_demo``
       - ``midpoint``, ``denorm``, ``surface_geopotential_name``
       - ``T_inds``, ``q_inds``, ``U_inds``, ``V_inds``
       - ``sp_inds``  (required when ``grid_type == 'sigma'``)
       - ``TOA_down_solar_ind``  — DSWRFtoa index in y_pred
       - ``TOA_up_solar_ind``   — USWRFtoa index in y_pred
       - ``TOA_up_OLR_ind``     — ULWRFtoa index in y_pred
       - ``surf_down_solar_ind`` — FSDS_J index in y_pred
       - ``surf_up_solar_ind``  — FSUS index in y_pred
       - ``surf_down_LW_ind``   — FLDS_J index in y_pred
       - ``surf_up_LW_ind``     — FLUS index in y_pred
       - ``surf_SH_ind``        — SHF  index in y_pred  (positive-upward)
       - ``surf_LH_ind``        — LHF  index in y_pred  (positive-upward)


   .. py:attribute:: T_ind_start


   .. py:attribute:: T_ind_end


   .. py:attribute:: q_ind_start


   .. py:attribute:: q_ind_end


   .. py:attribute:: U_ind_start


   .. py:attribute:: U_ind_end


   .. py:attribute:: V_ind_start


   .. py:attribute:: V_ind_end


   .. py:attribute:: TOA_down_solar_ind


   .. py:attribute:: TOA_up_solar_ind


   .. py:attribute:: TOA_up_OLR_ind


   .. py:attribute:: surf_down_solar_ind


   .. py:attribute:: surf_up_solar_ind


   .. py:attribute:: surf_down_LW_ind


   .. py:attribute:: surf_up_LW_ind


   .. py:attribute:: surf_SH_ind


   .. py:attribute:: surf_LH_ind


   .. py:method:: forward(x)


.. py:function:: concat_fix(y_pred, q_pred_correct, q_ind_start, q_ind_end, N_vars)

   this function use torch.concat to replace a specific subset of variable channels in `y_pred`.

   Given `q_pred = y_pred[:, ind_start:ind_end, ...]`, and `q_pred_correct` this function
   does: `y_pred[:, ind_start:ind_end, ...] = q_pred_correct`, but without using in-place
   modifications, so the graph of y_pred is maintained. It also handles
   `q_ind_start == q_ind_end cases`.

   All input tensors must have 5 dims of `batch, level-or-var, time, lat, lon`

   :param y_pred: Original y_pred tensor of shape (batch, var, time, lat, lon).
   :type y_pred: torch.Tensor
   :param q_pred_correct: Corrected q_pred tensor.
   :type q_pred_correct: torch.Tensor
   :param q_ind_start: Index where q_pred starts in y_pred.
   :type q_ind_start: int
   :param q_ind_end: Index where q_pred ends in y_pred.
   :type q_ind_end: int
   :param N_vars: Total number of variables in y_pred (i.e., y_pred.shape[1]).
   :type N_vars: int

   :returns: Concatenated y_pred with corrected q_pred.
   :rtype: torch.Tensor


