credit.datasets.multi_source#
MultiSourceDataset: primary entry point for all data loading.
Wraps one or more registered source datasets and returns a dict nested by
source name. Only sources whose keys appear under config["source"] are
instantiated; absent sources are silently skipped.
Sample structure returned by __getitem__:
{
"input": {<user_provided_name>: {"<user_provided_name>/prognostic/3d/T": tensor, ...}, ...},
"target": {<user_provided_name>: {"<user_provided_name>/prognostic/3d/T": tensor, ...}, ...}, # return_target only
"metadata": {<user_provided_name>: {"input_datetime": int, "target_datetime": int}, ...},
}
Usage:
from credit.datasets.multi_source import MultiSourceDataset
from credit.samplers import DistributedMultiStepBatchSampler
from torch.utils.data import DataLoader
dataset = MultiSourceDataset(config["data"], return_target=True)
sampler = DistributedMultiStepBatchSampler(dataset, batch_size=4,
shuffle=True, num_replicas=1, rank=0)
loader = DataLoader(dataset, batch_sampler=sampler, num_workers=4)
Extending with a new source:
# In _SOURCE_REGISTRY, add:
"NEW_SOURCE": ("credit.datasets.new_source", "NewSourceDataset"),
# The dataset class must accept (config, return_target) and expose
# a ``datetimes`` attribute (pd.DatetimeIndex).
Attributes#
Classes#
CREDIT Dataset that combines multiple source datasets. |
Functions#
|
Return a modified config dict containing only the specified source. |
|
Return the appropriate Dataset class based on the "dataset_type" field in the source config. |
Module Contents#
- credit.datasets.multi_source.logger#
- credit.datasets.multi_source._SOURCE_REGISTRY: dict[str, tuple[str, str]]#
- credit.datasets.multi_source.make_single_source_subconfig(config: dict[str, Any], user_dataset_name: str) dict[str, Any]#
Return a modified config dict containing only the specified source.
This is used internally to instantiate each sub-dataset with a config containing just its own source config block, to avoid confusion with multisource config fields (e.g. HRRR vs HRRR_NAT vs HRRR_SUBH).
- Parameters:
config – Original multisource config dict.
user_dataset_name – Unique dataset name specified by the user in config[“source”] (e.g. “Example_ERA5”).
- Returns:
New config dict containing only the specified source’s config block.
- credit.datasets.multi_source.route_to_dataset_class(source_cfg: dict[str, Any]) type#
Return the appropriate Dataset class based on the “dataset_type” field in the source config.
The module containing the class is imported lazily on first call so that optional heavy dependencies are not loaded unless this source type is used.
- Parameters:
source_cfg – Config dict for a single source (e.g. config[“source”][“Example_ERA5”]).
- Returns:
Dataset class corresponding to the “dataset_type” field.
- Raises:
ValueError – If the “dataset_type” field is missing or does not correspond to a registered dataset.
- class credit.datasets.multi_source.MultiSourceDataset(config: dict[str, Any], return_target: bool = False)#
Bases:
credit.datasets.base_dataset.AbstractBaseDatasetCREDIT Dataset that combines multiple source datasets.
Instantiates one sub-dataset per source key found in
config["source"], computes the intersection of their valid timestamps, and delegates each__getitem__call to all active sub-datasets.See module docstring for full output structure and usage examples.
Note that we inherit from AbstractBaseDataset _rather_ than BaseDataset.
- datasets#
Ordered mapping of lowercase source name to its Dataset instance (e.g.
{"era5": ERA5Dataset, "mrms": MRMSDataset}).
- datetimes#
DatetimeIndex of timestamps valid for all active sources (intersection of each source’s own
datetimes).
- static_metadata#
Per-source static metadata aggregated from each sub-dataset’s
static_metadataattribute.
- datasets: dict[str, credit.datasets.base_dataset.BaseDataset]#
- datetimes: pandas.DatetimeIndex#
- static_metadata: dict[str, dict[str, Any]]#
- __len__() int#
- __getitem__(args: tuple[pandas.Timestamp, int]) dict[str, dict[str, Any]]#
Return a dict of per-source sample dicts.
- Parameters:
args –
(t, i)where t is the current timestamp (nanoseconds or pd.Timestamp) and i is the within-sequence step index produced by the sampler.- Returns:
Dict keyed by data type, each value being a dict of source name to that source’s data:
{"input": {"era5": {...}, ...}, "target": {...}, "metadata": {...}}
- _build_master_clock(config: dict[str, Any]) pandas.DatetimeIndex#
Build the master sampling clock from the global config.
The clock is anchored to the global
start_datetime,end_datetime, andtimestep. For each source:Normal sources (no
temporal_mode): the clock is filtered to timestamps that exist exactly in that source’s native datetimes. A warning is emitted when the source’s native timestep differs from the master clock timestep andtemporal_modeis not set.Persist sources (
temporal_mode: persist): the clock is only clipped to the source’s coverage range. Fine-resolution master-clock ticks are snapped to the last native timestamp insideBaseDataset.__getitem__.