How to add a new postblock#
In this example we will be going over adding a new postblock, Foo to CREDIT.
The postblocks are named such because they are applied after the forward pass of the model itself, but before the loss calculations. For a more detailed explaination of the currently support post-blocks see Sha et al. 2025.
Reference#
Sha, Yingkai, et al. “Improving AI weather prediction models using global mass and energy conservation schemes.” arXiv preprint arXiv:2501.05648 (2025).
Currently supported postblocks#
Non-Negative Filter#
This example shows how to prevent 7 variables from going below zero
tracer_fixer:
activate: True
denorm: True
tracer_name: ['Qtot','PRECT','U10','CLDTOT','CLDHGH','CLDLOW','CLDMED']
tracer_thres: [0, 0, 0, 0, 0, 0, 0]
Global Mass Fixer#
This example shows how to fix the model global mass
global_mass_fixer:
activate: True
activate_outside_model: True
fix_level_num: 14 #specify presure or sigma level
simple_demo: False
denorm: True
grid_type: 'sigma' #specify presure or sigma level
midpoint: True
lon_lat_level_name: ['lon2d', 'lat2d', 'hyai', 'hybi']
surface_pressure_name: ['PS']
specific_total_water_name: ['Qtot']
Global Water Fixer#
This example shows how to fix the model water balance
global_water_fixer:
activate: True
activate_outside_model: True
simple_demo: False
denorm: True
grid_type: 'sigma' #specify presure or sigma level
midpoint: True
lon_lat_level_name: ['lon2d', 'lat2d', 'hyai', 'hybi']
surface_pressure_name: ['PS']
specific_total_water_name: ['Qtot']
precipitation_name: ['PRECT']
evaporation_name: ['QFLX']
Global Energy Fixer#
global_energy_fixer:
activate: True
activate_outside_model: True
simple_demo: False
denorm: True
grid_type: 'sigma'
midpoint: True
lon_lat_level_name: ['lon2d', 'lat2d', 'hyai', 'hybi']
surface_pressure_name: ['PS']
air_temperature_name: ['T']
specific_total_water_name: ['Qtot']
u_wind_name: ['U']
v_wind_name: ['V']
surface_geopotential_name: ['PHIS']
TOA_net_radiation_flux_name: ['FSNT', 'FLNT']
surface_net_radiation_flux_name: ['FSNS', 'FLNS']
surface_energy_flux_name: ['SHFLX', 'LHFLX',]
Create code for new postblock#
One can add a new class to credit/postblock.py or define a new module and import it into credit/postblock.py. See credit/skebs.py for an example of the latter.
The parser will add the data and model fields from the main config to post_conf. Inside the class Foo you will be able to access these.
from torch import nn
class Foo(nn.Module):
def __init__(self, post_conf):
super().__init__()
self.bar = post_conf["foo"]["bar"]
# accessing data or model conf
lead_time_periods = post_conf["data"]["lead_time_periods"]
def forward(self, x):
# x will be a dictionary of the previous state x and y_pred
# of the model up to this point
# both tensors will be in the transformed space
y_pred = x["y_pred"]
x_prev = x["x"]
# do stuff ...
x["y_pred"] = y_pred # pack back into the dictionary
return x
Define config fields#
Inside of your config you will need to add a new field for your postblock.
model:
...
post_conf:
...
foo:
activate: True
bar: 1.0
...
Add to postblock module#
Inside credit/postblock.py, append your postblock to the list of postblock operations self.operations, the order that you want it.
from credit.skebs import SKEBS
import logger
from torch import nn
class PostBlock(nn.Module):
def __init__(self, post_conf):
...
# SKEBS
if post_conf["skebs"]["activate"]:
logger.info("SKEBS registered")
opt = SKEBS(post_conf)
self.operations.append(opt)
...
if post_conf["foo"]["activate"]:
logger.info("foo registered")
opt = Foo(post_conf)
self.operations.append(opt)
...