credit.preblock.concat#

ConcatToTensor: end-of-chain preblock that collapses a nested batch dict into a flat (x, y, metadata) tuple. Used by build_preblocks/apply_preblocks.

Channel concat order is fully determined by the variable key structure {source}/{field_type}/{dim}/{varname}:

  1. field_type rank: prognostic < dynamic_forcing < static < diagnostic

  2. dim rank: 3d < 2d

  3. within each (field_type, dim) bucket: original insertion order is preserved (Python sort is stable), which matches config list order.

Attributes#

Classes#

ConcatToTensor

End-of-chain preblock that concatenates a nested batch dict of tensors

Functions#

_channel_sort_key(→ tuple)

Sort key for items from variables.items(): (var_key, tensor).

Module Contents#

credit.preblock.concat._PREDICTABLE_FIELD_TYPES#
credit.preblock.concat._channel_sort_key(item) tuple#

Sort key for items from variables.items(): (var_key, tensor).

var_key has the form source/field_type/dim/varname.

class credit.preblock.concat.ConcatToTensor(to_device: bool = True)#

Bases: credit.preblock.base.BasePreblock

End-of-chain preblock that concatenates a nested batch dict of tensors into a single input tensor (and optionally a target tensor).

Expects a batch dict of the form:

batch[data_type][source][var_name] -> torch.Tensor

where tensor shapes are (batch, n_levels, time, lat, lon) and concatenation is performed along dim=1 (channel). Input tensors are sorted by _channel_sort_key before concatenation so the channel order matches the canonical variable schema regardless of insertion order in the batch.

metadata keys are passed through as-is (not concatenated).

In addition to the tensors, two channel maps are attached to metadata under metadata["_channel_map"]:

  • "input" — every variable and its slice in the concatenated input tensor.

  • "output" — prognostic + diagnostic variables only, with slices reindexed from 0 to match y_pred channel ordering.

Each entry has the form:

var_key -> {"slice": slice(start, end), "orig_shape": (n_levels, T)}

Returns either:

(input_tensor, metadata)                    # if no "target" data_type present
(input_tensor, target_tensor, metadata)     # if "target" is present

Example config:

type: "concatenate_to_tensor"
args:
  to_device: true   # set false to skip .to(device) in apply_preblocks
to_device = True#
forward(batch)#