credit.datasets.channel_layout#
Derive per-group channel slices directly from the variables config and provide a standalone update function for multi-step rollout.
The concat order matches the canonical rank defined by FIELD_TYPE_RANK (prognostic < dynamic_forcing < static < diagnostic), with vars_3D before vars_2D within each group. This is the same rank used by credit.preblock.concat.ConcatToTensor, so slices returned here always correspond to the actual channel positions in the model input tensor.
Usage#
from credit.datasets.channel_layout import build_channel_layout, update_x
# once at trainer / rollout init slices, n_pred = build_channel_layout(conf)
# at every t > 1 step x = update_x(x, x_dynfrc, y_pred, slices)
Attributes#
Functions#
|
Return (slices, n_pred) derived from the variables config. |
|
Build the next-step input tensor for autoregressive rollout. |
Module Contents#
- credit.datasets.channel_layout.FIELD_TYPE_RANK: dict[str, int]#
- credit.datasets.channel_layout._DIAGNOSTIC#
- credit.datasets.channel_layout._DATASET_DRIVEN#
- credit.datasets.channel_layout._MODEL_PREDICTED#
- credit.datasets.channel_layout.build_channel_layout(conf)#
Return (slices, n_pred) derived from the variables config.
Slices are ordered by FIELD_TYPE_RANK (matching ConcatToTensor), so slice offsets correspond directly to channel positions in the tensor.
- Parameters:
conf (dict) – Full CREDIT config dict. Reads the first entry under conf[“data”][“source”] for levels and variable groups.
- Returns:
slices (dict[str, slice]) – Mapping from group name to its slice in the channel dimension of x. Groups in _DIAGNOSTIC are excluded (they are never in x). Order matches FIELD_TYPE_RANK, not config key order.
n_pred (int) – Total number of predicted (prognostic) channels.
- credit.datasets.channel_layout.update_x(x_prev, x_dynfrc, y_pred, slices)#
Build the next-step input tensor for autoregressive rollout.
Replaces dataset-driven channels (dynamic_forcing) and model-predicted channels (prognostic) in x_prev. Fixed channels (static, etc.) are carried forward unchanged via clone.
- Parameters:
x_prev (Tensor [B, C, ...]) – Full input tensor from the previous step.
x_dynfrc (Tensor [B, C_dyn, ...]) – New dynamic-forcing channels from the dataset, in the same relative order as the dataset-driven groups appear in slices.
y_pred (Tensor [B, C_pred, ...]) – Model output, in the same relative order as the predicted groups appear in slices.
slices (dict[str, slice]) – From build_channel_layout().
- Returns:
Updated input tensor ready for the next forward pass.
- Return type:
Tensor [B, C, …]