credit.datasets.MRMS
====================

.. py:module:: credit.datasets.MRMS

.. autoapi-nested-parse::

   MRMS.py
   -------------------------------------------------------
   MRMSDataset with nested input/target structure.

   Field type semantics (mirrors ERA5 conventions):
       prognostic      — input at step 0 AND target; model prediction fed back
                         at step > 0 (autoregressive rollout)
       diagnostic      — target only; not fed back into the model
       dynamic_forcing — input at every step; never a target

   Sample structure returned by __getitem__:

       {
           "input": {
               "mrms/prognostic/2d/MultiSensor_QPE_01H_Pass2_00.00": tensor,
               "mrms/dynamic_forcing/2d/MultiSensor_QPE_06H_Pass2_00.00": tensor,
               ...
           },
           "target": {                                  # only when return_target=True
               "mrms/prognostic/2d/MultiSensor_QPE_01H_Pass2_00.00": tensor,
               ...
           },
           "metadata": {
               "input_datetime":  int,                  # nanoseconds since epoch
               "target_datetime": int,                  # only when return_target=True
           },
       }

   All MRMS variables are 2D. Tensor shape (no batch dimension):
       (1, 1, lat, lon)   — singleton level dim, consistent with ERA5 2D convention

   After DataLoader collation the batch dimension is prepended:
       (batch, 1, 1, lat, lon)

   Modes:
       local  — load from NetCDF (.nc) or Zarr (.zarr) files on disk using
                the same ``filename_time_format`` strftime convention as ERA5.
       remote — stream directly from AWS S3 (noaa-mrms-pds, anonymous access)
                via s3fs + pygrib.

   File naming (local mode):
       Controlled by the optional ``filename_time_format`` config key.
       Defaults to ``"%Y%m%d-%H%M%S"`` (one file per timestamp).

       Examples::

           filename_time_format: "%Y%m%d-%H%M%S"   # MRMS_20240601-060000.nc
           filename_time_format: "%Y%m%d"           # MRMS_20240601.nc  (daily)
           filename_time_format: "%Y%m"             # MRMS_202406.nc    (monthly)

       If only a single file matches the glob pattern, ``filename_time_format``
       is ignored and that file is used for all timestamps.



Attributes
----------

.. autoapisummary::

   credit.datasets.MRMS.logger
   credit.datasets.MRMS.VALID_FIELD_TYPES
   credit.datasets.MRMS._S3_URI


Classes
-------

.. autoapisummary::

   credit.datasets.MRMS.MRMSDataset


Functions
---------

.. autoapisummary::

   credit.datasets.MRMS._apply_extent


Module Contents
---------------

.. py:data:: logger

.. py:data:: VALID_FIELD_TYPES

.. py:data:: _S3_URI
   :value: 's3://noaa-mrms-pds/{region}/{varname}/{date_str}/MRMS_{varname}_{datetime_str}.grib2.gz'


.. py:function:: _apply_extent(da: xarray.DataArray, extent: list[float] | None) -> xarray.DataArray

   Subset *da* to a spatial extent if provided.

   :param da: DataArray with ``lat`` and ``lon`` coordinates (0-360 longitude).
   :param extent: ``[min_lon, max_lon, min_lat, max_lat]`` in either -180–180 or
                  0-360 format; normalised to 0-360 internally.  ``None`` returns
                  *da* unchanged.

   :returns: Spatially subsetted DataArray, or *da* unchanged if *extent* is ``None``.


