credit.datasets.mrms
====================

.. py:module:: credit.datasets.mrms

.. autoapi-nested-parse::

   mrms.py
   -------------------------------------------------------
   MRMSDataset: PyTorch Dataset for MRMS data with nested input/target structure.

   Sample structure returned by __getitem__:

       {
           "input":    {<user_provided_name>: {"<user_provided_name>/prognostic/2d/MultiSensor_QPE_01H_Pass2_00.00": tensor,
                                               "<user_provided_name>/prognostic/2d/MultiSensor_QPE_06H_Pass2_00.00": tensor}},
           "target":   {<user_provided_name>: {"<user_provided_name>/prognostic/2d/MultiSensor_QPE_01H_Pass2_00.00": tensor,
                                               "<user_provided_name>/prognostic/2d/MultiSensor_QPE_06H_Pass2_00.00": tensor}},  # only populated when return_target=True
           "metadata": {<user_provided_name>: {"input_datetime": int, "target_datetime": int}},
       }

   All MRMS variables are 2D. Tensor shape (no batch dimension):
       (1, 1, lat, lon)   — singleton level dim, consistent with ERA5 2D convention

   After DataLoader collation the batch dimension is prepended:
       (batch, 1, 1, lat, lon)

   Modes:
       local  — load from NetCDF (.nc) or Zarr (.zarr) files on disk using
                the same ``filename_time_format`` strftime convention as ERA5.
       remote — stream directly from AWS S3 (noaa-mrms-pds, anonymous access)
                via s3fs + pygrib.

   File naming (local mode):
       Controlled by the optional ``filename_time_format`` config key.
       Defaults to ``"%Y%m%d-%H%M%S"`` (one file per timestamp).

       Examples::

           filename_time_format: "%Y%m%d-%H%M%S"   # MRMS_20240601-060000.nc
           filename_time_format: "%Y%m%d"           # MRMS_20240601.nc  (daily)
           filename_time_format: "%Y%m"             # MRMS_202406.nc    (monthly)

       If only a single file matches the glob pattern, ``filename_time_format``
       is ignored and that file is used for all timestamps.



Classes
-------

.. autoapisummary::

   credit.datasets.mrms.MRMSDataset


Functions
---------

.. autoapisummary::

   credit.datasets.mrms._apply_extent


Module Contents
---------------

.. py:function:: _apply_extent(da: xarray.DataArray, extent: list[float] | None) -> xarray.DataArray

   Subset *da* to a spatial extent if provided.

   :param da: DataArray with ``lat`` and ``lon`` coordinates (0-360 longitude).
   :param extent: ``[min_lon, max_lon, min_lat, max_lat]`` in either -180–180 or
                  0-360 format; normalised to 0-360 internally.  ``None`` returns
                  *da* unchanged.

   :returns: Spatially subsetted DataArray, or *da* unchanged if *extent* is ``None``.


