credit.postblock

Contents

credit.postblock#

Submodules#

Attributes#

Classes#

Reconstruct

Splits batch_dict["y_pred"] into a nested variable dict at batch_dict["y_processed"].

WetMaskBlock

Post-processing layer that applies wet mask to ocean predictions.

BridgeScalerTransformer

Scaling postblock using a fitted bridgescaler dict.

TracerFixer

This module fixes tracer values by replacing their values to a given threshold

GlobalMassFixer

This module applies global mass conservation fixes for both dry air and water budget.

GlobalWaterFixer

Base class for all neural network modules.

GlobalEnergyFixer

This module applys global energy conservation fixes. The output ensures that the global sum

GeopotentialDiagnostic

GeopotentialDiagnostic is a neural network module used for computing geopotential

Functions#

_build_postblock_section(→ torch.nn.ModuleDict)

build_postblocks(→ torch.nn.ModuleDict)

Instantiate postblocks for a single phase from a two-section config.

apply_postblocks(→ dict)

Apply a postblock group built by build_postblocks.

Package Contents#

class credit.postblock.Reconstruct(*args: Any, **kwargs: Any)#

Bases: credit.postblock.base.BasePostblock

Splits batch_dict["y_pred"] into a nested variable dict at batch_dict["y_processed"].

Slices are read from batch_dict["metadata"]["target"]["_channel_map"], built by ConcatToTensor and covering only prognostic + diagnostic variables. Each slice is unflattened from (B, n_levels * n_time, H, W) back to (B, n_levels, n_time, H, W). y_pred is left untouched. All other keys in batch_dict pass through unchanged.

forward(batch_dict: dict) dict#
class credit.postblock.WetMaskBlock(conf, key: str = 'prediction')#

Bases: torch.nn.Module

Post-processing layer that applies wet mask to ocean predictions. Zero trainable parameters, but mask influences gradients.

Masks predictions so land points = 0, ocean points preserve values. This encourages the model to focus learning on ocean regions.

key = 'prediction'#
forward(batch_dict: dict) dict#

Apply wet mask to batch_dict[self.key] (land=0, ocean preserved).

class credit.postblock.BridgeScalerTransformer(scaler_path: str, variables: list[str], method: str, key: str = 'y_processed')#

Bases: credit.postblock.base.BasePostblock

Scaling postblock using a fitted bridgescaler dict.

Applies per-variable scaling (or its inverse) to the nested prediction dict at batch_dict[key], which has the form batch_dict[key][source][var_key] where var_key is "source/field_type/dim/varname" (e.g. "era5/prognostic/3d/T").

Defaults to operating on "y_processed" — the nested dict written by Reconstruct. Use method="inverse_transform" to convert normalized predictions back to physical units before physics fixers.

The scaler dict must have been fit with bridgescaler.scale_var_dict using the same nested structure and saved with bridgescaler.save_scaler_dict.

Example config:

type: "bridgescaler_transform"
args:
    scaler_path: "/path/to/scaler.json"
    variables:
        - "era5/prognostic/3d/T"
        - "era5/prognostic/3d/U"
    method: "inverse_transform"
variables#
method#
scaler_path#
key = 'y_processed'#
scaler#
forward(batch_dict: dict) dict#
class credit.postblock.TracerFixer(post_conf)#

Bases: torch.nn.Module

This module fixes tracer values by replacing their values to a given threshold (e.g., tracer[tracer<thres] = thres).

Parameters:

post_conf (dict) – config dictionary that includes all specs for the tracer fixer.

tracer_indices#
tracer_thres#
tracer_thres_max#
forward(x)#
class credit.postblock.GlobalMassFixer(post_conf)#

Bases: torch.nn.Module

This module applies global mass conservation fixes for both dry air and water budget. The output ensures that the global dry air mass and global water budgets are conserved through correction ratios applied during model runs. Variables specific total water and precipitation will be corrected to close the budget. All corrections are done using float32 PyTorch tensors.

Parameters:

post_conf (dict) – config dictionary that includes all specs for the global mass fixer.

q_ind_start#
q_ind_end#
forward(x)#
class credit.postblock.GlobalWaterFixer(post_conf)#

Bases: torch.nn.Module

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

Note

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

Variables:

training (bool) – Boolean represents whether this module is in training or evaluation mode.

q_ind_start#
q_ind_end#
precip_ind#
evapor_ind#
forward(x)#
class credit.postblock.GlobalEnergyFixer(post_conf)#

Bases: torch.nn.Module

This module applys global energy conservation fixes. The output ensures that the global sum of total energy in the atmosphere is balanced by radiantion and energy fluxes at the top of the atmosphere and the surface. Variables air temperature will be modified to close the budget. All corrections are done using float32 Pytorch tensors.

Parameters:

post_conf (dict) – config dictionary that includes all specs for the global energy fixer.

