credit.datasets#
Submodules#
- credit.datasets.MRMS
- credit.datasets._file_utils
- credit.datasets.count_channels
- credit.datasets.datamap
- credit.datasets.downscaling_dataset
- credit.datasets.era5
- credit.datasets.era5_multistep
- credit.datasets.era5_multistep_batcher
- credit.datasets.era5_singlestep
- credit.datasets.les_singlestep
- credit.datasets.load_dataset_and_dataloader
- credit.datasets.mrms_download
- credit.datasets.multi_source
- credit.datasets.om4_multistep_batcher
- credit.datasets.realtime_predict
- credit.datasets.sequential_multistep
- credit.datasets.wrf_singlestep
- credit.datasets.wrfmultistep
Classes#
PyTorch Dataset that combines multiple source datasets. |
|
PyTorch Dataset for processed ERA5 data with nested input/target structure. |
|
PyTorch Dataset for Google Cloud ARCO ERA5 data with nested input/target structure. |
|
PyTorch Dataset for MRMS data with nested input/target structure. |
Package Contents#
- class credit.datasets.MultiSourceDataset(config: dict, return_target: bool = False)#
Bases:
torch.utils.data.DatasetPyTorch Dataset that combines multiple source datasets.
Instantiates one sub-dataset per source key found in
config["source"], computes the intersection of their valid timestamps, and delegates each__getitem__call to all active sub-datasets.See module docstring for full output structure and usage examples.
- datasets#
Ordered mapping of lowercase source name to its Dataset instance (e.g.
{"era5": ERA5Dataset, "mrms": MRMSDataset}).
- datetimes#
DatetimeIndex of timestamps valid for all active sources (intersection of each source’s own
datetimes).
- static_metadata#
Per-source static metadata aggregated from each sub-dataset’s
static_metadataattribute. Example:{ "era5": {"levels": [1000, 850, 500, 300], "datetime_fmt": "unix_ns"}, "mrms": {"datetime_fmt": "unix_ns"}, }
- datasets: dict[str, torch.utils.data.Dataset]#
- datetimes: pandas.DatetimeIndex#
- static_metadata: dict[str, dict]#
- __len__() int#
- __getitem__(args: tuple) dict#
Return a dict of per-source sample dicts.
The
(t, i)tuple is passed unchanged to every active sub-dataset. Each sub-dataset applies its own field-type / step-index logic (e.g. ERA5 and MRMS both skip prognostic disk reads ati > 0).- Parameters:
args –
(t, i)where t is the current timestamp (nanoseconds or pd.Timestamp) and i is the within-sequence step index produced by the sampler.- Returns:
Dict keyed by lowercase source name. Each value is the sub-dataset’s own sample dict:
{"input": {...}, "target": {...}, "metadata": {...}}
- _intersect_timestamps() pandas.DatetimeIndex#
Return timestamps common to all active source datasets.
- Returns:
Sorted DatetimeIndex containing only timestamps present in every source’s
datetimesindex. Returns an empty DatetimeIndex when no sources are configured.
- class credit.datasets.ERA5Dataset(config: dict, return_target: bool = False)#
Bases:
torch.utils.data.DatasetPyTorch Dataset for processed ERA5 data with nested input/target structure.
See module docstring for full description of output format and file naming.
Example YAML configuration:
data: source: ERA5: level_coord: "level" levels: [10, 30, 40, 50, 60, 70, 80, 90, 95, 100, 105, 110, 120, 130, 136, 137] variables: prognostic: vars_3D: ['T', 'U', 'V', 'Q'] vars_2D: ['SP', 't2m'] path: "/data/era5_*.zarr" filename_time_format: "%Y" # annual (default) dynamic_forcing: vars_2D: ['tsi'] path: "/data/solar_*.nc" filename_time_format: "%Y_%m" # monthly static: vars_2D: ['Z_GDS4_SFC', 'LSM'] path: "/data/lsm.nc" # single file — filename_time_format not needed diagnostic: null start_datetime: "2017-01-01" end_datetime: "2019-12-31" timestep: "6h" forecast_len: 1
- Assumptions:
A “time” dimension / coordinate is present for non-static fields.
A level coordinate (name given by
level_coord) represents the vertical axis of 3D variables.Dimension order: (time, level, latitude, longitude) for 3D; (time, latitude, longitude) for 2D; (latitude, longitude) for static.
- source_name: str = 'era5'#
- level_coord: str#
- levels: list[int]#
- return_target: bool = False#
- static_metadata: dict#
- dt#
- num_forecast_steps: int#
- start_datetime#
- end_datetime#
- datetimes: pandas.DatetimeIndex#
- file_dict: dict[str, list[tuple[pandas.Timestamp, pandas.Timestamp, str]] | None]#
- var_dict: dict[str, dict[str, list[str]]]#
- __len__() int#
- __getitem__(args: tuple) dict#
Return a nested input/target sample dict.
- Parameters:
args –
(t, i)where t is the current timestamp (nanoseconds or pd.Timestamp) and i is the within-sequence step index produced by the sampler. Wheni == 0prognostic and static fields are loaded in addition to dynamic forcing.- Returns:
Dict with keys
"input","metadata", and optionally"target"(whenreturn_target=True). Both"input"and"target"are dicts of per-variable tensors keyed by"era5/{field_type}/{dim}/{varname}".
- _register_field(field_type: str, d: dict | None) None#
Validate and register one field type from the config variables block.
Populates
self.file_dictandself.var_dictfor field_type.- Parameters:
field_type – One of
"prognostic","dynamic_forcing","static","diagnostic".d – Field-type config dict, or
None/ null to disable the field.
- Raises:
KeyError – If field_type is not a recognised field type.
ValueError – If d defines neither
vars_3Dnorvars_2D.
- _build_timestamps() pandas.DatetimeIndex#
Return valid initialisation timestamps for the dataset.
- Returns:
DatetimeIndex from
start_datetimetoend_datetimeminus the forecast horizon, at the configured timestep frequency.
- _extract_field(field_type: str, t: pandas.Timestamp, sample: dict) None#
Open the dataset for field_type at time t and populate sample.
Keys written are
"era5/{field_type}/3d/{varname}"for 3D variables and"era5/{field_type}/2d/{varname}"for 2D variables.- Parameters:
field_type – One of
"prognostic","dynamic_forcing","static","diagnostic".t – Timestamp to select.
sample –
Dict to write variable tensors into (modified in place). Tensor shapes (no batch dimension):
3D variable:
(n_levels, 1, lat, lon)2D variable:
(1, 1, lat, lon)
- static _to_cftime(ts: pandas.Timestamp, calendar: str) cftime.datetime#
Convert a pandas Timestamp to a cftime.datetime.
- Parameters:
ts – Pandas Timestamp to convert.
calendar – cftime calendar string read from the dataset (e.g.
"noleap","gregorian","proleptic_gregorian").
- Returns:
cftime.datetime with the specified calendar.
- class credit.datasets.ARCOERA5Dataset(data_config: dict, return_target: bool = False)#
Bases:
torch.utils.data.DatasetPyTorch Dataset for Google Cloud ARCO ERA5 data with nested input/target structure.
See module docstring for full description of output format and file naming.
Example YAML configuration:
data: source: ARCO_ERA5: level_coord: "hybrid" levels: [10, 30, 40, 50, 60, 70, 80, 90, 95, 100, 105, 110, 120, 130, 136, 137] variables: prognostic: vars_3D: ["temperature", "u_component_of_wind", "v_component_of_wind", "specific_humidity"] vars_2D: ["surface_pressure"] dynamic_forcing: vars_2D: ["toa_incident_solar_radiation"] static: vars_2D: ["land_sea_mask"] diagnostic: vars_2D: ["total_precipitation"] start_datetime: "2017-01-01" end_datetime: "2019-12-31" timestep: "6h" forecast_len: 1
- Assumptions:
A “time” dimension / coordinate is present for non-static fields.
A level coordinate (name given by
level_coord) represents the vertical axis of 3D variables.Dimension order: (time, level, latitude, longitude) for 3D; (time, latitude, longitude) for 2D; (latitude, longitude) for static.
- pressure_lev_era5_path = 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3'#
- model_lev_era5_path = 'gs://gcp-public-data-arco-era5/ar/model-level-1h-0p25deg.zarr-v1'#
- model_lev_vars = ['divergence', 'fraction_of_cloud_cover', 'geopotential', 'ozone_mass_mixing_ratio',...#
- source_name: str = 'arco_era5'#
- level_coord: str#
- return_target: bool = False#
- static_metadata: dict#
- dt#
- num_forecast_steps: int#
- start_datetime#
- end_datetime#
- datetimes: pandas.DatetimeIndex#
- var_dict: dict[str, dict[str, list[str]]]#
- fs = None#
- mod_level_store = None#
- pres_level_store = None#
- __len__() int#
- __getitem__(args: tuple) dict#
Return a nested input/target sample dict.
- Parameters:
args –
(t, i)where t is the current timestamp (nanoseconds or pd.Timestamp) and i is the within-sequence step index produced by the sampler. Wheni == 0prognostic and static fields are loaded in addition to dynamic forcing.- Returns:
Dict with keys
"input","metadata", and optionally"target"(whenreturn_target=True). Both"input"and"target"are dicts of per-variable tensors keyed by"arco_era5/{field_type}/{dim}/{varname}".
- _init_fs()#
- _register_field(field_type: str, d: dict | None) None#
Validate and register one field type from the config variables block.
Populates
self.file_dictandself.var_dictfor field_type.- Parameters:
field_type – One of
"prognostic","dynamic_forcing","static","diagnostic".d – Field-type config dict, or
None/ null to disable the field.
- Raises:
KeyError – If field_type is not a recognised field type.
ValueError – If d defines neither
vars_3Dnorvars_2D.
- _build_timestamps() pandas.DatetimeIndex#
Return valid initialisation timestamps for the dataset.
- Returns:
DatetimeIndex from
start_datetimetoend_datetimeminus the forecast horizon, at the configured timestep frequency.
- _extract_field(field_type: str, t: pandas.Timestamp, sample: dict) None#
Open the dataset for field_type at time t and populate sample.
Keys written are
"era5/{field_type}/3d/{varname}"for 3D variables and"era5/{field_type}/2d/{varname}"for 2D variables.- Parameters:
field_type – One of
"prognostic","dynamic_forcing","static","diagnostic".t – Timestamp to select.
sample –
Dict to write variable tensors into (modified in place). Tensor shapes (no batch dimension):
3D variable:
(n_levels, 1, lat, lon)2D variable:
(1, 1, lat, lon)
- static _to_cftime(ts: pandas.Timestamp, calendar: str) cftime.datetime#
Convert a pandas Timestamp to a cftime.datetime.
- Parameters:
ts – Pandas Timestamp to convert.
calendar – cftime calendar string read from the dataset (e.g.
"noleap","gregorian","proleptic_gregorian").
- Returns:
cftime.datetime with the specified calendar.
- class credit.datasets.MRMSDataset(config: dict, return_target: bool = False)#
Bases:
torch.utils.data.DatasetPyTorch Dataset for MRMS data with nested input/target structure.
Field types follow ERA5 conventions:
prognosticvariables appear in both input (at step 0) and target;dynamic_forcingappears in input at every step;diagnosticappears in target only. At stepi > 0the model’s own prognostic predictions are fed back — no disk read occurs for prognostic fields at those steps.Supports loading directly from AWS S3 (remote mode) or from local NetCDF / Zarr files (local mode). Spatial subsetting via
extentis applied at load time on the native MRMS grid.See module docstring for full description of output format and file naming.
Example YAML configuration (local mode):
data: source: MRMS: mode: "local" variables: prognostic: # input at step 0 + target vars_2D: - "MultiSensor_QPE_01H_Pass2_00.00" path: "/data/MRMS_*.nc" filename_time_format: "%Y%m%d-%H%M%S" dynamic_forcing: # input every step vars_2D: - "MultiSensor_QPE_06H_Pass2_00.00" path: "/data/MRMS_*.nc" filename_time_format: "%Y%m%d-%H%M%S" extent: [-130, -60, 20, 55] # [min_lon, max_lon, min_lat, max_lat] start_datetime: "2024-06-01" end_datetime: "2024-07-01" timestep: "6h" forecast_len: 0
Example YAML configuration (remote mode):
data: source: MRMS: mode: "remote" region: "CONUS" variables: prognostic: vars_2D: - "MultiSensor_QPE_01H_Pass2_00.00" extent: [-130, -60, 20, 55]
- Assumptions:
Local files have
time,lat,londimensions/coordinates.Longitude coordinates are in the 0–360 convention (both local and remote).
extentis specified as[min_lon, max_lon, min_lat, max_lat]in either -180-180 or 0-360 format; it is normalised to 0-360 internally.
- source_name: str = 'mrms'#
- return_target: bool = False#
- mode: str#
- region: str#
- extent: list[float] | None#
- static_metadata: dict#
- dt#
- num_forecast_steps: int#
- start_datetime#
- end_datetime#
- datetimes: pandas.DatetimeIndex#
- file_dict: dict[str, list[tuple[pandas.Timestamp, pandas.Timestamp, str]] | None]#
- var_dict: dict[str, dict[str, list[str]]]#
- __len__() int#
- __getitem__(args: tuple) dict#
Return a nested input/target sample dict.
Prognostic fields are loaded into
inputonly at stepi == 0(consistent with ERA5 autoregressive rollout semantics). Dynamic forcing is loaded at every step. Diagnostic fields never appear ininput.- Parameters:
args –
(t, i)where t is the current timestamp (nanoseconds or pd.Timestamp) and i is the within-sequence step index produced by the sampler.- Returns:
Dict with keys
"input","metadata", and optionally"target"(whenreturn_target=True). Both"input"and"target"are dicts of per-variable tensors keyed by"mrms/{field_type}/2d/{varname}".
- _register_field(field_type: str, d: dict | None) None#
Validate and register one field type from the config variables block.
- Parameters:
field_type – One of
"prognostic","diagnostic","dynamic_forcing".d – Field-type config dict, or
None/ null to disable the field.
- Raises:
KeyError – If field_type is not a recognised MRMS field type.
ValueError – If d defines no
vars_2D.
- _build_timestamps() pandas.DatetimeIndex#
Return valid initialisation timestamps for the dataset.
- Returns:
DatetimeIndex from
start_datetimetoend_datetimeminus the forecast horizon, at the configured timestep frequency.
- _extract_field(field_type: str, t: pandas.Timestamp, sample: dict) None#
Load all variables for field_type at time t into sample.
Dispatches to local or remote loading based on
self.mode.- Parameters:
field_type – Registered field type (e.g.
"prognostic").t – Timestamp to load.
sample – Dict to write variable tensors into (modified in place). Tensor shape (no batch dimension):
(1, 1, lat, lon).
- _load_local_var(field_type: str, vname: str, t: pandas.Timestamp)#
Load a single variable from a local NetCDF or Zarr file.
- Parameters:
field_type – Field type key used to look up file intervals.
vname – Variable name within the dataset.
t – Timestamp to select.
- Returns:
2-D numpy array
(lat, lon)after optional extent subsetting.- Raises:
KeyError – If no files are registered for field_type.
- _load_remote_var(vname: str, t: pandas.Timestamp)#
Stream a single variable from the MRMS S3 bucket.
Imports
s3fsandpygriblazily so they are only required when remote mode is actually used.- Parameters:
vname – MRMS variable name (used in the S3 path).
t – Timestamp to fetch.
- Returns:
2-D numpy array
(lat, lon)after optional extent subsetting.