credit.postblock#
Submodules#
Attributes#
Classes#
Splits |
|
Post-processing layer that applies wet mask to ocean predictions. |
|
Scaling postblock using a fitted bridgescaler dict. |
|
This module fixes tracer values by replacing their values to a given threshold |
|
This module applies global mass conservation fixes for both dry air and water budget. |
|
Base class for all neural network modules. |
|
This module applys global energy conservation fixes. The output ensures that the global sum |
|
GeopotentialDiagnostic is a neural network module used for computing geopotential |
Functions#
|
|
|
Instantiate postblocks for a single phase from a two-section config. |
|
Apply a postblock group built by |
Package Contents#
- class credit.postblock.Reconstruct(*args: Any, **kwargs: Any)#
Bases:
credit.postblock.base.BasePostblockSplits
batch_dict["y_pred"]into a nested variable dict atbatch_dict["y_processed"].Slices are read from
batch_dict["metadata"]["target"]["_channel_map"], built byConcatToTensorand covering only prognostic + diagnostic variables. Each slice is unflattened from(B, n_levels * n_time, H, W)back to(B, n_levels, n_time, H, W).y_predis left untouched. All other keys inbatch_dictpass through unchanged.- forward(batch_dict: dict) dict#
- class credit.postblock.WetMaskBlock(conf, key: str = 'prediction')#
Bases:
torch.nn.ModulePost-processing layer that applies wet mask to ocean predictions. Zero trainable parameters, but mask influences gradients.
Masks predictions so land points = 0, ocean points preserve values. This encourages the model to focus learning on ocean regions.
- key = 'prediction'#
- forward(batch_dict: dict) dict#
Apply wet mask to
batch_dict[self.key](land=0, ocean preserved).
- class credit.postblock.BridgeScalerTransformer(scaler_path: str, variables: list[str], method: str, key: str = 'y_processed')#
Bases:
credit.postblock.base.BasePostblockScaling postblock using a fitted bridgescaler dict.
Applies per-variable scaling (or its inverse) to the nested prediction dict at
batch_dict[key], which has the formbatch_dict[key][source][var_key]wherevar_keyis"source/field_type/dim/varname"(e.g."era5/prognostic/3d/T").Defaults to operating on
"y_processed"— the nested dict written byReconstruct. Usemethod="inverse_transform"to convert normalized predictions back to physical units before physics fixers.The scaler dict must have been fit with
bridgescaler.scale_var_dictusing the same nested structure and saved withbridgescaler.save_scaler_dict.Example config:
type: "bridgescaler_transform" args: scaler_path: "/path/to/scaler.json" variables: - "era5/prognostic/3d/T" - "era5/prognostic/3d/U" method: "inverse_transform"
- variables#
- method#
- scaler_path#
- key = 'y_processed'#
- scaler#
- forward(batch_dict: dict) dict#
- class credit.postblock.TracerFixer(post_conf)#
Bases:
torch.nn.ModuleThis module fixes tracer values by replacing their values to a given threshold (e.g., tracer[tracer<thres] = thres).
- Parameters:
post_conf (dict) – config dictionary that includes all specs for the tracer fixer.
- tracer_indices#
- tracer_thres#
- tracer_thres_max#
- forward(x)#
- class credit.postblock.GlobalMassFixer(post_conf)#
Bases:
torch.nn.ModuleThis module applies global mass conservation fixes for both dry air and water budget. The output ensures that the global dry air mass and global water budgets are conserved through correction ratios applied during model runs. Variables specific total water and precipitation will be corrected to close the budget. All corrections are done using float32 PyTorch tensors.
- Parameters:
post_conf (dict) – config dictionary that includes all specs for the global mass fixer.
- q_ind_start#
- q_ind_end#
- forward(x)#
- class credit.postblock.GlobalWaterFixer(post_conf)#
Bases:
torch.nn.ModuleBase class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes:
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their parameters converted when you call
to(), etc.Note
As per the example above, an
__init__()call to the parent class must be made before assignment on the child.- Variables:
training (bool) – Boolean represents whether this module is in training or evaluation mode.
- q_ind_start#
- q_ind_end#
- precip_ind#
- evapor_ind#
- forward(x)#
- class credit.postblock.GlobalEnergyFixer(post_conf)#
Bases:
torch.nn.ModuleThis module applys global energy conservation fixes. The output ensures that the global sum of total energy in the atmosphere is balanced by radiantion and energy fluxes at the top of the atmosphere and the surface. Variables air temperature will be modified to close the budget. All corrections are done using float32 Pytorch tensors.
- Parameters:
post_conf (dict) – config dictionary that includes all specs for the global energy fixer.
- T_ind_start#
- T_ind_end#
- q_ind_start#
- q_ind_end#
- U_ind_start#
- U_ind_end#
- V_ind_start#
- V_ind_end#
- TOA_solar_ind#
- TOA_OLR_ind#
- surf_solar_ind#
- surf_LR_ind#
- surf_SH_ind#
- surf_LH_ind#
- forward(x)#
- class credit.postblock.GeopotentialDiagnostic(output_name: str = 'ARCO_ERA5/derived_diagnostic/3d/geopotential', dataset_name: str = 'ARCO_ERA5', chunk_size: int = 1000, data_keys: Iterable[str] = ('prediction', 'target'), surface_geopotential_var: str = 'ARCO_ERA5/static/2d/geopotential_at_surface', surface_pressure_var: str = 'ARCO_ERA5/prognostic/2d/surface_pressure', temperature_var: str = 'ARCO_ERA5/prognostic/3d/temperature', specific_humidity_var: str = 'ARCO_ERA5/prognostic/3d/specific_humidity', flip_vertical: bool = True, level_info_file: str = 'ERA5_Lev_Info.nc', model_a_half_var: str = 'a_half', model_b_half_var: str = 'b_half', static_source_key: str = 'ic_raw', levels: list[int] | None = None)#
Bases:
torch.nn.ModuleGeopotentialDiagnostic is a neural network module used for computing geopotential diagnostics using multi-dimensional input data.
This class processes geophysical variables such as surface geopotential, surface pressure, temperature, and specific humidity to calculate geopotential fields. The input data is expected to conform to a specific format, and the class makes use of auxiliary metadata files that describe model-specific level information.
- output_name#
The key used in the dataset to store the computed geopotential diagnostic output.
- Type:
str
- dataset_name#
The name of the dataset from which input variables will be retrieved.
- Type:
str
- chunk_size#
The chunk size used for vectorized computations to optimize memory usage during processing.
- Type:
int
- data_keys#
The keys in the input data dictionary that will be processed (e.g., “prediction”, “target”).
- Type:
Iterable[str]
- surface_geopotential_var#
The key for the surface geopotential variable in the dataset.
- Type:
str
- surface_pressure_var#
The key for the surface pressure variable in the dataset.
- Type:
str
- temperature_var#
The key for the temperature variable in the dataset.
- Type:
str
- specific_humidity_var#
The key for the specific humidity variable in the dataset.
- Type:
str
- flip_vertical#
Whether to flip the vertical dimension of the input tensors. Default True
- Type:
bool
- level_info_file#
The filename of the auxiliary metadata file that stores information about model levels.
- Type:
str
- model_a_half_var#
The variable name for the a (pressure) hybrid sigma-pressure coefficient in the level information file.
- Type:
str
- model_b_half_var#
The variable name for the b (sigma) hybrid sigma-pressure coefficient parameter in the level information file.
- Type:
str
- output_name = 'ARCO_ERA5/derived_diagnostic/3d/geopotential'#
- dataset_name = 'ARCO_ERA5'#
- chunk_size = 1000#
- data_keys = ('prediction', 'target')#
- surface_geopotential_var = 'ARCO_ERA5/static/2d/geopotential_at_surface'#
- surface_pressure_var = 'ARCO_ERA5/prognostic/2d/surface_pressure'#
- temperature_var = 'ARCO_ERA5/prognostic/3d/temperature'#
- specific_humidity_var = 'ARCO_ERA5/prognostic/3d/specific_humidity'#
- flip_vertical = True#
- level_info_file#
- model_a_half_var = 'a_half'#
- model_b_half_var = 'b_half'#
- static_source_key = 'ic_raw'#
- levels = None#
- forward(data_dict: dict)#
Processes a dictionary of input data, rearranges dimensions, computes derived quantities using a custom function geopotential, and updates the data dictionary with the results.
- Parameters:
data_dict (dict) – Input dictionary containing data corresponding to various data types. The data for each type is expected to be organized into specified attributes (e.g., temperature, specific humidity).
- Returns:
Updated data dictionary, where new computed fields are added to the relevant dataset, preserving the original structure.
- Return type:
dict
- Raises:
ValueError – If any required data type is not found in the input data_dict.
- credit.postblock.POSTBLOCK_REGISTRY#
- credit.postblock._VALID_SECTIONS#
- credit.postblock._build_postblock_section(section_cfg: dict) torch.nn.ModuleDict#
- credit.postblock.build_postblocks(postblock_cfg: dict | None = None, phase: str = 'per_step') torch.nn.ModuleDict#
Instantiate postblocks for a single phase from a two-section config.
Config format:
postblocks: per_step: # run after every forward pass in the rollout loop reconstruct: type: reconstruct inverse_scale: type: bridgescaler_transform args: method: inverse_transform scaler_path: /path/to/scaler.json post_rollout: # run once after all rollout steps complete mass_fixer: type: global_mass_fixer args: ...
Typical usage — build once per phase, store separately:
step_postblocks = build_postblocks(cfg, phase="per_step") rollout_postblocks = build_postblocks(cfg, phase="post_rollout") # inside rollout loop, after each forward pass: full_data_dict = apply_postblocks(step_postblocks, full_data_dict) # once after rollout loop completes: apply_postblocks(rollout_postblocks, full_data_dict)
- Parameters:
postblock_cfg – the full
postblocksconfig dict (both sections).phase – which section to build —
"per_step"or"post_rollout".
- Returns:
nn.ModuleDictof instantiated blocks for the requested phase.- Raises:
ValueError – if the config contains keys other than
"per_step"/"post_rollout", or ifphaseis not one of those values.
- credit.postblock.apply_postblocks(postblocks: torch.nn.ModuleDict, batch_dict: dict) dict#
Apply a postblock group built by
build_postblocks.- Parameters:
postblocks –
nn.ModuleDictbuilt bybuild_postblocksfor a single phase.batch_dict – dict containing at minimum
"y_pred"and"metadata".
- Returns:
The same
batch_dictafter all blocks in the group have run.