credit.postblock.reconstruct

credit.postblock.reconstruct#

Reconstruct: first postblock that splits the flat batch_dict["y_pred"] tensor into a nested variable dict, writing the result to batch_dict["y_processed"].

Reads batch_dict["metadata"]["target"]["_channel_map"] built by ConcatToTensor to know which channels correspond to which variables.

Input#

batch_dictdict
Must contain:

“y_pred” — flat model output tensor, shape (B, C, H, W) or (B, C, T, H, W) “metadata” — metadata dict with ["target"]["_channel_map"]

All other keys pass through unchanged.

Output#

The same batch_dict with "y_processed" added as a nested dict:

batch_dict[“y_processed”][source][var_key]

-> tensor of shape (B, n_levels, n_time, H, W)

"y_pred" is left intact (grad-attached) for use in loss computation.

Classes#

Reconstruct

Splits batch_dict["y_pred"] into a nested variable dict at batch_dict["y_processed"].

Module Contents#

class credit.postblock.reconstruct.Reconstruct(*args: Any, **kwargs: Any)#

Bases: credit.postblock.base.BasePostblock

Splits batch_dict["y_pred"] into a nested variable dict at batch_dict["y_processed"].

Slices are read from batch_dict["metadata"]["target"]["_channel_map"], built by ConcatToTensor and covering only prognostic + diagnostic variables. Each slice is unflattened from (B, n_levels * n_time, H, W) back to (B, n_levels, n_time, H, W). y_pred is left untouched. All other keys in batch_dict pass through unchanged.

forward(batch_dict: dict) dict#