T_ind_start#
T_ind_end#
q_ind_start#
q_ind_end#
U_ind_start#
U_ind_end#
V_ind_start#
V_ind_end#
TOA_solar_ind#
TOA_OLR_ind#
surf_solar_ind#
surf_LR_ind#
surf_SH_ind#
surf_LH_ind#
forward(x)#
class credit.postblock.GeopotentialDiagnostic(output_name: str = 'ARCO_ERA5/derived_diagnostic/3d/geopotential', dataset_name: str = 'ARCO_ERA5', chunk_size: int = 1000, data_keys: Iterable[str] = ('prediction', 'target'), surface_geopotential_var: str = 'ARCO_ERA5/static/2d/geopotential_at_surface', surface_pressure_var: str = 'ARCO_ERA5/prognostic/2d/surface_pressure', temperature_var: str = 'ARCO_ERA5/prognostic/3d/temperature', specific_humidity_var: str = 'ARCO_ERA5/prognostic/3d/specific_humidity', flip_vertical: bool = True, level_info_file: str = 'ERA5_Lev_Info.nc', model_a_half_var: str = 'a_half', model_b_half_var: str = 'b_half', static_source_key: str = 'ic_raw', levels: list[int] | None = None)#

Bases: torch.nn.Module

GeopotentialDiagnostic is a neural network module used for computing geopotential diagnostics using multi-dimensional input data.

This class processes geophysical variables such as surface geopotential, surface pressure, temperature, and specific humidity to calculate geopotential fields. The input data is expected to conform to a specific format, and the class makes use of auxiliary metadata files that describe model-specific level information.

output_name#

The key used in the dataset to store the computed geopotential diagnostic output.

Type:

str

dataset_name#

The name of the dataset from which input variables will be retrieved.

Type:

str

chunk_size#

The chunk size used for vectorized computations to optimize memory usage during processing.

Type:

int

data_keys#

The keys in the input data dictionary that will be processed (e.g., “prediction”, “target”).

Type:

Iterable[str]

surface_geopotential_var#

The key for the surface geopotential variable in the dataset.

Type:

str

surface_pressure_var#

The key for the surface pressure variable in the dataset.

Type:

str

temperature_var#

The key for the temperature variable in the dataset.

Type:

str

specific_humidity_var#

The key for the specific humidity variable in the dataset.

Type:

str

flip_vertical#

Whether to flip the vertical dimension of the input tensors. Default True

Type:

bool

level_info_file#

The filename of the auxiliary metadata file that stores information about model levels.

Type:

str

model_a_half_var#

The variable name for the a (pressure) hybrid sigma-pressure coefficient in the level information file.

Type:

str

model_b_half_var#

The variable name for the b (sigma) hybrid sigma-pressure coefficient parameter in the level information file.

Type:

str

output_name = 'ARCO_ERA5/derived_diagnostic/3d/geopotential'#
dataset_name = 'ARCO_ERA5'#
chunk_size = 1000#
data_keys = ('prediction', 'target')#
surface_geopotential_var = 'ARCO_ERA5/static/2d/geopotential_at_surface'#
surface_pressure_var = 'ARCO_ERA5/prognostic/2d/surface_pressure'#
temperature_var = 'ARCO_ERA5/prognostic/3d/temperature'#
specific_humidity_var = 'ARCO_ERA5/prognostic/3d/specific_humidity'#
flip_vertical = True#
level_info_file#
model_a_half_var = 'a_half'#
model_b_half_var = 'b_half'#
static_source_key = 'ic_raw'#
levels = None#
forward(data_dict: dict)#

Processes a dictionary of input data, rearranges dimensions, computes derived quantities using a custom function geopotential, and updates the data dictionary with the results.

Parameters:

data_dict (dict) – Input dictionary containing data corresponding to various data types. The data for each type is expected to be organized into specified attributes (e.g., temperature, specific humidity).

Returns:

Updated data dictionary, where new computed fields are added to the relevant dataset, preserving the original structure.

Return type:

dict

Raises:

ValueError – If any required data type is not found in the input data_dict.

credit.postblock.POSTBLOCK_REGISTRY#
credit.postblock._VALID_SECTIONS#
credit.postblock._build_postblock_section(section_cfg: dict) torch.nn.ModuleDict#
credit.postblock.build_postblocks(postblock_cfg: dict | None = None, phase: str = 'per_step') torch.nn.ModuleDict#

Instantiate postblocks for a single phase from a two-section config.

Config format:

postblocks:
  per_step:          # run after every forward pass in the rollout loop
    reconstruct:
      type: reconstruct
    inverse_scale:
      type: bridgescaler_transform
      args:
        method: inverse_transform
        scaler_path: /path/to/scaler.json
  post_rollout:      # run once after all rollout steps complete
    mass_fixer:
      type: global_mass_fixer
      args: ...

Typical usage — build once per phase, store separately:

step_postblocks    = build_postblocks(cfg, phase="per_step")
rollout_postblocks = build_postblocks(cfg, phase="post_rollout")

# inside rollout loop, after each forward pass:
full_data_dict = apply_postblocks(step_postblocks, full_data_dict)

# once after rollout loop completes:
apply_postblocks(rollout_postblocks, full_data_dict)
Parameters:
  • postblock_cfg – the full postblocks config dict (both sections).

  • phase – which section to build — "per_step" or "post_rollout".

Returns:

nn.ModuleDict of instantiated blocks for the requested phase.

Raises:

ValueError – if the config contains keys other than "per_step" / "post_rollout", or if phase is not one of those values.

credit.postblock.apply_postblocks(postblocks: torch.nn.ModuleDict, batch_dict: dict) dict#

Apply a postblock group built by build_postblocks.

Parameters:
  • postblocksnn.ModuleDict built by build_postblocks for a single phase.

  • batch_dict – dict containing at minimum "y_pred" and "metadata".

Returns:

The same batch_dict after all blocks in the group have run.