credit.preblock.norm#
ERA5Normalizer: normalizes per-variable ERA5 tensors using mean/std NC files.
Operates on the raw batch structure from MultiSourceDataset:
batch["era5"]["input"]["era5/field_type/3d/varname"] = (B, n_levels, T, H, W)
batch["era5"]["input"]["era5/field_type/2d/varname"] = (B, 1, T, H, W)
Variables absent from the mean/std file are passed through unchanged.
Registered in the preblock registry as "era5_normalizer" so it can be
included via the config’s preblocks: section:
preblocks:
norm:
type: era5_normalizer
args:
mean_path: /path/to/mean.nc
std_path: /path/to/std.nc
Attributes#
Classes#
Normalizes per-variable ERA5 tensors using pre-computed mean/std files. |
Module Contents#
- credit.preblock.norm.logger#
- class credit.preblock.norm.ERA5Normalizer(mean_path: str, std_path: str, levels: list[int] | None = None)#
Bases:
credit.preblock.base.BasePreblockNormalizes per-variable ERA5 tensors using pre-computed mean/std files.
Normalization:
(x - mean) / stdapplied per variable. Variables not found in the statistics file are passed through unchanged.- Parameters:
mean_path – Path to NetCDF file containing per-variable means.
std_path – Path to NetCDF file containing per-variable standard deviations.
levels – Optional list of 1-indexed model levels to select from the full 137-level stats (e.g. [60, 90, 120, 137] for a 4-level smoke test). When omitted, all levels in the stats file are used.
- _mean: dict[str, torch.Tensor]#
- _std: dict[str, torch.Tensor]#
- _normalize_tensor(key: str, tensor: torch.Tensor) torch.Tensor#
Normalize tensor using the variable name extracted from key.
- forward(batch: dict) dict#
Normalize all input/target tensors, returning a new batch dict.