credit.preblock#
Submodules#
Attributes#
Classes#
Applies a log transformation to specified variables in a batch dict. |
|
Applies a log transformation to specified variables in a batch dict. |
|
Scaling preblock using a fitted bridgescaler dict. |
|
Regridding layer using weights file provided by the ESMF library. |
|
End-of-chain preblock that concatenates a nested batch dict of tensors |
Functions#
|
Instantiates all preblocks from the config's 'preblocks' section. |
|
Sequentially applies transform preblocks (dict→dict), then concatenates to tensors. |
Package Contents#
- class credit.preblock.LogTransform(variables: list[str], data_types: list[str] = None, base: str = 'e', eps: float = 1e-08)#
Bases:
credit.preblock.base.BasePreblockApplies a log transformation to specified variables in a batch dict.
- Expected dict structure:
batch[source][data_type][‘source/var_type/var_shape/var_name’]
- Config example:
type: “log_transform” args:
- variables:
‘ERA5/prognostic/3D/Q’
- data_types: # optional, defaults to [‘input’, ‘target’]
‘input’
‘target’
base: ‘e’ # optional, default ‘e’. Options: ‘e’, ‘2’, ‘10’ eps: 1.0e-8 # optional, default 1e-8
- variables#
- data_types = ['input', 'target']#
- eps#
- forward(batch: dict) dict#
- class credit.preblock.SqrtTransform(variables: list[str], data_types: list[str] = None)#
Bases:
credit.preblock.base.BasePreblockApplies a log transformation to specified variables in a batch dict.
- Expected dict structure:
batch[source][data_type][‘source/var_type/var_shape/var_name’]
- Config example:
type: “sqrt_transform” args:
- variables:
‘ERA5/prognostic/3D/Q’
- data_types: # optional, defaults to [‘input’, ‘target’]
‘input’
‘target’
- variables#
- data_types = ['input', 'target']#
- forward(batch: dict) dict#
- class credit.preblock.BridgeScalerTransformer(scaler_path: str, variables: list[str], method: str)#
Bases:
credit.preblock.base.BasePreblockScaling preblock using a fitted bridgescaler dict.
Applies per-variable z-score scaling (or its inverse) to tensors in a nested batch dict of the form
batch[source][data_type][var_key].The scaler dict must have been fit with
bridgescaler.scale_var_dictusing the same nested structure and saved withbridgescaler.save_scaler_dict.Example config:
type: "bridgescaler_transform" args: scaler_path: "/path/to/scaler.json" variables: - "era5/prognostic/3d/T" - "era5/prognostic/3d/U" method: "transform"
- variables#
- method#
- scaler_path#
- scaler#
- forward(batch: dict) dict#
- class credit.preblock.Regridder(weight_file, variables: list[str], data_types: list[str] = None, reshape_to_xy=True, flip_axis=None)#
Bases:
credit.preblock.base.BasePreblockRegridding layer using weights file provided by the ESMF library. :param weight_file: path to weights file :param variables: list of variable keys to regrid (e.g. [‘era5/prognostic/3d/T’]) :param data_types: list of data types to process (default: [‘input’, ‘target’]) :param reshape_to_xy: whether to reshape the flattened array back to xy coordinates :param flip_axis: axes to flip before regridding (e.g. [-1, -2]) :type flip_axis: list, tuple, or None
- variables#
- data_types = ['input', 'target']#
- reshape_to_xy = True#
- flip_axis = None#
- n_a#
- n_b#
- dst_shape#
- _W = None#
- _W_device = None#
- _get_W(device)#
- _regrid(x: torch.Tensor) torch.Tensor#
- forward(batch: dict) dict#
- class credit.preblock.ConcatToTensor(*args: Any, **kwargs: Any)#
Bases:
credit.preblock.base.BasePreblockEnd-of-chain preblock that concatenates a nested batch dict of tensors into a single input tensor (and optionally a target tensor).
Expects a batch dict of the form:
batch[source][data_type][var_name] -> torch.Tensor
where tensor shapes are (batch, channel, time, lon, lat) and concatenation is performed along dim=1 (channel). Traversal order follows key insertion order: for each source, all var_names under a data_type are concatenated, then the next source, and so on.
metadatakeys are passed through as-is (not concatenated).Returns either:
(input_tensor, metadata) # if no "target" data_type present (input_tensor, target_tensor, metadata) # if "target" is present
Example config:
type: "concatenate_to_tensor" args: {}
- forward(batch)#
- credit.preblock.PREBLOCK_REGISTRY#
- credit.preblock.build_preblocks(preblock_cfg: dict) torch.nn.ModuleDict#
Instantiates all preblocks from the config’s ‘preblocks’ section.
- Parameters:
preblock_cfg –
the full preblocks dict from the config, e.g.: {
’era5_log_transform’: {‘type’: ‘log_transform’, ‘args’: {…}}, ‘era5_z_transform’: {‘type’: ‘z_transform’, ‘args’: {…}},
}
- Returns:
nn.ModuleDict of instantiated preblocks, ordered as in config.
- credit.preblock.apply_preblocks(preblocks: torch.nn.ModuleDict, batch: dict)#
Sequentially applies transform preblocks (dict→dict), then concatenates to tensors.
Concatenation is always performed last and is not configurable.