credit.datasets
===============

.. py:module:: credit.datasets


Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/credit/datasets/MRMS/index
   /autoapi/credit/datasets/_file_utils/index
   /autoapi/credit/datasets/count_channels/index
   /autoapi/credit/datasets/datamap/index
   /autoapi/credit/datasets/downscaling_dataset/index
   /autoapi/credit/datasets/era5/index
   /autoapi/credit/datasets/era5_multistep/index
   /autoapi/credit/datasets/era5_multistep_batcher/index
   /autoapi/credit/datasets/era5_singlestep/index
   /autoapi/credit/datasets/les_singlestep/index
   /autoapi/credit/datasets/load_dataset_and_dataloader/index
   /autoapi/credit/datasets/mrms_download/index
   /autoapi/credit/datasets/multi_source/index
   /autoapi/credit/datasets/om4_multistep_batcher/index
   /autoapi/credit/datasets/realtime_predict/index
   /autoapi/credit/datasets/sequential_multistep/index
   /autoapi/credit/datasets/wrf_singlestep/index
   /autoapi/credit/datasets/wrfmultistep/index


Classes
-------

.. autoapisummary::

   credit.datasets.MultiSourceDataset
   credit.datasets.ERA5Dataset
   credit.datasets.ARCOERA5Dataset
   credit.datasets.MRMSDataset


Package Contents
----------------

