credit.postblock.reconstruct
============================

.. py:module:: credit.postblock.reconstruct

.. autoapi-nested-parse::

   reconstruct.py
   --------------
   Reconstruct: first postblock that splits the flat ``batch_dict["y_pred"]``
   tensor into a nested variable dict, writing the result to ``batch_dict["y_processed"]``.

   Reads ``batch_dict["metadata"]["target"]["_channel_map"]`` built by
   ``ConcatToTensor`` to know which channels correspond to which variables.

   Input
   -----
   batch_dict : dict
       Must contain:
         "y_pred"   — flat model output tensor, shape (B, C, H, W) or (B, C, T, H, W)
         "metadata" — metadata dict with ``["target"]["_channel_map"]``
       All other keys pass through unchanged.

   Output
   ------
   The same ``batch_dict`` with ``"y_processed"`` added as a nested dict:

       batch_dict["y_processed"][source][var_key]
           -> tensor of shape (B, n_levels, n_time, H, W)

   ``"y_pred"`` is left intact (grad-attached) for use in loss computation.



Classes
-------

.. autoapisummary::

   credit.postblock.reconstruct.Reconstruct


Module Contents
---------------

.. py:class:: Reconstruct(*args: Any, **kwargs: Any)

   Bases: :py:obj:`credit.postblock.base.BasePostblock`


   Splits ``batch_dict["y_pred"]`` into a nested variable dict at ``batch_dict["y_processed"]``.

   Slices are read from ``batch_dict["metadata"]["target"]["_channel_map"]``, built
   by ``ConcatToTensor`` and covering only prognostic + diagnostic variables.
   Each slice is unflattened from ``(B, n_levels * n_time, H, W)`` back to
   ``(B, n_levels, n_time, H, W)``. ``y_pred`` is left untouched. All other
   keys in ``batch_dict`` pass through unchanged.


   .. py:method:: forward(batch_dict: dict) -> dict


