credit.preblock#
Submodules#
Attributes#
Classes#
Applies a log transformation to specified variables in a batch dict. |
|
Applies a sqrt transformation to specified variables in a batch dict. |
|
Regridding layer using weights file provided by the ESMF library. |
|
End-of-chain preblock that concatenates a nested batch dict of tensors |
|
Normalizes per-variable ERA5 tensors using pre-computed mean/std files. |
Functions#
|
|
|
Instantiate preblocks for a single phase from a two-section config. |
|
Sequentially applies a group of preblocks, returning the transformed batch. |
|
Apply a preblock group built by |
Package Contents#
- class credit.preblock.LogTransform(variables: list[str], data_types: list[str] = None, base: str = 'e', eps: float = 1e-08)#
Bases:
credit.preblock.base.BasePreblockApplies a log transformation to specified variables in a batch dict.
- Expected dict structure:
batch[source][data_type][‘source/var_type/var_shape/var_name’]
- Config example:
type: “log_transform” args:
- variables:
‘ERA5/prognostic/3D/Q’
- data_types: # optional, defaults to [‘input’, ‘target’]
‘input’
‘target’
base: ‘e’ # optional, default ‘e’. Options: ‘e’, ‘2’, ‘10’ eps: 1.0e-8 # optional, default 1e-8
- variables#
- data_types = ['input', 'target']#
- eps#
- forward(batch: dict) dict#
- class credit.preblock.SqrtTransform(variables: list[str], data_types: list[str] = None)#
Bases:
credit.preblock.base.BasePreblockApplies a sqrt transformation to specified variables in a batch dict.
- Expected dict structure:
batch[source][data_type][‘source/var_type/var_shape/var_name’]
- Config example:
type: “sqrt_transform” args:
- variables:
‘ERA5/prognostic/3D/Q’
- data_types: # optional, defaults to [‘input’, ‘target’]
‘input’
‘target’
- variables#
- data_types = ['input', 'target']#
- forward(batch: dict) dict#
- class credit.preblock.Regridder(weight_file, variables: list[str], data_types: list[str] = None, reshape_to_xy=True, flip_axis=None)#
Bases:
credit.preblock.base.BasePreblockRegridding layer using weights file provided by the ESMF library. :param weight_file: path to weights file :param variables: list of variable keys to regrid (e.g. [‘era5/prognostic/3d/T’]) :param data_types: list of data types to process (default: [‘input’, ‘target’]) :param reshape_to_xy: whether to reshape the flattened array back to xy coordinates :param flip_axis: axes to flip before regridding (e.g. [-1, -2]) :type flip_axis: list, tuple, or None
- variables#
- data_types = ['input', 'target']#
- reshape_to_xy = True#
- flip_axis = None#
- n_a#
- n_b#
- dst_shape#
- _W = None#
- _W_device = None#
- _get_W(device)#
- _regrid(x: torch.Tensor) torch.Tensor#
- forward(batch: dict) dict#
- class credit.preblock.ConcatToTensor(to_device: bool = True)#
Bases:
credit.preblock.base.BasePreblockEnd-of-chain preblock that concatenates a nested batch dict of tensors into a single input tensor (and optionally a target tensor).
Expects a batch dict of the form:
batch[data_type][source][var_name] -> torch.Tensor
where tensor shapes are (batch, n_levels, time, lat, lon) and concatenation is performed along dim=1 (channel). Input tensors are sorted by
_channel_sort_keybefore concatenation so the channel order matches the canonical variable schema regardless of insertion order in the batch.metadatakeys are passed through as-is (not concatenated).In addition to the tensors, two channel maps are attached to metadata under
metadata["_channel_map"]:"input"— every variable and its slice in the concatenated input tensor."output"— prognostic + diagnostic variables only, with slices reindexed from 0 to matchy_predchannel ordering.
Each entry has the form:
var_key -> {"slice": slice(start, end), "orig_shape": (n_levels, T)}
Returns either:
(input_tensor, metadata) # if no "target" data_type present (input_tensor, target_tensor, metadata) # if "target" is present
Example config:
type: "concatenate_to_tensor" args: to_device: true # set false to skip .to(device) in apply_preblocks
- to_device = True#
- forward(batch)#
- class credit.preblock.ERA5Normalizer(mean_path: str, std_path: str, levels: list[int] | None = None)#
Bases:
credit.preblock.base.BasePreblockNormalizes per-variable ERA5 tensors using pre-computed mean/std files.
Normalization:
(x - mean) / stdapplied per variable. Variables not found in the statistics file are passed through unchanged.- Parameters:
mean_path – Path to NetCDF file containing per-variable means.
std_path – Path to NetCDF file containing per-variable standard deviations.
levels – Optional list of 1-indexed model levels to select from the full 137-level stats (e.g. [60, 90, 120, 137] for a 4-level smoke test). When omitted, all levels in the stats file are used.
- _mean: dict[str, torch.Tensor]#
- _std: dict[str, torch.Tensor]#
- _normalize_tensor(key: str, tensor: torch.Tensor) torch.Tensor#
Normalize tensor using the variable name extracted from key.
- forward(batch: dict) dict#
Normalize all input/target tensors, returning a new batch dict.
- credit.preblock._BRIDGESCALER_AVAILABLE = True#
- credit.preblock.PREBLOCK_REGISTRY#
- credit.preblock._VALID_SECTIONS#
- credit.preblock._build_preblock_section(section_cfg: dict) torch.nn.ModuleDict#
- credit.preblock.build_preblocks(preblock_cfg: dict | None = None, phase: str = 'per_step') torch.nn.ModuleDict#
Instantiate preblocks for a single phase from a two-section config.
Config format:
preblocks: ic_only: # run once at t=0 on the raw batch (e.g. static regrid) regrid_static: type: regrid args: ... per_step: # run every rollout step (e.g. log_transform, concat) log_transform: type: log_transform concat: type: concat
Typical usage — build once per phase, store separately:
ic_preblocks = build_preblocks(cfg, phase="ic_only") step_preblocks = build_preblocks(cfg, phase="per_step") # t=0: run both in sequence ic_preprocessed = apply_preblocks(ic_preblocks, batch, device=device) preprocessed_batch = apply_preblocks(step_preblocks, ic_preprocessed, device=device) # t>0: run per_step only preprocessed_batch = apply_preblocks(step_preblocks, rollout_batch, device=device)
- Parameters:
preblock_cfg – the full
preblocksconfig dict (both sections).phase – which section to build —
"ic_only"or"per_step".
- Returns:
nn.ModuleDictof instantiated blocks for the requested phase.- Raises:
ValueError – if the config contains keys other than
"ic_only"/"per_step", or ifphaseis not one of those values.
- credit.preblock._run_preblock_group(group: torch.nn.ModuleDict, batch: dict, device=None)#
Sequentially applies a group of preblocks, returning the transformed batch.
- credit.preblock.apply_preblocks(preblocks: torch.nn.ModuleDict, batch: dict, device=None) dict#
Apply a preblock group built by
build_preblocks.- Parameters:
preblocks –
nn.ModuleDictbuilt bybuild_preblocksfor a single phase.batch – nested variable dict from the dataset (or a prior preblock pass).
device – move output tensors here after concat.
- Returns:
{"x": tensor, "y": tensor, "metadata": ...}. Otherwise: the transformed nested batch dict (pre-concat).- Return type:
When concat has run