credit.preblock.concat
======================

.. py:module:: credit.preblock.concat

.. autoapi-nested-parse::

   concat.py
   ---------
   ConcatToTensor: end-of-chain preblock that collapses a nested batch dict into a
   flat (x, y, metadata) tuple. Used by build_preblocks/apply_preblocks.

   Channel concat order is fully determined by the variable key structure
   ``{source}/{field_type}/{dim}/{varname}``:

     1. field_type rank: prognostic < dynamic_forcing < static < diagnostic
     2. dim rank: 3d < 2d
     3. within each (field_type, dim) bucket: original insertion order is
        preserved (Python sort is stable), which matches config list order.



Attributes
----------

.. autoapisummary::

   credit.preblock.concat._PREDICTABLE_FIELD_TYPES


Classes
-------

.. autoapisummary::

   credit.preblock.concat.ConcatToTensor


Functions
---------

.. autoapisummary::

   credit.preblock.concat._channel_sort_key


Module Contents
---------------

.. py:data:: _PREDICTABLE_FIELD_TYPES

.. py:function:: _channel_sort_key(item) -> tuple

   Sort key for items from variables.items(): (var_key, tensor).

   var_key has the form ``source/field_type/dim/varname``.


.. py:class:: ConcatToTensor(to_device: bool = True)

   Bases: :py:obj:`credit.preblock.base.BasePreblock`


   End-of-chain preblock that concatenates a nested batch dict of tensors
   into a single input tensor (and optionally a target tensor).

   Expects a batch dict of the form::

       batch[data_type][source][var_name] -> torch.Tensor

   where tensor shapes are (batch, n_levels, time, lat, lon) and concatenation
   is performed along dim=1 (channel). Input tensors are sorted by
   ``_channel_sort_key`` before concatenation so the channel order matches
   the canonical variable schema regardless of insertion order in the batch.

   ``metadata`` keys are passed through as-is (not concatenated).

   In addition to the tensors, two channel maps are attached to metadata under
   ``metadata["_channel_map"]``:

   * ``"input"``  — every variable and its slice in the concatenated input tensor.
   * ``"output"`` — prognostic + diagnostic variables only, with slices
     reindexed from 0 to match ``y_pred`` channel ordering.

   Each entry has the form::

       var_key -> {"slice": slice(start, end), "orig_shape": (n_levels, T)}

   Returns either::

       (input_tensor, metadata)                    # if no "target" data_type present
       (input_tensor, target_tensor, metadata)     # if "target" is present

   Example config::

       type: "concatenate_to_tensor"
       args:
         to_device: true   # set false to skip .to(device) in apply_preblocks


   .. py:attribute:: to_device
      :value: True



   .. py:method:: forward(batch)


