credit.datasets.channel_layout
==============================

.. py:module:: credit.datasets.channel_layout

.. autoapi-nested-parse::

   channel_layout.py
   -----------------
   Derive per-group channel slices directly from the variables config and
   provide a standalone update function for multi-step rollout.

   The concat order matches the canonical rank defined by FIELD_TYPE_RANK
   (prognostic < dynamic_forcing < static < diagnostic), with vars_3D before
   vars_2D within each group.  This is the same rank used by
   credit.preblock.concat.ConcatToTensor, so slices returned here always
   correspond to the actual channel positions in the model input tensor.

   Usage
   -----
       from credit.datasets.channel_layout import build_channel_layout, update_x

       # once at trainer / rollout init
       slices, n_pred = build_channel_layout(conf)

       # at every t > 1 step
       x = update_x(x, x_dynfrc, y_pred, slices)



Attributes
----------

.. autoapisummary::

   credit.datasets.channel_layout.FIELD_TYPE_RANK
   credit.datasets.channel_layout._DIAGNOSTIC
   credit.datasets.channel_layout._DATASET_DRIVEN
   credit.datasets.channel_layout._MODEL_PREDICTED


Functions
---------

.. autoapisummary::

   credit.datasets.channel_layout.build_channel_layout
   credit.datasets.channel_layout.update_x


Module Contents
---------------

.. py:data:: FIELD_TYPE_RANK
   :type:  dict[str, int]

.. py:data:: _DIAGNOSTIC

.. py:data:: _DATASET_DRIVEN

.. py:data:: _MODEL_PREDICTED

.. py:function:: build_channel_layout(conf)

   Return (slices, n_pred) derived from the variables config.

   Slices are ordered by FIELD_TYPE_RANK (matching ConcatToTensor), so
   slice offsets correspond directly to channel positions in the tensor.

   :param conf: Full CREDIT config dict.  Reads the first entry under
                conf["data"]["source"] for levels and variable groups.
   :type conf: dict

   :returns: * **slices** (*dict[str, slice]*) -- Mapping from group name to its slice in the channel dimension of x.
               Groups in _DIAGNOSTIC are excluded (they are never in x).
               Order matches FIELD_TYPE_RANK, not config key order.
             * **n_pred** (*int*) -- Total number of predicted (prognostic) channels.


.. py:function:: update_x(x_prev, x_dynfrc, y_pred, slices)

   Build the next-step input tensor for autoregressive rollout.

   Replaces dataset-driven channels (dynamic_forcing) and model-predicted
   channels (prognostic) in x_prev.  Fixed channels (static, etc.) are
   carried forward unchanged via clone.

   :param x_prev: Full input tensor from the previous step.
   :type x_prev: Tensor  [B, C, ...]
   :param x_dynfrc: New dynamic-forcing channels from the dataset, in the same relative
                    order as the dataset-driven groups appear in slices.
   :type x_dynfrc: Tensor  [B, C_dyn, ...]
   :param y_pred: Model output, in the same relative order as the predicted groups
                  appear in slices.
   :type y_pred: Tensor  [B, C_pred, ...]
   :param slices: From build_channel_layout().
   :type slices: dict[str, slice]

   :returns: Updated input tensor ready for the next forward pass.
   :rtype: Tensor  [B, C, ...]


