credit.preblock#

Submodules#

Attributes#

Classes#

LogTransform

Applies a log transformation to specified variables in a batch dict.

SqrtTransform

Applies a sqrt transformation to specified variables in a batch dict.

Regridder

Regridding layer using weights file provided by the ESMF library.

ConcatToTensor

End-of-chain preblock that concatenates a nested batch dict of tensors

ERA5Normalizer

Normalizes per-variable ERA5 tensors using pre-computed mean/std files.

Functions#

_build_preblock_section(→ torch.nn.ModuleDict)

build_preblocks(→ torch.nn.ModuleDict)

Instantiate preblocks for a single phase from a two-section config.

_run_preblock_group(group, batch[, device])

Sequentially applies a group of preblocks, returning the transformed batch.

apply_preblocks(→ dict)

Apply a preblock group built by build_preblocks.

Package Contents#

class credit.preblock.LogTransform(variables: list[str], data_types: list[str] = None, base: str = 'e', eps: float = 1e-08)#

Bases: credit.preblock.base.BasePreblock

Applies a log transformation to specified variables in a batch dict.

Expected dict structure:

batch[source][data_type][‘source/var_type/var_shape/var_name’]

Config example:

type: “log_transform” args:

variables:
  • ‘ERA5/prognostic/3D/Q’

data_types: # optional, defaults to [‘input’, ‘target’]
  • ‘input’

  • ‘target’

base: ‘e’ # optional, default ‘e’. Options: ‘e’, ‘2’, ‘10’ eps: 1.0e-8 # optional, default 1e-8

variables#
data_types = ['input', 'target']#
eps#
forward(batch: dict) dict#
class credit.preblock.SqrtTransform(variables: list[str], data_types: list[str] = None)#

Bases: credit.preblock.base.BasePreblock

Applies a sqrt transformation to specified variables in a batch dict.

Expected dict structure:

batch[source][data_type][‘source/var_type/var_shape/var_name’]

Config example:

type: “sqrt_transform” args:

variables:
  • ‘ERA5/prognostic/3D/Q’

data_types: # optional, defaults to [‘input’, ‘target’]
  • ‘input’

  • ‘target’

variables#
data_types = ['input', 'target']#
forward(batch: dict) dict#
class credit.preblock.Regridder(weight_file, variables: list[str], data_types: list[str] = None, reshape_to_xy=True, flip_axis=None)#

Bases: credit.preblock.base.BasePreblock

Regridding layer using weights file provided by the ESMF library. :param weight_file: path to weights file :param variables: list of variable keys to regrid (e.g. [‘era5/prognostic/3d/T’]) :param data_types: list of data types to process (default: [‘input’, ‘target’]) :param reshape_to_xy: whether to reshape the flattened array back to xy coordinates :param flip_axis: axes to flip before regridding (e.g. [-1, -2]) :type flip_axis: list, tuple, or None

variables#
data_types = ['input', 'target']#
reshape_to_xy = True#
flip_axis = None#
n_a#
n_b#
dst_shape#
_W = None#
_W_device = None#
_get_W(device)#
_regrid(x: torch.Tensor) torch.Tensor#
forward(batch: dict) dict#
class credit.preblock.ConcatToTensor(to_device: bool = True)#

Bases: credit.preblock.base.BasePreblock

End-of-chain preblock that concatenates a nested batch dict of tensors into a single input tensor (and optionally a target tensor).

Expects a batch dict of the form:

batch[data_type][source][var_name] -> torch.Tensor

where tensor shapes are (batch, n_levels, time, lat, lon) and concatenation is performed along dim=1 (channel). Input tensors are sorted by _channel_sort_key before concatenation so the channel order matches the canonical variable schema regardless of insertion order in the batch.

metadata keys are passed through as-is (not concatenated).

In addition to the tensors, two channel maps are attached to metadata under metadata["_channel_map"]:

  • "input" — every variable and its slice in the concatenated input tensor.

  • "output" — prognostic + diagnostic variables only, with slices reindexed from 0 to match y_pred channel ordering.

Each entry has the form:

var_key -> {"slice": slice(start, end), "orig_shape": (n_levels, T)}

Returns either:

(input_tensor, metadata)                    # if no "target" data_type present
(input_tensor, target_tensor, metadata)     # if "target" is present

Example config:

type: "concatenate_to_tensor"
args:
  to_device: true   # set false to skip .to(device) in apply_preblocks
to_device = True#
forward(batch)#
class credit.preblock.ERA5Normalizer(mean_path: str, std_path: str, levels: list[int] | None = None)#

Bases: credit.preblock.base.BasePreblock

Normalizes per-variable ERA5 tensors using pre-computed mean/std files.

Normalization: (x - mean) / std applied per variable. Variables not found in the statistics file are passed through unchanged.

Parameters:
  • mean_path – Path to NetCDF file containing per-variable means.

  • std_path – Path to NetCDF file containing per-variable standard deviations.

  • levels – Optional list of 1-indexed model levels to select from the full 137-level stats (e.g. [60, 90, 120, 137] for a 4-level smoke test). When omitted, all levels in the stats file are used.

_mean: dict[str, torch.Tensor]#
_std: dict[str, torch.Tensor]#
_normalize_tensor(key: str, tensor: torch.Tensor) torch.Tensor#

Normalize tensor using the variable name extracted from key.

forward(batch: dict) dict#

Normalize all input/target tensors, returning a new batch dict.

credit.preblock._BRIDGESCALER_AVAILABLE = True#
credit.preblock.PREBLOCK_REGISTRY#
credit.preblock._VALID_SECTIONS#
credit.preblock._build_preblock_section(section_cfg: dict) torch.nn.ModuleDict#
credit.preblock.build_preblocks(preblock_cfg: dict | None = None, phase: str = 'per_step') torch.nn.ModuleDict#

Instantiate preblocks for a single phase from a two-section config.

Config format:

preblocks:
  ic_only:          # run once at t=0 on the raw batch (e.g. static regrid)
    regrid_static:
      type: regrid
      args: ...
  per_step:         # run every rollout step (e.g. log_transform, concat)
    log_transform:
      type: log_transform
    concat:
      type: concat

Typical usage — build once per phase, store separately:

ic_preblocks   = build_preblocks(cfg, phase="ic_only")
step_preblocks = build_preblocks(cfg, phase="per_step")

# t=0: run both in sequence
ic_preprocessed    = apply_preblocks(ic_preblocks, batch, device=device)
preprocessed_batch = apply_preblocks(step_preblocks, ic_preprocessed, device=device)

# t>0: run per_step only
preprocessed_batch = apply_preblocks(step_preblocks, rollout_batch, device=device)
Parameters:
  • preblock_cfg – the full preblocks config dict (both sections).

  • phase – which section to build — "ic_only" or "per_step".

Returns:

nn.ModuleDict of instantiated blocks for the requested phase.

Raises:

ValueError – if the config contains keys other than "ic_only" / "per_step", or if phase is not one of those values.

credit.preblock._run_preblock_group(group: torch.nn.ModuleDict, batch: dict, device=None)#

Sequentially applies a group of preblocks, returning the transformed batch.

credit.preblock.apply_preblocks(preblocks: torch.nn.ModuleDict, batch: dict, device=None) dict#

Apply a preblock group built by build_preblocks.

Parameters:
  • preblocksnn.ModuleDict built by build_preblocks for a single phase.

  • batch – nested variable dict from the dataset (or a prior preblock pass).

  • device – move output tensors here after concat.

Returns:

{"x": tensor, "y": tensor, "metadata": ...}. Otherwise: the transformed nested batch dict (pre-concat).

Return type:

When concat has run