credit.trainers.rollout_utils

credit.trainers.rollout_utils#

rollout_utils.py — utilities for autoregressive multi-step rollout.

Attributes#

Functions#

assemble_rollout_batch(→ dict)

Assemble a batch dict for the rollout preblock pass at autoregressive step t > 0.

Module Contents#

credit.trainers.rollout_utils.logger#
credit.trainers.rollout_utils.assemble_rollout_batch(full_data_dict: dict, curr_batch: dict) dict#

Assemble a batch dict for the rollout preblock pass at autoregressive step t > 0.

Constructs a dataset-schema batch by routing each variable from the appropriate source:

  • prognostic / diagnostic channels: from full_data_dict["y_processed"] — the postblock-processed prediction from the previous step.

  • dynamic_forcing channels: from curr_batch["input"] — the current step’s time-varying forcing loaded by the dataset.

  • static (and any other non-predicted) channels: from full_data_dict["ic_preprocessed"]["input"] — the t=0 raw batch after IC-only preblocks, so statics are already on the model grid.

The assembled dict is passed to apply_preblocks(step_preblocks, ...) which handles per-step operations (log_transform, concat). curr_batch["target"] is forwarded so preblocks normalize the training target in the same pass.

Parameters:
  • full_data_dict – the rollout state dict. Must contain: "y_processed" — nested {source: {var_key: tensor}} from the previous step’s postblock chain (output of Reconstruct + fixers). "ic_preprocessed" — t=0 raw batch after IC-only preblocks, providing the authoritative variable key list and static tensors.

  • curr_batch – current step’s raw batch from the dataset. Provides dynamic forcing fields and the training target.

Returns:

dict with keys "input" (nested source→var dict) and "target" (from curr_batch), ready for apply_preblocks(step_preblocks, ...).

Raises:

TypeError – if full_data_dict["y_processed"] is not a dict, which usually means Reconstruct was not included in the postblock chain.