credit.datasets.multi_source
============================

.. py:module:: credit.datasets.multi_source

.. autoapi-nested-parse::

   multi_source.py
   ---------------
   MultiSourceDataset: primary entry point for all data loading.

   Wraps one or more registered source datasets (ERA5, MRMS, …) and returns a
   dict nested by source name.  Only sources whose keys appear under
   ``config["source"]`` are instantiated; absent sources are silently skipped.
   The wrapper is always used as the entry point — even when only one source
   is configured.

   Sample structure returned by __getitem__::

       {
           "era5": {
               "input":    {"era5/prognostic/3d/T": tensor, ...},
               "target":   {"era5/prognostic/3d/T": tensor, ...},  # return_target only
               "metadata": {"input_datetime": int, "target_datetime": int},
           },
           "mrms": {
               "input":    {"mrms/prognostic/2d/MultiSensor_QPE_01H_Pass2_00.00": tensor, ...},
               "target":   {"mrms/prognostic/2d/MultiSensor_QPE_01H_Pass2_00.00": tensor, ...},
               "metadata": {"input_datetime": int, "target_datetime": int},
           },
       }

   Pre-block pipeline (applied after the DataLoader)::

       MultiSourceDataset
           → ERA5ScalePreBlock   (scale on native ERA5 grid)
           → MRMSScalePreBlock   (scale on native MRMS grid)
           → MRMSRegridPreBlock  (regrid MRMS → ERA5 grid)
           → MergePreBlock       (flatten nested dict → single input/target dict)
           → Model

   Usage::

       from credit.datasets.multi_source import MultiSourceDataset
       from credit.samplers import DistributedMultiStepBatchSampler
       from torch.utils.data import DataLoader

       dataset = MultiSourceDataset(config["data"], return_target=True)
       sampler = DistributedMultiStepBatchSampler(dataset, batch_size=4,
                     shuffle=True, num_replicas=1, rank=0)
       loader = DataLoader(dataset, batch_sampler=sampler, num_workers=4)

   Extending with a new source::

       # In _SOURCE_REGISTRY, add:
       "NewSource": NewSourceDataset,

       # The dataset class must accept (config, return_target) and expose
       # a ``datetimes`` attribute (pd.DatetimeIndex).



Attributes
----------

.. autoapisummary::

   credit.datasets.multi_source.logger
   credit.datasets.multi_source._SOURCE_REGISTRY


Classes
-------

.. autoapisummary::

   credit.datasets.multi_source.MultiSourceDataset


Module Contents
---------------

.. py:data:: logger

.. py:data:: _SOURCE_REGISTRY
   :type:  dict[str, type]

.. py:class:: MultiSourceDataset(config: dict, return_target: bool = False)

   Bases: :py:obj:`torch.utils.data.Dataset`


   PyTorch Dataset that combines multiple source datasets.

   Instantiates one sub-dataset per source key found in ``config["source"]``,
   computes the intersection of their valid timestamps, and delegates each
   ``__getitem__`` call to all active sub-datasets.

   See module docstring for full output structure and usage examples.

   .. attribute:: datasets

      Ordered mapping of lowercase source name to its Dataset
      instance (e.g. ``{"era5": ERA5Dataset, "mrms": MRMSDataset}``).

   .. attribute:: datetimes

      DatetimeIndex of timestamps valid for *all* active sources
      (intersection of each source's own ``datetimes``).

   .. attribute:: static_metadata

      Per-source static metadata aggregated from each
      sub-dataset's ``static_metadata`` attribute.  Example::
      
          {
              "era5": {"levels": [1000, 850, 500, 300], "datetime_fmt": "unix_ns"},
              "mrms": {"datetime_fmt": "unix_ns"},
          }


   .. py:attribute:: datasets
      :type:  dict[str, torch.utils.data.Dataset]


   .. py:attribute:: datetimes
      :type:  pandas.DatetimeIndex


   .. py:attribute:: static_metadata
      :type:  dict[str, dict]


   .. py:method:: __len__() -> int


   .. py:method:: __getitem__(args: tuple) -> dict

      Return a dict of per-source sample dicts.

      The ``(t, i)`` tuple is passed unchanged to every active sub-dataset.
      Each sub-dataset applies its own field-type / step-index logic
      (e.g. ERA5 and MRMS both skip prognostic disk reads at ``i > 0``).

      :param args: ``(t, i)`` where *t* is the current timestamp (nanoseconds
                   or pd.Timestamp) and *i* is the within-sequence step index
                   produced by the sampler.

      :returns: Dict keyed by lowercase source name.  Each value is the
                sub-dataset's own sample dict::

                    {"input": {...}, "target": {...}, "metadata": {...}}



   .. py:method:: _intersect_timestamps() -> pandas.DatetimeIndex

      Return timestamps common to all active source datasets.

      :returns: Sorted DatetimeIndex containing only timestamps present in every
                source's ``datetimes`` index.  Returns an empty DatetimeIndex
                when no sources are configured.



