credit.datasets

Contents

credit.datasets#

Submodules#

Classes#

MultiSourceDataset

PyTorch Dataset that combines multiple source datasets.

ERA5Dataset

PyTorch Dataset for processed ERA5 data with nested input/target structure.

ARCOERA5Dataset

PyTorch Dataset for Google Cloud ARCO ERA5 data with nested input/target structure.

MRMSDataset

PyTorch Dataset for MRMS data with nested input/target structure.

Package Contents#

class credit.datasets.MultiSourceDataset(config: dict, return_target: bool = False)#

Bases: torch.utils.data.Dataset

PyTorch Dataset that combines multiple source datasets.

Instantiates one sub-dataset per source key found in config["source"], computes the intersection of their valid timestamps, and delegates each __getitem__ call to all active sub-datasets.

See module docstring for full output structure and usage examples.

datasets#

Ordered mapping of lowercase source name to its Dataset instance (e.g. {"era5": ERA5Dataset, "mrms": MRMSDataset}).

datetimes#

DatetimeIndex of timestamps valid for all active sources (intersection of each source’s own datetimes).

static_metadata#

Per-source static metadata aggregated from each sub-dataset’s static_metadata attribute. Example:

{
    "era5": {"levels": [1000, 850, 500, 300], "datetime_fmt": "unix_ns"},
    "mrms": {"datetime_fmt": "unix_ns"},
}
datasets: dict[str, torch.utils.data.Dataset]#
datetimes: pandas.DatetimeIndex#
static_metadata: dict[str, dict]#
__len__() int#
__getitem__(args: tuple) dict#

Return a dict of per-source sample dicts.

The (t, i) tuple is passed unchanged to every active sub-dataset. Each sub-dataset applies its own field-type / step-index logic (e.g. ERA5 and MRMS both skip prognostic disk reads at i > 0).

Parameters:

args(t, i) where t is the current timestamp (nanoseconds or pd.Timestamp) and i is the within-sequence step index produced by the sampler.

Returns:

Dict keyed by lowercase source name. Each value is the sub-dataset’s own sample dict:

{"input": {...}, "target": {...}, "metadata": {...}}

_intersect_timestamps() pandas.DatetimeIndex#

Return timestamps common to all active source datasets.

Returns:

Sorted DatetimeIndex containing only timestamps present in every source’s datetimes index. Returns an empty DatetimeIndex when no sources are configured.

class credit.datasets.ERA5Dataset(config: dict, return_target: bool = False)#

Bases: torch.utils.data.Dataset

PyTorch Dataset for processed ERA5 data with nested input/target structure.

See module docstring for full description of output format and file naming.

Example YAML configuration:

data:
  source:
    ERA5:
      level_coord: "level"
      levels: [10, 30, 40, 50, 60, 70, 80, 90, 95, 100, 105, 110, 120, 130, 136, 137]
      variables:
        prognostic:
          vars_3D: ['T', 'U', 'V', 'Q']
          vars_2D: ['SP', 't2m']
          path: "/data/era5_*.zarr"
          filename_time_format: "%Y"        # annual (default)
        dynamic_forcing:
          vars_2D: ['tsi']
          path: "/data/solar_*.nc"
          filename_time_format: "%Y_%m"     # monthly
        static:
          vars_2D: ['Z_GDS4_SFC', 'LSM']
          path: "/data/lsm.nc"
          # single file — filename_time_format not needed
        diagnostic: null

  start_datetime: "2017-01-01"
  end_datetime: "2019-12-31"
  timestep: "6h"
  forecast_len: 1
Assumptions:
  1. A “time” dimension / coordinate is present for non-static fields.

  2. A level coordinate (name given by level_coord) represents the vertical axis of 3D variables.

  3. Dimension order: (time, level, latitude, longitude) for 3D; (time, latitude, longitude) for 2D; (latitude, longitude) for static.

