credit.postblock.reconstruct#
Reconstruct: first postblock that splits the flat batch_dict["y_pred"]
tensor into a nested variable dict, writing the result to batch_dict["y_processed"].
Reads batch_dict["metadata"]["target"]["_channel_map"] built by
ConcatToTensor to know which channels correspond to which variables.
Input#
- batch_dictdict
- Must contain:
“y_pred” — flat model output tensor, shape (B, C, H, W) or (B, C, T, H, W) “metadata” — metadata dict with
["target"]["_channel_map"]
All other keys pass through unchanged.
Output#
The same batch_dict with "y_processed" added as a nested dict:
- batch_dict[“y_processed”][source][var_key]
-> tensor of shape (B, n_levels, n_time, H, W)
"y_pred" is left intact (grad-attached) for use in loss computation.
Classes#
Splits |
Module Contents#
- class credit.postblock.reconstruct.Reconstruct(*args: Any, **kwargs: Any)#
Bases:
credit.postblock.base.BasePostblockSplits
batch_dict["y_pred"]into a nested variable dict atbatch_dict["y_processed"].Slices are read from
batch_dict["metadata"]["target"]["_channel_map"], built byConcatToTensorand covering only prognostic + diagnostic variables. Each slice is unflattened from(B, n_levels * n_time, H, W)back to(B, n_levels, n_time, H, W).y_predis left untouched. All other keys inbatch_dictpass through unchanged.- forward(batch_dict: dict) dict#