.. py:class:: MRMSDataset(config: dict, return_target: bool = False)

   Bases: :py:obj:`torch.utils.data.Dataset`


   PyTorch Dataset for MRMS data with nested input/target structure.

   Field types follow ERA5 conventions: ``prognostic`` variables appear in
   both input (at step 0) and target; ``dynamic_forcing`` appears in input
   at every step; ``diagnostic`` appears in target only.  At step ``i > 0``
   the model's own prognostic predictions are fed back — no disk read occurs
   for prognostic fields at those steps.

   Supports loading directly from AWS S3 (remote mode) or from local
   NetCDF / Zarr files (local mode). Spatial subsetting via ``extent``
   is applied at load time on the native MRMS grid.

   See module docstring for full description of output format and file naming.

   Example YAML configuration (local mode)::

       data:
         source:
           MRMS:
             mode: "local"
             variables:
               prognostic:                         # input at step 0 + target
                 vars_2D:
                   - "MultiSensor_QPE_01H_Pass2_00.00"
                 path: "/data/MRMS_*.nc"
                 filename_time_format: "%Y%m%d-%H%M%S"
               dynamic_forcing:                    # input every step
                 vars_2D:
                   - "MultiSensor_QPE_06H_Pass2_00.00"
                 path: "/data/MRMS_*.nc"
                 filename_time_format: "%Y%m%d-%H%M%S"
             extent: [-130, -60, 20, 55]   # [min_lon, max_lon, min_lat, max_lat]

         start_datetime: "2024-06-01"
         end_datetime:   "2024-07-01"
         timestep:       "6h"
         forecast_len:   0

   Example YAML configuration (remote mode)::

       data:
         source:
           MRMS:
             mode: "remote"
             region: "CONUS"
             variables:
               prognostic:
                 vars_2D:
                   - "MultiSensor_QPE_01H_Pass2_00.00"
             extent: [-130, -60, 20, 55]

   Assumptions:
       1. Local files have ``time``, ``lat``, ``lon`` dimensions/coordinates.
       2. Longitude coordinates are in the 0–360 convention (both local and remote).
       3. ``extent`` is specified as ``[min_lon, max_lon, min_lat, max_lat]``
          in either -180-180 or 0-360 format; it is normalised to 0-360 internally.


   .. py:attribute:: source_name
      :type:  str
      :value: 'mrms'



   .. py:attribute:: return_target
      :type:  bool
      :value: False



   .. py:attribute:: mode
      :type:  str


   .. py:attribute:: region
      :type:  str


   .. py:attribute:: extent
      :type:  list[float] | None


   .. py:attribute:: static_metadata
      :type:  dict


   .. py:attribute:: dt


   .. py:attribute:: num_forecast_steps
      :type:  int


   .. py:attribute:: start_datetime


   .. py:attribute:: end_datetime


   .. py:attribute:: datetimes
      :type:  pandas.DatetimeIndex


   .. py:attribute:: file_dict
      :type:  dict[str, list[tuple[pandas.Timestamp, pandas.Timestamp, str]] | None]


   .. py:attribute:: var_dict
      :type:  dict[str, dict[str, list[str]]]


   .. py:method:: __len__() -> int


   .. py:method:: __getitem__(args: tuple) -> dict

      Return a nested input/target sample dict.

      Prognostic fields are loaded into ``input`` only at step ``i == 0``
      (consistent with ERA5 autoregressive rollout semantics).  Dynamic
      forcing is loaded at every step.  Diagnostic fields never appear
      in ``input``.

      :param args: ``(t, i)`` where *t* is the current timestamp (nanoseconds
                   or pd.Timestamp) and *i* is the within-sequence step index
                   produced by the sampler.

      :returns: Dict with keys ``"input"``, ``"metadata"``, and optionally
                ``"target"`` (when ``return_target=True``). Both ``"input"`` and
                ``"target"`` are dicts of per-variable tensors keyed by
                ``"mrms/{field_type}/2d/{varname}"``.



   .. py:method:: _register_field(field_type: str, d: dict | None) -> None

      Validate and register one field type from the config variables block.

      :param field_type: One of ``"prognostic"``, ``"diagnostic"``,
                         ``"dynamic_forcing"``.
      :param d: Field-type config dict, or ``None`` / null to disable the field.

      :raises KeyError: If *field_type* is not a recognised MRMS field type.
      :raises ValueError: If *d* defines no ``vars_2D``.



   .. py:method:: _build_timestamps() -> pandas.DatetimeIndex

      Return valid initialisation timestamps for the dataset.

      :returns: DatetimeIndex from ``start_datetime`` to ``end_datetime`` minus
                the forecast horizon, at the configured timestep frequency.



   .. py:method:: _extract_field(field_type: str, t: pandas.Timestamp, sample: dict) -> None

      Load all variables for *field_type* at time *t* into *sample*.

      Dispatches to local or remote loading based on ``self.mode``.

      :param field_type: Registered field type (e.g. ``"prognostic"``).
      :param t: Timestamp to load.
      :param sample: Dict to write variable tensors into (modified in place).
                     Tensor shape (no batch dimension): ``(1, 1, lat, lon)``.



   .. py:method:: _load_local_var(field_type: str, vname: str, t: pandas.Timestamp)

      Load a single variable from a local NetCDF or Zarr file.

      :param field_type: Field type key used to look up file intervals.
      :param vname: Variable name within the dataset.
      :param t: Timestamp to select.

      :returns: 2-D numpy array ``(lat, lon)`` after optional extent subsetting.

      :raises KeyError: If no files are registered for *field_type*.



   .. py:method:: _load_remote_var(vname: str, t: pandas.Timestamp)

      Stream a single variable from the MRMS S3 bucket.

      Imports ``s3fs`` and ``pygrib`` lazily so they are only required
      when remote mode is actually used.

      :param vname: MRMS variable name (used in the S3 path).
      :param t: Timestamp to fetch.

      :returns: 2-D numpy array ``(lat, lon)`` after optional extent subsetting.



