credit.datasets.multi_source#

MultiSourceDataset: primary entry point for all data loading.

Wraps one or more registered source datasets and returns a dict nested by source name. Only sources whose keys appear under config["source"] are instantiated; absent sources are silently skipped.

Sample structure returned by __getitem__:

{
    "input":    {<user_provided_name>: {"<user_provided_name>/prognostic/3d/T": tensor, ...}, ...},
    "target":   {<user_provided_name>: {"<user_provided_name>/prognostic/3d/T": tensor, ...}, ...},  # return_target only
    "metadata": {<user_provided_name>: {"input_datetime": int, "target_datetime": int}, ...},
}

Usage:

from credit.datasets.multi_source import MultiSourceDataset
from credit.samplers import DistributedMultiStepBatchSampler
from torch.utils.data import DataLoader

dataset = MultiSourceDataset(config["data"], return_target=True)
sampler = DistributedMultiStepBatchSampler(dataset, batch_size=4,
              shuffle=True, num_replicas=1, rank=0)
loader = DataLoader(dataset, batch_sampler=sampler, num_workers=4)

Extending with a new source:

# In _SOURCE_REGISTRY, add:
"NEW_SOURCE": ("credit.datasets.new_source", "NewSourceDataset"),

# The dataset class must accept (config, return_target) and expose
# a ``datetimes`` attribute (pd.DatetimeIndex).

Attributes#

Classes#

MultiSourceDataset

CREDIT Dataset that combines multiple source datasets.

Functions#

make_single_source_subconfig(→ dict[str, Any])

Return a modified config dict containing only the specified source.

route_to_dataset_class(→ type)

Return the appropriate Dataset class based on the "dataset_type" field in the source config.

Module Contents#

credit.datasets.multi_source.logger#
credit.datasets.multi_source._SOURCE_REGISTRY: dict[str, tuple[str, str]]#
credit.datasets.multi_source.make_single_source_subconfig(config: dict[str, Any], user_dataset_name: str) dict[str, Any]#

Return a modified config dict containing only the specified source.

This is used internally to instantiate each sub-dataset with a config containing just its own source config block, to avoid confusion with multisource config fields (e.g. HRRR vs HRRR_NAT vs HRRR_SUBH).

Parameters:
  • config – Original multisource config dict.

  • user_dataset_name – Unique dataset name specified by the user in config[“source”] (e.g. “Example_ERA5”).

Returns:

New config dict containing only the specified source’s config block.

credit.datasets.multi_source.route_to_dataset_class(source_cfg: dict[str, Any]) type#

Return the appropriate Dataset class based on the “dataset_type” field in the source config.

The module containing the class is imported lazily on first call so that optional heavy dependencies are not loaded unless this source type is used.

Parameters:

source_cfg – Config dict for a single source (e.g. config[“source”][“Example_ERA5”]).

Returns:

Dataset class corresponding to the “dataset_type” field.

Raises:

ValueError – If the “dataset_type” field is missing or does not correspond to a registered dataset.

class credit.datasets.multi_source.MultiSourceDataset(config: dict[str, Any], return_target: bool = False)#

Bases: credit.datasets.base_dataset.AbstractBaseDataset

CREDIT Dataset that combines multiple source datasets.

Instantiates one sub-dataset per source key found in config["source"], computes the intersection of their valid timestamps, and delegates each __getitem__ call to all active sub-datasets.

See module docstring for full output structure and usage examples.

Note that we inherit from AbstractBaseDataset _rather_ than BaseDataset.

datasets#

Ordered mapping of lowercase source name to its Dataset instance (e.g. {"era5": ERA5Dataset, "mrms": MRMSDataset}).

datetimes#

DatetimeIndex of timestamps valid for all active sources (intersection of each source’s own datetimes).

static_metadata#

Per-source static metadata aggregated from each sub-dataset’s static_metadata attribute.

datasets: dict[str, credit.datasets.base_dataset.BaseDataset]#
datetimes: pandas.DatetimeIndex#
static_metadata: dict[str, dict[str, Any]]#
__len__() int#
__getitem__(args: tuple[pandas.Timestamp, int]) dict[str, dict[str, Any]]#

Return a dict of per-source sample dicts.

Parameters:

args(t, i) where t is the current timestamp (nanoseconds or pd.Timestamp) and i is the within-sequence step index produced by the sampler.

Returns:

Dict keyed by data type, each value being a dict of source name to that source’s data:

{"input": {"era5": {...}, ...}, "target": {...}, "metadata": {...}}

_build_master_clock(config: dict[str, Any]) pandas.DatetimeIndex#

Build the master sampling clock from the global config.

The clock is anchored to the global start_datetime, end_datetime, and timestep. For each source:

  • Normal sources (no temporal_mode): the clock is filtered to timestamps that exist exactly in that source’s native datetimes. A warning is emitted when the source’s native timestep differs from the master clock timestep and temporal_mode is not set.

  • Persist sources (temporal_mode: persist): the clock is only clipped to the source’s coverage range. Fine-resolution master-clock ticks are snapped to the last native timestamp inside BaseDataset.__getitem__.