credit.datasets.multi_source#
MultiSourceDataset: primary entry point for all data loading.
Wraps one or more registered source datasets (ERA5, MRMS, …) and returns a
dict nested by source name. Only sources whose keys appear under
config["source"] are instantiated; absent sources are silently skipped.
The wrapper is always used as the entry point — even when only one source
is configured.
Sample structure returned by __getitem__:
{
"era5": {
"input": {"era5/prognostic/3d/T": tensor, ...},
"target": {"era5/prognostic/3d/T": tensor, ...}, # return_target only
"metadata": {"input_datetime": int, "target_datetime": int},
},
"mrms": {
"input": {"mrms/prognostic/2d/MultiSensor_QPE_01H_Pass2_00.00": tensor, ...},
"target": {"mrms/prognostic/2d/MultiSensor_QPE_01H_Pass2_00.00": tensor, ...},
"metadata": {"input_datetime": int, "target_datetime": int},
},
}
Pre-block pipeline (applied after the DataLoader):
MultiSourceDataset
→ ERA5ScalePreBlock (scale on native ERA5 grid)
→ MRMSScalePreBlock (scale on native MRMS grid)
→ MRMSRegridPreBlock (regrid MRMS → ERA5 grid)
→ MergePreBlock (flatten nested dict → single input/target dict)
→ Model
Usage:
from credit.datasets.multi_source import MultiSourceDataset
from credit.samplers import DistributedMultiStepBatchSampler
from torch.utils.data import DataLoader
dataset = MultiSourceDataset(config["data"], return_target=True)
sampler = DistributedMultiStepBatchSampler(dataset, batch_size=4,
shuffle=True, num_replicas=1, rank=0)
loader = DataLoader(dataset, batch_sampler=sampler, num_workers=4)
Extending with a new source:
# In _SOURCE_REGISTRY, add:
"NewSource": NewSourceDataset,
# The dataset class must accept (config, return_target) and expose
# a ``datetimes`` attribute (pd.DatetimeIndex).
Attributes#
Classes#
PyTorch Dataset that combines multiple source datasets. |
Module Contents#
- credit.datasets.multi_source.logger#
- credit.datasets.multi_source._SOURCE_REGISTRY: dict[str, type]#
- class credit.datasets.multi_source.MultiSourceDataset(config: dict, return_target: bool = False)#
Bases:
torch.utils.data.DatasetPyTorch Dataset that combines multiple source datasets.
Instantiates one sub-dataset per source key found in
config["source"], computes the intersection of their valid timestamps, and delegates each__getitem__call to all active sub-datasets.See module docstring for full output structure and usage examples.
- datasets#
Ordered mapping of lowercase source name to its Dataset instance (e.g.
{"era5": ERA5Dataset, "mrms": MRMSDataset}).
- datetimes#
DatetimeIndex of timestamps valid for all active sources (intersection of each source’s own
datetimes).
- static_metadata#
Per-source static metadata aggregated from each sub-dataset’s
static_metadataattribute. Example:{ "era5": {"levels": [1000, 850, 500, 300], "datetime_fmt": "unix_ns"}, "mrms": {"datetime_fmt": "unix_ns"}, }
- datasets: dict[str, torch.utils.data.Dataset]#
- datetimes: pandas.DatetimeIndex#
- static_metadata: dict[str, dict]#
- __len__() int#
- __getitem__(args: tuple) dict#
Return a dict of per-source sample dicts.
The
(t, i)tuple is passed unchanged to every active sub-dataset. Each sub-dataset applies its own field-type / step-index logic (e.g. ERA5 and MRMS both skip prognostic disk reads ati > 0).- Parameters:
args –
(t, i)where t is the current timestamp (nanoseconds or pd.Timestamp) and i is the within-sequence step index produced by the sampler.- Returns:
Dict keyed by lowercase source name. Each value is the sub-dataset’s own sample dict:
{"input": {...}, "target": {...}, "metadata": {...}}
- _intersect_timestamps() pandas.DatetimeIndex#
Return timestamps common to all active source datasets.
- Returns:
Sorted DatetimeIndex containing only timestamps present in every source’s
datetimesindex. Returns an empty DatetimeIndex when no sources are configured.