credit.preblock#

Submodules#

Attributes#

Classes#

LogTransform

Applies a log transformation to specified variables in a batch dict.

SqrtTransform

Applies a log transformation to specified variables in a batch dict.

BridgeScalerTransformer

Scaling preblock using a fitted bridgescaler dict.

Regridder

Regridding layer using weights file provided by the ESMF library.

ConcatToTensor

End-of-chain preblock that concatenates a nested batch dict of tensors

Functions#

build_preblocks(→ torch.nn.ModuleDict)

Instantiates all preblocks from the config's 'preblocks' section.

apply_preblocks(preblocks, batch)

Sequentially applies transform preblocks (dict→dict), then concatenates to tensors.

Package Contents#

class credit.preblock.LogTransform(variables: list[str], data_types: list[str] = None, base: str = 'e', eps: float = 1e-08)#

Bases: credit.preblock.base.BasePreblock

Applies a log transformation to specified variables in a batch dict.

Expected dict structure:

batch[source][data_type][‘source/var_type/var_shape/var_name’]

Config example:

type: “log_transform” args:

variables:
  • ‘ERA5/prognostic/3D/Q’

data_types: # optional, defaults to [‘input’, ‘target’]
  • ‘input’

  • ‘target’

base: ‘e’ # optional, default ‘e’. Options: ‘e’, ‘2’, ‘10’ eps: 1.0e-8 # optional, default 1e-8

variables#
data_types = ['input', 'target']#
eps#
forward(batch: dict) dict#
class credit.preblock.SqrtTransform(variables: list[str], data_types: list[str] = None)#

Bases: credit.preblock.base.BasePreblock

Applies a log transformation to specified variables in a batch dict.

Expected dict structure:

batch[source][data_type][‘source/var_type/var_shape/var_name’]

Config example:

type: “sqrt_transform” args:

variables:
  • ‘ERA5/prognostic/3D/Q’

data_types: # optional, defaults to [‘input’, ‘target’]
  • ‘input’

  • ‘target’

variables#
data_types = ['input', 'target']#
forward(batch: dict) dict#
class credit.preblock.BridgeScalerTransformer(scaler_path: str, variables: list[str], method: str)#

Bases: credit.preblock.base.BasePreblock

Scaling preblock using a fitted bridgescaler dict.

Applies per-variable z-score scaling (or its inverse) to tensors in a nested batch dict of the form batch[source][data_type][var_key].

The scaler dict must have been fit with bridgescaler.scale_var_dict using the same nested structure and saved with bridgescaler.save_scaler_dict.

Example config:

type: "bridgescaler_transform"
args:
    scaler_path: "/path/to/scaler.json"
    variables:
        - "era5/prognostic/3d/T"
        - "era5/prognostic/3d/U"
    method: "transform"
variables#
method#
scaler_path#
scaler#
forward(batch: dict) dict#
class credit.preblock.Regridder(weight_file, variables: list[str], data_types: list[str] = None, reshape_to_xy=True, flip_axis=None)#

Bases: credit.preblock.base.BasePreblock

Regridding layer using weights file provided by the ESMF library. :param weight_file: path to weights file :param variables: list of variable keys to regrid (e.g. [‘era5/prognostic/3d/T’]) :param data_types: list of data types to process (default: [‘input’, ‘target’]) :param reshape_to_xy: whether to reshape the flattened array back to xy coordinates :param flip_axis: axes to flip before regridding (e.g. [-1, -2]) :type flip_axis: list, tuple, or None

variables#
data_types = ['input', 'target']#
reshape_to_xy = True#
flip_axis = None#
n_a#
n_b#
dst_shape#
_W = None#
_W_device = None#
_get_W(device)#
_regrid(x: torch.Tensor) torch.Tensor#
forward(batch: dict) dict#
class credit.preblock.ConcatToTensor(*args: Any, **kwargs: Any)#

Bases: credit.preblock.base.BasePreblock

End-of-chain preblock that concatenates a nested batch dict of tensors into a single input tensor (and optionally a target tensor).

Expects a batch dict of the form:

batch[source][data_type][var_name] -> torch.Tensor

where tensor shapes are (batch, channel, time, lon, lat) and concatenation is performed along dim=1 (channel). Traversal order follows key insertion order: for each source, all var_names under a data_type are concatenated, then the next source, and so on.

metadata keys are passed through as-is (not concatenated).

Returns either:

(input_tensor, metadata)                    # if no "target" data_type present
(input_tensor, target_tensor, metadata)     # if "target" is present

Example config:

type: "concatenate_to_tensor"
args: {}
forward(batch)#
credit.preblock.PREBLOCK_REGISTRY#
credit.preblock.build_preblocks(preblock_cfg: dict) torch.nn.ModuleDict#

Instantiates all preblocks from the config’s ‘preblocks’ section.

Parameters:

preblock_cfg

the full preblocks dict from the config, e.g.: {

’era5_log_transform’: {‘type’: ‘log_transform’, ‘args’: {…}}, ‘era5_z_transform’: {‘type’: ‘z_transform’, ‘args’: {…}},

}

Returns:

nn.ModuleDict of instantiated preblocks, ordered as in config.

credit.preblock.apply_preblocks(preblocks: torch.nn.ModuleDict, batch: dict)#

Sequentially applies transform preblocks (dict→dict), then concatenates to tensors.

Concatenation is always performed last and is not configurable.