credit.trainers.rollout_utils#
rollout_utils.py — utilities for autoregressive multi-step rollout.
Attributes#
Functions#
|
Assemble a batch dict for the rollout preblock pass at autoregressive step t > 0. |
Module Contents#
- credit.trainers.rollout_utils.logger#
- credit.trainers.rollout_utils.assemble_rollout_batch(full_data_dict: dict, curr_batch: dict) dict#
Assemble a batch dict for the rollout preblock pass at autoregressive step t > 0.
Constructs a dataset-schema batch by routing each variable from the appropriate source:
prognostic / diagnostic channels: from
full_data_dict["y_processed"]— the postblock-processed prediction from the previous step.dynamic_forcing channels: from
curr_batch["input"]— the current step’s time-varying forcing loaded by the dataset.static (and any other non-predicted) channels: from
full_data_dict["ic_preprocessed"]["input"]— the t=0 raw batch after IC-only preblocks, so statics are already on the model grid.
The assembled dict is passed to
apply_preblocks(step_preblocks, ...)which handles per-step operations (log_transform, concat).curr_batch["target"]is forwarded so preblocks normalize the training target in the same pass.- Parameters:
full_data_dict – the rollout state dict. Must contain:
"y_processed"— nested{source: {var_key: tensor}}from the previous step’s postblock chain (output ofReconstruct+ fixers)."ic_preprocessed"— t=0 raw batch after IC-only preblocks, providing the authoritative variable key list and static tensors.curr_batch – current step’s raw batch from the dataset. Provides dynamic forcing fields and the training target.
- Returns:
dict with keys
"input"(nested source→var dict) and"target"(fromcurr_batch), ready forapply_preblocks(step_preblocks, ...).- Raises:
TypeError – if
full_data_dict["y_processed"]is not a dict, which usually meansReconstructwas not included in the postblock chain.