source_name: str = 'era5'#
level_coord: str#
levels: list[int]#
return_target: bool = False#
static_metadata: dict#
dt#
num_forecast_steps: int#
start_datetime#
end_datetime#
datetimes: pandas.DatetimeIndex#
file_dict: dict[str, list[tuple[pandas.Timestamp, pandas.Timestamp, str]] | None]#
var_dict: dict[str, dict[str, list[str]]]#
__len__() int#
__getitem__(args: tuple) dict#

Return a nested input/target sample dict.

Parameters:

args(t, i) where t is the current timestamp (nanoseconds or pd.Timestamp) and i is the within-sequence step index produced by the sampler. When i == 0 prognostic and static fields are loaded in addition to dynamic forcing.

Returns:

Dict with keys "input", "metadata", and optionally "target" (when return_target=True). Both "input" and "target" are dicts of per-variable tensors keyed by "era5/{field_type}/{dim}/{varname}".

_register_field(field_type: str, d: dict | None) None#

Validate and register one field type from the config variables block.

Populates self.file_dict and self.var_dict for field_type.

Parameters:
  • field_type – One of "prognostic", "dynamic_forcing", "static", "diagnostic".

  • d – Field-type config dict, or None / null to disable the field.

Raises:
  • KeyError – If field_type is not a recognised field type.

  • ValueError – If d defines neither vars_3D nor vars_2D.

_build_timestamps() pandas.DatetimeIndex#

Return valid initialisation timestamps for the dataset.

Returns:

DatetimeIndex from start_datetime to end_datetime minus the forecast horizon, at the configured timestep frequency.

_extract_field(field_type: str, t: pandas.Timestamp, sample: dict) None#

Open the dataset for field_type at time t and populate sample.

Keys written are "era5/{field_type}/3d/{varname}" for 3D variables and "era5/{field_type}/2d/{varname}" for 2D variables.

Parameters:
  • field_type – One of "prognostic", "dynamic_forcing", "static", "diagnostic".

  • t – Timestamp to select.

  • sample

    Dict to write variable tensors into (modified in place). Tensor shapes (no batch dimension):

    • 3D variable: (n_levels, 1, lat, lon)

    • 2D variable: (1, 1, lat, lon)

static _to_cftime(ts: pandas.Timestamp, calendar: str) cftime.datetime#

Convert a pandas Timestamp to a cftime.datetime.

Parameters:
  • ts – Pandas Timestamp to convert.

  • calendar – cftime calendar string read from the dataset (e.g. "noleap", "gregorian", "proleptic_gregorian").

Returns:

cftime.datetime with the specified calendar.

class credit.datasets.ARCOERA5Dataset(data_config: dict, return_target: bool = False)#

Bases: torch.utils.data.Dataset

PyTorch Dataset for Google Cloud ARCO ERA5 data with nested input/target structure.

See module docstring for full description of output format and file naming.

Example YAML configuration:

data:
  source:
    ARCO_ERA5:
      level_coord: "hybrid"
      levels: [10, 30, 40, 50, 60, 70, 80, 90, 95, 100, 105, 110, 120, 130, 136, 137]
      variables:
        prognostic:
          vars_3D: ["temperature", "u_component_of_wind", "v_component_of_wind", "specific_humidity"]
          vars_2D: ["surface_pressure"]
        dynamic_forcing:
          vars_2D: ["toa_incident_solar_radiation"]
        static:
          vars_2D: ["land_sea_mask"]
        diagnostic:
          vars_2D: ["total_precipitation"]

  start_datetime: "2017-01-01"
  end_datetime: "2019-12-31"
  timestep: "6h"
  forecast_len: 1
Assumptions:
  1. A “time” dimension / coordinate is present for non-static fields.

  2. A level coordinate (name given by level_coord) represents the vertical axis of 3D variables.

  3. Dimension order: (time, level, latitude, longitude) for 3D; (time, latitude, longitude) for 2D; (latitude, longitude) for static.