.. py:class:: MRMSDataset(data_config: dict[str, Any], return_target: bool = False)

   Bases: :py:obj:`credit.datasets.base_dataset.BaseDataset`


   PyTorch Dataset for MRMS data with nested input/target structure.

   Field types follow CREDIT Gen2 conventions: ``prognostic`` variables appear in
   both input (at step 0) and target; ``dynamic_forcing`` appears in input
   at every step; ``diagnostic`` appears in target only.  At step ``i > 0``
   the model's own prognostic predictions are fed back — no disk read occurs
   for prognostic fields at those steps.

   Supports loading directly from AWS S3 (remote mode) or from local
   NetCDF / Zarr files (local mode). Spatial subsetting via ``extent``
   is applied at load time on the native MRMS grid.

   See module docstring for full description of output format and file naming.

   Example YAML configuration (local mode)::

       data:
         source:
           Example_MRMS:  # User-provided name (arbitrary key)
             dataset_type: "mrms"
             mode: "local"
             variables:
               prognostic:                         # input at step 0 + target
                 vars_2D:
                   - "MultiSensor_QPE_01H_Pass2_00.00"
                 path: "/data/MRMS_*.nc"
                 filename_time_format: "%Y%m%d-%H%M%S"
               dynamic_forcing:                    # input every step
                 vars_2D:
                   - "MultiSensor_QPE_06H_Pass2_00.00"
                 path: "/data/MRMS_*.nc"
                 filename_time_format: "%Y%m%d-%H%M%S"
             extent: [-130, -60, 20, 55]   # [min_lon, max_lon, min_lat, max_lat]

         start_datetime: "2024-06-01"
         end_datetime:   "2024-07-01"
         timestep:       "6h"
         forecast_len:   0

   Example YAML configuration (remote mode)::

       data:
         source:
           Example_MRMS:  # User-provided name (arbitrary key)
             dataset_type: "mrms"
             mode: "remote"
             region: "CONUS"
             variables:
               prognostic:
                 vars_2D:
                   - "MultiSensor_QPE_01H_Pass2_00.00"
             extent: [-130, -60, 20, 55]

   Assumptions:
       1. Local files have ``time``, ``lat``, ``lon`` dimensions/coordinates.
       2. Longitude coordinates are in the 0–360 convention (both local and remote).
       3. ``extent`` is specified as ``[min_lon, max_lon, min_lat, max_lat]``
          in either -180-180 or 0-360 format; it is normalised to 0-360 internally.


   .. py:attribute:: dataset_type
      :type:  str
      :value: 'mrms'



   .. py:attribute:: region
      :type:  str


   .. py:attribute:: extent
      :type:  list[float] | None


   .. py:attribute:: static_metadata
      :type:  dict


   .. py:attribute:: _fs
      :value: None



   .. py:method:: _get_file_source(field_config: dict[str, Any]) -> list[tuple[pandas.Timestamp, pandas.Timestamp, str]] | bool | None

      Return the file source for a field. Override in subclasses for different modes/backends.

      :param field_config: Validated field-type config dict.
      :type field_config: dict[str, Any]

      :raises ValueError: If ``self.mode`` is not a recognised mode.

      :returns:

                Depending on the mode and field type,
                    this method may return a list of (start_time, end_time, file_path) tuples produced by _map_files,
                    a boolean indicating the presence of the field (e.g., for remote data), or None if the field is disabled.
                    The expected return type should be consistent within a dataset class.
      :rtype: list[tuple[pd.Timestamp, pd.Timestamp, str]] | bool | None



   .. py:method:: _extract_field(field_type: str, t: pandas.Timestamp, sample: dict) -> None

      Load all 2-D variables for *field_type* at time *t* into *sample*.

      Dispatches to local or remote loading based on ``self.mode``.

      :param field_type: Registered field type (e.g. ``"prognostic"``).
      :param t: Timestamp to load.
      :param sample: Dict to write variable tensors into (modified in place).
                     Tensor shape (no batch dimension): ``(1, 1, lat, lon)``.



   .. py:method:: _load_local_var(field_type: str, vname: str, t: pandas.Timestamp)

      Load a single variable from a local NetCDF or Zarr file.

      :param field_type: Field type key used to look up file intervals.
      :param vname: Variable name within the dataset.
      :param t: Timestamp to select.

      :returns: 2-D numpy array ``(lat, lon)`` after optional extent subsetting.

      :raises KeyError: If no files are registered for *field_type*.



   .. py:method:: _load_remote_var(vname: str, t: pandas.Timestamp)

      Stream a single variable from the MRMS S3 bucket.

      Imports ``s3fs`` and ``pygrib`` lazily so they are only required
      when remote mode is actually used.

      :param vname: MRMS variable name (used in the S3 path).
      :param t: Timestamp to fetch.

      :returns: 2-D numpy array ``(lat, lon)`` after optional extent subsetting.



