credit.datasets.multi_source
============================

.. py:module:: credit.datasets.multi_source

.. autoapi-nested-parse::

   multi_source.py
   ---------------
   MultiSourceDataset: primary entry point for all data loading.

   Wraps one or more registered source datasets and returns a dict nested by
   source name.  Only sources whose keys appear under ``config["source"]`` are
   instantiated; absent sources are silently skipped.

   Sample structure returned by __getitem__::

       {
           "input":    {<user_provided_name>: {"<user_provided_name>/prognostic/3d/T": tensor, ...}, ...},
           "target":   {<user_provided_name>: {"<user_provided_name>/prognostic/3d/T": tensor, ...}, ...},  # return_target only
           "metadata": {<user_provided_name>: {"input_datetime": int, "target_datetime": int}, ...},
       }

   Usage::

       from credit.datasets.multi_source import MultiSourceDataset
       from credit.samplers import DistributedMultiStepBatchSampler
       from torch.utils.data import DataLoader

       dataset = MultiSourceDataset(config["data"], return_target=True)
       sampler = DistributedMultiStepBatchSampler(dataset, batch_size=4,
                     shuffle=True, num_replicas=1, rank=0)
       loader = DataLoader(dataset, batch_sampler=sampler, num_workers=4)

   Extending with a new source::

       # In _SOURCE_REGISTRY, add:
       "NEW_SOURCE": ("credit.datasets.new_source", "NewSourceDataset"),

       # The dataset class must accept (config, return_target) and expose
       # a ``datetimes`` attribute (pd.DatetimeIndex).



Attributes
----------

.. autoapisummary::

   credit.datasets.multi_source.logger
   credit.datasets.multi_source._SOURCE_REGISTRY


Classes
-------

.. autoapisummary::

   credit.datasets.multi_source.MultiSourceDataset


Functions
---------

.. autoapisummary::

   credit.datasets.multi_source.make_single_source_subconfig
   credit.datasets.multi_source.route_to_dataset_class


Module Contents
---------------

.. py:data:: logger

.. py:data:: _SOURCE_REGISTRY
   :type:  dict[str, tuple[str, str]]

.. py:function:: make_single_source_subconfig(config: dict[str, Any], user_dataset_name: str) -> dict[str, Any]

   Return a modified config dict containing only the specified source.

   This is used internally to instantiate each sub-dataset with a config
   containing just its own source config block, to avoid confusion with
   multisource config fields (e.g. HRRR vs HRRR_NAT vs HRRR_SUBH).

   :param config: Original multisource config dict.
   :param user_dataset_name: Unique dataset name specified by the user in config["source"] (e.g. "Example_ERA5").

   :returns: New config dict containing only the specified source's config block.


.. py:function:: route_to_dataset_class(source_cfg: dict[str, Any]) -> type

   Return the appropriate Dataset class based on the "dataset_type" field in the source config.

   The module containing the class is imported lazily on first call so that
   optional heavy dependencies are not loaded unless this source type is used.

   :param source_cfg: Config dict for a single source (e.g. config["source"]["Example_ERA5"]).

   :returns: Dataset class corresponding to the "dataset_type" field.

   :raises ValueError: If the "dataset_type" field is missing or does not correspond to a registered dataset.


.. py:class:: MultiSourceDataset(config: dict[str, Any], return_target: bool = False)

   Bases: :py:obj:`credit.datasets.base_dataset.AbstractBaseDataset`


   CREDIT Dataset that combines multiple source datasets.

   Instantiates one sub-dataset per source key found in ``config["source"]``,
   computes the intersection of their valid timestamps, and delegates each
   ``__getitem__`` call to all active sub-datasets.

   See module docstring for full output structure and usage examples.

   Note that we inherit from AbstractBaseDataset _rather_ than BaseDataset.

   .. attribute:: datasets

      Ordered mapping of lowercase source name to its Dataset
      instance (e.g. ``{"era5": ERA5Dataset, "mrms": MRMSDataset}``).

   .. attribute:: datetimes

      DatetimeIndex of timestamps valid for *all* active sources
      (intersection of each source's own ``datetimes``).

   .. attribute:: static_metadata

      Per-source static metadata aggregated from each
      sub-dataset's ``static_metadata`` attribute.


   .. py:attribute:: datasets
      :type:  dict[str, credit.datasets.base_dataset.BaseDataset]


   .. py:attribute:: datetimes
      :type:  pandas.DatetimeIndex


   .. py:attribute:: static_metadata
      :type:  dict[str, dict[str, Any]]


   .. py:method:: __len__() -> int


   .. py:method:: __getitem__(args: tuple[pandas.Timestamp, int]) -> dict[str, dict[str, Any]]

      Return a dict of per-source sample dicts.

      :param args: ``(t, i)`` where *t* is the current timestamp (nanoseconds
                   or pd.Timestamp) and *i* is the within-sequence step index
                   produced by the sampler.

      :returns: Dict keyed by data type, each value being a dict of source name
                to that source's data::

                    {"input": {"era5": {...}, ...}, "target": {...}, "metadata": {...}}



   .. py:method:: _build_master_clock(config: dict[str, Any]) -> pandas.DatetimeIndex

      Build the master sampling clock from the global config.

      The clock is anchored to the global ``start_datetime``, ``end_datetime``,
      and ``timestep``.  For each source:

      - **Normal sources** (no ``temporal_mode``): the clock is filtered to
        timestamps that exist exactly in that source's native datetimes.  A
        warning is emitted when the source's native timestep differs from the
        master clock timestep and ``temporal_mode`` is not set.
      - **Persist sources** (``temporal_mode: persist``): the clock is only
        clipped to the source's coverage range.  Fine-resolution master-clock
        ticks are snapped to the last native timestamp inside
        ``BaseDataset.__getitem__``.