pressure_lev_era5_path = 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3'#
model_lev_era5_path = 'gs://gcp-public-data-arco-era5/ar/model-level-1h-0p25deg.zarr-v1'#
model_lev_vars = ['divergence', 'fraction_of_cloud_cover', 'geopotential', 'ozone_mass_mixing_ratio',...#
source_name: str = 'arco_era5'#
level_coord: str#
return_target: bool = False#
static_metadata: dict#
dt#
num_forecast_steps: int#
start_datetime#
end_datetime#
datetimes: pandas.DatetimeIndex#
var_dict: dict[str, dict[str, list[str]]]#
fs = None#
mod_level_store = None#
pres_level_store = None#
__len__() int#
__getitem__(args: tuple) dict#

Return a nested input/target sample dict.

Parameters:

args(t, i) where t is the current timestamp (nanoseconds or pd.Timestamp) and i is the within-sequence step index produced by the sampler. When i == 0 prognostic and static fields are loaded in addition to dynamic forcing.

Returns:

Dict with keys "input", "metadata", and optionally "target" (when return_target=True). Both "input" and "target" are dicts of per-variable tensors keyed by "arco_era5/{field_type}/{dim}/{varname}".

_init_fs()#
_register_field(field_type: str, d: dict | None) None#

Validate and register one field type from the config variables block.

Populates self.file_dict and self.var_dict for field_type.

Parameters:
  • field_type – One of "prognostic", "dynamic_forcing", "static", "diagnostic".

  • d – Field-type config dict, or None / null to disable the field.

Raises:
  • KeyError – If field_type is not a recognised field type.

  • ValueError – If d defines neither vars_3D nor vars_2D.

_build_timestamps() pandas.DatetimeIndex#

Return valid initialisation timestamps for the dataset.

Returns:

DatetimeIndex from start_datetime to end_datetime minus the forecast horizon, at the configured timestep frequency.

_extract_field(field_type: str, t: pandas.Timestamp, sample: dict) None#

Open the dataset for field_type at time t and populate sample.

Keys written are "era5/{field_type}/3d/{varname}" for 3D variables and "era5/{field_type}/2d/{varname}" for 2D variables.

Parameters:
  • field_type – One of "prognostic", "dynamic_forcing", "static", "diagnostic".

  • t – Timestamp to select.

  • sample

    Dict to write variable tensors into (modified in place). Tensor shapes (no batch dimension):

    • 3D variable: (n_levels, 1, lat, lon)

    • 2D variable: (1, 1, lat, lon)

static _to_cftime(ts: pandas.Timestamp, calendar: str) cftime.datetime#

Convert a pandas Timestamp to a cftime.datetime.

Parameters:
  • ts – Pandas Timestamp to convert.

  • calendar – cftime calendar string read from the dataset (e.g. "noleap", "gregorian", "proleptic_gregorian").

Returns:

cftime.datetime with the specified calendar.

class credit.datasets.MRMSDataset(config: dict, return_target: bool = False)#

Bases: torch.utils.data.Dataset

PyTorch Dataset for MRMS data with nested input/target structure.

Field types follow ERA5 conventions: prognostic variables appear in both input (at step 0) and target; dynamic_forcing appears in input at every step; diagnostic appears in target only. At step i > 0 the model’s own prognostic predictions are fed back — no disk read occurs for prognostic fields at those steps.

Supports loading directly from AWS S3 (remote mode) or from local NetCDF / Zarr files (local mode). Spatial subsetting via extent is applied at load time on the native MRMS grid.

See module docstring for full description of output format and file naming.

Example YAML configuration (local mode):

data:
  source:
    MRMS:
      mode: "local"
      variables:
        prognostic:                         # input at step 0 + target
          vars_2D:
            - "MultiSensor_QPE_01H_Pass2_00.00"
          path: "/data/MRMS_*.nc"
          filename_time_format: "%Y%m%d-%H%M%S"
        dynamic_forcing:                    # input every step
          vars_2D:
            - "MultiSensor_QPE_06H_Pass2_00.00"
          path: "/data/MRMS_*.nc"
          filename_time_format: "%Y%m%d-%H%M%S"
      extent: [-130, -60, 20, 55]   # [min_lon, max_lon, min_lat, max_lat]

  start_datetime: "2024-06-01"
  end_datetime:   "2024-07-01"
  timestep:       "6h"
  forecast_len:   0

Example YAML configuration (remote mode):

data:
  source:
    MRMS:
      mode: "remote"
      region: "CONUS"
      variables:
        prognostic:
          vars_2D:
            - "MultiSensor_QPE_01H_Pass2_00.00"
      extent: [-130, -60, 20, 55]
Assumptions:
  1. Local files have time, lat, lon dimensions/coordinates.

  2. Longitude coordinates are in the 0–360 convention (both local and remote).

  3. extent is specified as [min_lon, max_lon, min_lat, max_lat] in either -180-180 or 0-360 format; it is normalised to 0-360 internally.

source_name: str = 'mrms'#
return_target: bool = False#
mode: str#
region: str#
extent: list[float] | None#
static_metadata: dict#
dt#
num_forecast_steps: int#
start_datetime#
end_datetime#
datetimes: pandas.DatetimeIndex#
file_dict: dict[str, list[tuple[pandas.Timestamp, pandas.Timestamp, str]] | None]#
var_dict: dict[str, dict[str, list[str]]]#
__len__() int#
__getitem__(args: tuple) dict#

Return a nested input/target sample dict.

Prognostic fields are loaded into input only at step i == 0 (consistent with ERA5 autoregressive rollout semantics). Dynamic forcing is loaded at every step. Diagnostic fields never appear in input.

Parameters:

args(t, i) where t is the current timestamp (nanoseconds or pd.Timestamp) and i is the within-sequence step index produced by the sampler.

Returns:

Dict with keys "input", "metadata", and optionally "target" (when return_target=True). Both "input" and "target" are dicts of per-variable tensors keyed by "mrms/{field_type}/2d/{varname}".

_register_field(field_type: str, d: dict | None) None#

Validate and register one field type from the config variables block.

Parameters:
  • field_type – One of "prognostic", "diagnostic", "dynamic_forcing".

  • d – Field-type config dict, or None / null to disable the field.

Raises:
  • KeyError – If field_type is not a recognised MRMS field type.

  • ValueError – If d defines no vars_2D.

_build_timestamps() pandas.DatetimeIndex#

Return valid initialisation timestamps for the dataset.

Returns:

DatetimeIndex from start_datetime to end_datetime minus the forecast horizon, at the configured timestep frequency.

_extract_field(field_type: str, t: pandas.Timestamp, sample: dict) None#

Load all variables for field_type at time t into sample.

Dispatches to local or remote loading based on self.mode.

Parameters:
  • field_type – Registered field type (e.g. "prognostic").

  • t – Timestamp to load.

  • sample – Dict to write variable tensors into (modified in place). Tensor shape (no batch dimension): (1, 1, lat, lon).

_load_local_var(field_type: str, vname: str, t: pandas.Timestamp)#

Load a single variable from a local NetCDF or Zarr file.

Parameters:
  • field_type – Field type key used to look up file intervals.

  • vname – Variable name within the dataset.

  • t – Timestamp to select.

Returns:

2-D numpy array (lat, lon) after optional extent subsetting.

Raises:

KeyError – If no files are registered for field_type.

_load_remote_var(vname: str, t: pandas.Timestamp)#

Stream a single variable from the MRMS S3 bucket.

Imports s3fs and pygrib lazily so they are only required when remote mode is actually used.

Parameters:
  • vname – MRMS variable name (used in the S3 path).

  • t – Timestamp to fetch.

Returns:

2-D numpy array (lat, lon) after optional extent subsetting.