.. py:class:: MultiSourceDataset(config: dict, return_target: bool = False)

   Bases: :py:obj:`torch.utils.data.Dataset`


   PyTorch Dataset that combines multiple source datasets.

   Instantiates one sub-dataset per source key found in ``config["source"]``,
   computes the intersection of their valid timestamps, and delegates each
   ``__getitem__`` call to all active sub-datasets.

   See module docstring for full output structure and usage examples.

   .. attribute:: datasets

      Ordered mapping of lowercase source name to its Dataset
      instance (e.g. ``{"era5": ERA5Dataset, "mrms": MRMSDataset}``).

   .. attribute:: datetimes

      DatetimeIndex of timestamps valid for *all* active sources
      (intersection of each source's own ``datetimes``).

   .. attribute:: static_metadata

      Per-source static metadata aggregated from each
      sub-dataset's ``static_metadata`` attribute.  Example::
      
          {
              "era5": {"levels": [1000, 850, 500, 300], "datetime_fmt": "unix_ns"},
              "mrms": {"datetime_fmt": "unix_ns"},
          }


   .. py:attribute:: datasets
      :type:  dict[str, torch.utils.data.Dataset]


   .. py:attribute:: datetimes
      :type:  pandas.DatetimeIndex


   .. py:attribute:: static_metadata
      :type:  dict[str, dict]


   .. py:method:: __len__() -> int


   .. py:method:: __getitem__(args: tuple) -> dict

      Return a dict of per-source sample dicts.

      The ``(t, i)`` tuple is passed unchanged to every active sub-dataset.
      Each sub-dataset applies its own field-type / step-index logic
      (e.g. ERA5 and MRMS both skip prognostic disk reads at ``i > 0``).

      :param args: ``(t, i)`` where *t* is the current timestamp (nanoseconds
                   or pd.Timestamp) and *i* is the within-sequence step index
                   produced by the sampler.

      :returns: Dict keyed by lowercase source name.  Each value is the
                sub-dataset's own sample dict::

                    {"input": {...}, "target": {...}, "metadata": {...}}



   .. py:method:: _intersect_timestamps() -> pandas.DatetimeIndex

      Return timestamps common to all active source datasets.

      :returns: Sorted DatetimeIndex containing only timestamps present in every
                source's ``datetimes`` index.  Returns an empty DatetimeIndex
                when no sources are configured.



.. py:class:: ERA5Dataset(config: dict, return_target: bool = False)

   Bases: :py:obj:`torch.utils.data.Dataset`


   PyTorch Dataset for processed ERA5 data with nested input/target structure.

   See module docstring for full description of output format and file naming.

   Example YAML configuration::

       data:
         source:
           ERA5:
             level_coord: "level"
             levels: [10, 30, 40, 50, 60, 70, 80, 90, 95, 100, 105, 110, 120, 130, 136, 137]
             variables:
               prognostic:
                 vars_3D: ['T', 'U', 'V', 'Q']
                 vars_2D: ['SP', 't2m']
                 path: "/data/era5_*.zarr"
                 filename_time_format: "%Y"        # annual (default)
               dynamic_forcing:
                 vars_2D: ['tsi']
                 path: "/data/solar_*.nc"
                 filename_time_format: "%Y_%m"     # monthly
               static:
                 vars_2D: ['Z_GDS4_SFC', 'LSM']
                 path: "/data/lsm.nc"
                 # single file — filename_time_format not needed
               diagnostic: null

         start_datetime: "2017-01-01"
         end_datetime: "2019-12-31"
         timestep: "6h"
         forecast_len: 1

   Assumptions:
       1. A "time" dimension / coordinate is present for non-static fields.
       2. A level coordinate (name given by ``level_coord``) represents the
          vertical axis of 3D variables.
       3. Dimension order: (time, level, latitude, longitude) for 3D;
          (time, latitude, longitude) for 2D; (latitude, longitude) for static.


   .. py:attribute:: source_name
      :type:  str
      :value: 'era5'



   .. py:attribute:: level_coord
      :type:  str


   .. py:attribute:: levels
      :type:  list[int]


   .. py:attribute:: return_target
      :type:  bool
      :value: False



   .. py:attribute:: static_metadata
      :type:  dict


   .. py:attribute:: dt


   .. py:attribute:: num_forecast_steps
      :type:  int


   .. py:attribute:: start_datetime


   .. py:attribute:: end_datetime


   .. py:attribute:: datetimes
      :type:  pandas.DatetimeIndex


   .. py:attribute:: file_dict
      :type:  dict[str, list[tuple[pandas.Timestamp, pandas.Timestamp, str]] | None]


   .. py:attribute:: var_dict
      :type:  dict[str, dict[str, list[str]]]


   .. py:method:: __len__() -> int


   .. py:method:: __getitem__(args: tuple) -> dict

      Return a nested input/target sample dict.

      :param args: ``(t, i)`` where *t* is the current timestamp (nanoseconds
                   or pd.Timestamp) and *i* is the within-sequence step index
                   produced by the sampler. When ``i == 0`` prognostic and static
                   fields are loaded in addition to dynamic forcing.

      :returns: Dict with keys ``"input"``, ``"metadata"``, and optionally
                ``"target"`` (when ``return_target=True``). Both ``"input"`` and
                ``"target"`` are dicts of per-variable tensors keyed by
                ``"era5/{field_type}/{dim}/{varname}"``.



   .. py:method:: _register_field(field_type: str, d: dict | None) -> None

      Validate and register one field type from the config variables block.

      Populates ``self.file_dict`` and ``self.var_dict`` for *field_type*.

      :param field_type: One of ``"prognostic"``, ``"dynamic_forcing"``,
                         ``"static"``, ``"diagnostic"``.
      :param d: Field-type config dict, or ``None`` / null to disable the field.

      :raises KeyError: If *field_type* is not a recognised field type.
      :raises ValueError: If *d* defines neither ``vars_3D`` nor ``vars_2D``.



   .. py:method:: _build_timestamps() -> pandas.DatetimeIndex

      Return valid initialisation timestamps for the dataset.

      :returns: DatetimeIndex from ``start_datetime`` to ``end_datetime`` minus
                the forecast horizon, at the configured timestep frequency.



   .. py:method:: _extract_field(field_type: str, t: pandas.Timestamp, sample: dict) -> None

      Open the dataset for *field_type* at time *t* and populate *sample*.

      Keys written are ``"era5/{field_type}/3d/{varname}"`` for 3D variables
      and ``"era5/{field_type}/2d/{varname}"`` for 2D variables.

      :param field_type: One of ``"prognostic"``, ``"dynamic_forcing"``,
                         ``"static"``, ``"diagnostic"``.
      :param t: Timestamp to select.
      :param sample: Dict to write variable tensors into (modified in place).
                     Tensor shapes (no batch dimension):

                     - 3D variable: ``(n_levels, 1, lat, lon)``
                     - 2D variable: ``(1, 1, lat, lon)``



   .. py:method:: _to_cftime(ts: pandas.Timestamp, calendar: str) -> cftime.datetime
      :staticmethod:


      Convert a pandas Timestamp to a cftime.datetime.

      :param ts: Pandas Timestamp to convert.
      :param calendar: cftime calendar string read from the dataset
                       (e.g. ``"noleap"``, ``"gregorian"``, ``"proleptic_gregorian"``).

      :returns: cftime.datetime with the specified calendar.



.. py:class:: ARCOERA5Dataset(data_config: dict, return_target: bool = False)

   Bases: :py:obj:`torch.utils.data.Dataset`


   PyTorch Dataset for Google Cloud ARCO ERA5 data with nested input/target structure.

   See module docstring for full description of output format and file naming.

   Example YAML configuration::

       data:
         source:
           ARCO_ERA5:
             level_coord: "hybrid"
             levels: [10, 30, 40, 50, 60, 70, 80, 90, 95, 100, 105, 110, 120, 130, 136, 137]
             variables:
               prognostic:
                 vars_3D: ["temperature", "u_component_of_wind", "v_component_of_wind", "specific_humidity"]
                 vars_2D: ["surface_pressure"]
               dynamic_forcing:
                 vars_2D: ["toa_incident_solar_radiation"]
               static:
                 vars_2D: ["land_sea_mask"]
               diagnostic:
                 vars_2D: ["total_precipitation"]

         start_datetime: "2017-01-01"
         end_datetime: "2019-12-31"
         timestep: "6h"
         forecast_len: 1

   Assumptions:
       1. A "time" dimension / coordinate is present for non-static fields.
       2. A level coordinate (name given by ``level_coord``) represents the
          vertical axis of 3D variables.
       3. Dimension order: (time, level, latitude, longitude) for 3D;
          (time, latitude, longitude) for 2D; (latitude, longitude) for static.


   .. py:attribute:: pressure_lev_era5_path
      :value: 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3'



   .. py:attribute:: model_lev_era5_path
      :value: 'gs://gcp-public-data-arco-era5/ar/model-level-1h-0p25deg.zarr-v1'



   .. py:attribute:: model_lev_vars
      :value: ['divergence', 'fraction_of_cloud_cover', 'geopotential', 'ozone_mass_mixing_ratio',...



   .. py:attribute:: source_name
      :type:  str
      :value: 'arco_era5'



   .. py:attribute:: level_coord
      :type:  str


   .. py:attribute:: return_target
      :type:  bool
      :value: False



   .. py:attribute:: static_metadata
      :type:  dict


   .. py:attribute:: dt


   .. py:attribute:: num_forecast_steps
      :type:  int


   .. py:attribute:: start_datetime


   .. py:attribute:: end_datetime


   .. py:attribute:: datetimes
      :type:  pandas.DatetimeIndex


   .. py:attribute:: var_dict
      :type:  dict[str, dict[str, list[str]]]


   .. py:attribute:: fs
      :value: None



   .. py:attribute:: mod_level_store
      :value: None



   .. py:attribute:: pres_level_store
      :value: None



   .. py:method:: __len__() -> int


   .. py:method:: __getitem__(args: tuple) -> dict

      Return a nested input/target sample dict.

      :param args: ``(t, i)`` where *t* is the current timestamp (nanoseconds
                   or pd.Timestamp) and *i* is the within-sequence step index
                   produced by the sampler. When ``i == 0`` prognostic and static
                   fields are loaded in addition to dynamic forcing.

      :returns: Dict with keys ``"input"``, ``"metadata"``, and optionally
                ``"target"`` (when ``return_target=True``). Both ``"input"`` and
                ``"target"`` are dicts of per-variable tensors keyed by
                ``"arco_era5/{field_type}/{dim}/{varname}"``.



   .. py:method:: _init_fs()


   .. py:method:: _register_field(field_type: str, d: dict | None) -> None

      Validate and register one field type from the config variables block.

      Populates ``self.file_dict`` and ``self.var_dict`` for *field_type*.

      :param field_type: One of ``"prognostic"``, ``"dynamic_forcing"``,
                         ``"static"``, ``"diagnostic"``.
      :param d: Field-type config dict, or ``None`` / null to disable the field.

      :raises KeyError: If *field_type* is not a recognised field type.
      :raises ValueError: If *d* defines neither ``vars_3D`` nor ``vars_2D``.



   .. py:method:: _build_timestamps() -> pandas.DatetimeIndex

      Return valid initialisation timestamps for the dataset.

      :returns: DatetimeIndex from ``start_datetime`` to ``end_datetime`` minus
                the forecast horizon, at the configured timestep frequency.



   .. py:method:: _extract_field(field_type: str, t: pandas.Timestamp, sample: dict) -> None

      Open the dataset for *field_type* at time *t* and populate *sample*.

      Keys written are ``"era5/{field_type}/3d/{varname}"`` for 3D variables
      and ``"era5/{field_type}/2d/{varname}"`` for 2D variables.

      :param field_type: One of ``"prognostic"``, ``"dynamic_forcing"``,
                         ``"static"``, ``"diagnostic"``.
      :param t: Timestamp to select.
      :param sample: Dict to write variable tensors into (modified in place).
                     Tensor shapes (no batch dimension):

                     - 3D variable: ``(n_levels, 1, lat, lon)``
                     - 2D variable: ``(1, 1, lat, lon)``



   .. py:method:: _to_cftime(ts: pandas.Timestamp, calendar: str) -> cftime.datetime
      :staticmethod:


      Convert a pandas Timestamp to a cftime.datetime.

      :param ts: Pandas Timestamp to convert.
      :param calendar: cftime calendar string read from the dataset
                       (e.g. ``"noleap"``, ``"gregorian"``, ``"proleptic_gregorian"``).

      :returns: cftime.datetime with the specified calendar.



.. py:class:: MRMSDataset(config: dict, return_target: bool = False)

   Bases: :py:obj:`torch.utils.data.Dataset`


   PyTorch Dataset for MRMS data with nested input/target structure.

   Field types follow ERA5 conventions: ``prognostic`` variables appear in
   both input (at step 0) and target; ``dynamic_forcing`` appears in input
   at every step; ``diagnostic`` appears in target only.  At step ``i > 0``
   the model's own prognostic predictions are fed back — no disk read occurs
   for prognostic fields at those steps.

   Supports loading directly from AWS S3 (remote mode) or from local
   NetCDF / Zarr files (local mode). Spatial subsetting via ``extent``
   is applied at load time on the native MRMS grid.

   See module docstring for full description of output format and file naming.

   Example YAML configuration (local mode)::

       data:
         source:
           MRMS:
             mode: "local"
             variables:
               prognostic:                         # input at step 0 + target
                 vars_2D:
                   - "MultiSensor_QPE_01H_Pass2_00.00"
                 path: "/data/MRMS_*.nc"
                 filename_time_format: "%Y%m%d-%H%M%S"
               dynamic_forcing:                    # input every step
                 vars_2D:
                   - "MultiSensor_QPE_06H_Pass2_00.00"
                 path: "/data/MRMS_*.nc"
                 filename_time_format: "%Y%m%d-%H%M%S"
             extent: [-130, -60, 20, 55]   # [min_lon, max_lon, min_lat, max_lat]

         start_datetime: "2024-06-01"
         end_datetime:   "2024-07-01"
         timestep:       "6h"
         forecast_len:   0

   Example YAML configuration (remote mode)::

       data:
         source:
           MRMS:
             mode: "remote"
             region: "CONUS"
             variables:
               prognostic:
                 vars_2D:
                   - "MultiSensor_QPE_01H_Pass2_00.00"
             extent: [-130, -60, 20, 55]

   Assumptions:
       1. Local files have ``time``, ``lat``, ``lon`` dimensions/coordinates.
       2. Longitude coordinates are in the 0–360 convention (both local and remote).
       3. ``extent`` is specified as ``[min_lon, max_lon, min_lat, max_lat]``
          in either -180-180 or 0-360 format; it is normalised to 0-360 internally.


   .. py:attribute:: source_name
      :type:  str
      :value: 'mrms'



   .. py:attribute:: return_target
      :type:  bool
      :value: False



   .. py:attribute:: mode
      :type:  str


   .. py:attribute:: region
      :type:  str


   .. py:attribute:: extent
      :type:  list[float] | None


   .. py:attribute:: static_metadata
      :type:  dict


   .. py:attribute:: dt


   .. py:attribute:: num_forecast_steps
      :type:  int


   .. py:attribute:: start_datetime


   .. py:attribute:: end_datetime


   .. py:attribute:: datetimes
      :type:  pandas.DatetimeIndex


   .. py:attribute:: file_dict
      :type:  dict[str, list[tuple[pandas.Timestamp, pandas.Timestamp, str]] | None]


   .. py:attribute:: var_dict
      :type:  dict[str, dict[str, list[str]]]


   .. py:method:: __len__() -> int


   .. py:method:: __getitem__(args: tuple) -> dict

      Return a nested input/target sample dict.

      Prognostic fields are loaded into ``input`` only at step ``i == 0``
      (consistent with ERA5 autoregressive rollout semantics).  Dynamic
      forcing is loaded at every step.  Diagnostic fields never appear
      in ``input``.

      :param args: ``(t, i)`` where *t* is the current timestamp (nanoseconds
                   or pd.Timestamp) and *i* is the within-sequence step index
                   produced by the sampler.

      :returns: Dict with keys ``"input"``, ``"metadata"``, and optionally
                ``"target"`` (when ``return_target=True``). Both ``"input"`` and
                ``"target"`` are dicts of per-variable tensors keyed by
                ``"mrms/{field_type}/2d/{varname}"``.



   .. py:method:: _register_field(field_type: str, d: dict | None) -> None

      Validate and register one field type from the config variables block.

      :param field_type: One of ``"prognostic"``, ``"diagnostic"``,
                         ``"dynamic_forcing"``.
      :param d: Field-type config dict, or ``None`` / null to disable the field.

      :raises KeyError: If *field_type* is not a recognised MRMS field type.
      :raises ValueError: If *d* defines no ``vars_2D``.



   .. py:method:: _build_timestamps() -> pandas.DatetimeIndex

      Return valid initialisation timestamps for the dataset.

      :returns: DatetimeIndex from ``start_datetime`` to ``end_datetime`` minus
                the forecast horizon, at the configured timestep frequency.



   .. py:method:: _extract_field(field_type: str, t: pandas.Timestamp, sample: dict) -> None

      Load all variables for *field_type* at time *t* into *sample*.

      Dispatches to local or remote loading based on ``self.mode``.

      :param field_type: Registered field type (e.g. ``"prognostic"``).
      :param t: Timestamp to load.
      :param sample: Dict to write variable tensors into (modified in place).
                     Tensor shape (no batch dimension): ``(1, 1, lat, lon)``.



   .. py:method:: _load_local_var(field_type: str, vname: str, t: pandas.Timestamp)

      Load a single variable from a local NetCDF or Zarr file.

      :param field_type: Field type key used to look up file intervals.
      :param vname: Variable name within the dataset.
      :param t: Timestamp to select.

      :returns: 2-D numpy array ``(lat, lon)`` after optional extent subsetting.

      :raises KeyError: If no files are registered for *field_type*.



   .. py:method:: _load_remote_var(vname: str, t: pandas.Timestamp)

      Stream a single variable from the MRMS S3 bucket.

      Imports ``s3fs`` and ``pygrib`` lazily so they are only required
      when remote mode is actually used.

      :param vname: MRMS variable name (used in the S3 path).
      :param t: Timestamp to fetch.

      :returns: 2-D numpy array ``(lat, lon)`` after optional extent subsetting.



