credit.preblock.concat#
ConcatToTensor: end-of-chain preblock that collapses a nested batch dict into a flat (x, y, metadata) tuple. Used by build_preblocks/apply_preblocks.
Channel concat order is fully determined by the variable key structure
{source}/{field_type}/{dim}/{varname}:
field_type rank: prognostic < dynamic_forcing < static < diagnostic
dim rank: 3d < 2d
within each (field_type, dim) bucket: original insertion order is preserved (Python sort is stable), which matches config list order.
Attributes#
Classes#
End-of-chain preblock that concatenates a nested batch dict of tensors |
Functions#
|
Sort key for items from variables.items(): (var_key, tensor). |
Module Contents#
- credit.preblock.concat._PREDICTABLE_FIELD_TYPES#
- credit.preblock.concat._channel_sort_key(item) tuple#
Sort key for items from variables.items(): (var_key, tensor).
var_key has the form
source/field_type/dim/varname.
- class credit.preblock.concat.ConcatToTensor(to_device: bool = True)#
Bases:
credit.preblock.base.BasePreblockEnd-of-chain preblock that concatenates a nested batch dict of tensors into a single input tensor (and optionally a target tensor).
Expects a batch dict of the form:
batch[data_type][source][var_name] -> torch.Tensor
where tensor shapes are (batch, n_levels, time, lat, lon) and concatenation is performed along dim=1 (channel). Input tensors are sorted by
_channel_sort_keybefore concatenation so the channel order matches the canonical variable schema regardless of insertion order in the batch.metadatakeys are passed through as-is (not concatenated).In addition to the tensors, two channel maps are attached to metadata under
metadata["_channel_map"]:"input"— every variable and its slice in the concatenated input tensor."output"— prognostic + diagnostic variables only, with slices reindexed from 0 to matchy_predchannel ordering.
Each entry has the form:
var_key -> {"slice": slice(start, end), "orig_shape": (n_levels, T)}
Returns either:
(input_tensor, metadata) # if no "target" data_type present (input_tensor, target_tensor, metadata) # if "target" is present
Example config:
type: "concatenate_to_tensor" args: to_device: true # set false to skip .to(device) in apply_preblocks
- to_device = True#
- forward(batch)#