credit.datasets.hrrr
====================

.. py:module:: credit.datasets.hrrr

.. autoapi-nested-parse::

   hrrr.py
   -------------------------------------------------------
   HRRRDataset: PyTorch Dataset for HRRR GRIB2 data.

   Supports three HRRR products (``VALID_PRODUCTS``):

   * ``"wrfprs"`` — pressure-level output (default, ~200 MB/file)
   * ``"wrfnat"`` — native/hybrid-sigma level output (~200 MB/file, ~65 levels)
   * ``"wrfsubh"`` — 15-minute sub-hourly surface output (surface vars only)

   Tensor keys follow the pattern ``{user_provided_name}/{hrrr_product}/{field_type}/{dim}/{varname}``
   where *hrrr_product* is product-specific:

       "wrfprs"  → `{user_provided_name}/`wrfprs/{field_type}/{dim}/{varname}``
       "wrfnat"  → `{user_provided_name}/`wrfnat/{field_type}/{dim}/{varname}``
       "wrfsubh" → `{user_provided_name}/`wrfsubh/{field_type}/2d/{varname}``

   *dim* is ``"3d"`` for multi-level variables and ``"2d"`` for surface variables.

   Tensor shapes (before DataLoader batching):
       3D variables: ``(n_levels, 1, y, x)``
       2D variables: ``(1, 1, y, x)``

   The ``y`` / ``x`` spatial dimensions correspond to HRRR's native Lambert
   Conformal Conic grid; if ``extent`` is specified they reflect the cropped
   sub-domain rather than the full CONUS grid (~1059 x 1799).

   Two S3 path layouts are handled automatically:

       v1/v2  (before 2018-07-12):
           s3://noaa-hrrr-bdp-pds/hrrr.{YYYYMMDD}/hrrr.t{HH}z.{product}f{FF:02d}.grib2
       v3/v4  (2018-07-12 onward):
           s3://noaa-hrrr-bdp-pds/hrrr.{YYYYMMDD}/conus/hrrr.t{HH}z.{product}f{FF:02d}.grib2

   GRIB2 reading
   -------------
   Both local and remote modes use the same ``.idx`` + byte-range pipeline:

   *Remote mode*:

   1. Fetch the sidecar ``.idx`` inventory (~100 KB) via HTTPS to get exact byte
      offsets for every GRIB message.
   2. Issue one HTTP Range GET per required message (~50-200 KB each) via
      ``requests``, with all messages fetched in parallel using
      :class:`concurrent.futures.ThreadPoolExecutor`.

   *Local mode*: reads the ``.idx`` sidecar from disk, then uses
   ``file.seek()`` + ``file.read()`` — identical byte-range approach, no
   full-file scan.  The ``.idx`` sidecar must be present alongside the grib2;
   download it with ``hrrr_download.py``.

   For a typical training sample (5 vars x 6 levels ≈ 30 messages) remote mode
   transfers ~3 MB instead of ~200 MB (~60-100x reduction).

   Variable lookup is driven by :data:`VAR_REGISTRY`.  Extend it at import
   time to add variables without subclassing::

       from credit.datasets.hrrr import VAR_REGISTRY
       VAR_REGISTRY["MYVAR"] = {
           "shortName": "myvar", "typeOfLevel": "isobaricInhPa",
           "idx_name": "MYVAR", "idx_level": None,
       }

   Example YAML (wrfprs, local mode)::

       data:
         source:
           Example_HRRR:  # User-provided name (arbitrary key)
             dataset_type: "HRRR"
             # product: "wrfprs" # Optional for PRS product. Default is "wrfprs".
             mode: "local"
             base_path: "/data/hrrr"
             forecast_hour: 0
             levels: [250, 500, 700, 850, 925, 1000]
             variables:
               prognostic:
                 vars_3D: [T, U, V, Q, GH]
                 vars_2D: [t2m]
             extent: [-130, -60, 20, 55]

         start_datetime: "2021-06-01"
         end_datetime:   "2021-06-05"
         timestep:       "1h"
         forecast_len:   0

   Example YAML (wrfnat, remote mode)::

       data:
         source:
           Example_HRRR_NAT:  # User-provided name (arbitrary key)
             dataset_type: "HRRR"
             product: "wrfnat" # Options: "wrfprs" (default), "wrfnat", "wrfsubh"
             mode: "remote"
             forecast_hour: 0
             levels: [10, 20, 30, 40, 50]   # hybrid level indices 1-65
             variables:
               prognostic:
                 vars_3D: [T, U, V, Q]

         start_datetime: "2022-01-01"
         end_datetime:   "2022-01-31"
         timestep:       "1h"
         forecast_len:   0

   Example YAML (wrfsubh, remote mode — 15-min output)::

       data:
         source:
           Example_HRRR_SUBH:  # User-provided name (arbitrary key)
             dataset_type: "HRRR"
             product: "wrfsubh" # Options: "wrfprs" (default), "wrfnat", "wrfsubh"
             mode: "remote"
             variables:
               prognostic:
                 vars_2D: [t2m, sp, refc]

         start_datetime: "2022-01-01 00:15"
         end_datetime:   "2022-01-31 00:00"
         timestep:       "15min"
         forecast_len:   0



Attributes
----------

.. autoapisummary::

   credit.datasets.hrrr.logger
   credit.datasets.hrrr._HRRR_V3_CUTOFF
   credit.datasets.hrrr._S3_BUCKET
   credit.datasets.hrrr._HRRR_HTTPS_BASE
   credit.datasets.hrrr.VAR_REGISTRY
   credit.datasets.hrrr._MAX_REMOTE_WORKERS
   credit.datasets.hrrr._HTTP_TIMEOUT
   credit.datasets.hrrr.VALID_PRODUCTS


Classes
-------

.. autoapisummary::

   credit.datasets.hrrr.HRRRDataset


Functions
---------

.. autoapisummary::

   credit.datasets.hrrr._hrrr_s3_uri
   credit.datasets.hrrr._hrrr_local_path
   credit.datasets.hrrr._s3_uri_to_https
   credit.datasets.hrrr._parse_idx
   credit.datasets.hrrr._fetch_idx
   credit.datasets.hrrr._fetch_message
   credit.datasets.hrrr._build_prs_entry_map
   credit.datasets.hrrr._resolve_pressure_levels
   credit.datasets.hrrr._build_nat_entry_map
   credit.datasets.hrrr._resolve_nat_levels
   credit.datasets.hrrr._find_subhf_entry
   credit.datasets.hrrr._fetch_bytes_local
   credit.datasets.hrrr._load_idx_local
   credit.datasets.hrrr._to_float32
   credit.datasets.hrrr._validate_product_request


Module Contents
---------------

.. py:data:: logger

.. py:data:: _HRRR_V3_CUTOFF

.. py:data:: _S3_BUCKET
   :value: 'noaa-hrrr-bdp-pds'


.. py:data:: _HRRR_HTTPS_BASE
   :value: 'https://noaa-hrrr-bdp-pds.s3.amazonaws.com'


.. py:data:: VAR_REGISTRY
   :type:  dict[str, dict[str, str | None]]

.. py:data:: _MAX_REMOTE_WORKERS
   :value: 8


.. py:data:: _HTTP_TIMEOUT
   :type:  tuple[int, int]
   :value: (10, 120)


.. py:data:: VALID_PRODUCTS

.. py:function:: _hrrr_s3_uri(t: pandas.Timestamp, forecast_hour: int, product: VALID_PRODUCTS = 'wrfprs') -> str

   Construct the S3 URI for a HRRR grib2 file.

   :param t: Initialisation timestamp (UTC).
   :type t: pd.Timestamp
   :param forecast_hour: Forecast lead hour (FF), e.g. ``0`` for analysis.
   :type forecast_hour: int
   :param product: HRRR product name. Defaults to "wrfprs".
   :type product: VALID_PRODUCTS, optional

   :returns: S3 URI.
   :rtype: str


.. py:function:: _hrrr_local_path(base_path: str, t: pandas.Timestamp, forecast_hour: int, product: VALID_PRODUCTS = 'wrfprs') -> str

   Construct the local filesystem path for a HRRR grib2 file.

   :param base_path: Root directory containing HRRR data.
   :type base_path: str
   :param t: Initialization timestamp (UTC).
   :type t: pd.Timestamp
   :param forecast_hour: Forecast lead hour (FF), e.g. ``0`` for analysis.
   :type forecast_hour: int
   :param product: HRRR product name. Defaults to "wrfprs".
   :type product: VALID_PRODUCTS, optional

   :returns: Local filesystem path to the grib2 file.
   :rtype: str


.. py:function:: _s3_uri_to_https(s3_uri: str) -> str

   Convert an ``s3://noaa-hrrr-bdp-pds/...`` URI to a public HTTPS URL.

   :param s3_uri: SRI URI
   :type s3_uri: str

   :returns: Public HTTPS URL
   :rtype: str


.. py:function:: _parse_idx(text: str) -> list[dict[str, str | int | None]]

   Parse a HRRR ``.idx`` inventory file into a list of message entries.

   Each entry dict has keys: ``var``, ``level``, ``byte_start``, ``byte_end``
   (``None`` for the last entry, meaning read to EOF).

   :param text: The content of the .idx file.
   :type text: str

   :returns: Entries parsed from the .idx, in file order.
   :rtype: list[dict[str, str | int | None]]


.. py:function:: _fetch_idx(s3_uri: str) -> list[dict[str, str | int | None]]

   Fetch and parse the ``.idx`` sidecar for a HRRR grib2 file via HTTPS.

   :param s3_uri: S3 URI
   :type s3_uri: str

   :raises FileNotFoundError: If the ``.idx`` file is not found (older v1/v2 files
       may lack sidecars; pre-download with ``hrrr_download.py`` and use
       local mode instead).

   :returns: Entries parsed from the .idx, in file order.
   :rtype: list[dict[str, str | int | None]]


.. py:function:: _fetch_message(https_url: str, byte_start: int, byte_end: int | None, session=None) -> bytes

   Fetch a single GRIB message via an HTTP Range request.

   :param https_url: Public HTTPS URL of the grib2 file.
   :type https_url: str
   :param byte_start: First byte of the message (inclusive).
   :type byte_start: int
   :param byte_end: Last byte of the message (inclusive), or ``None`` for EOF.
   :type byte_end: int | None
   :param session: Optional ``requests.Session`` for connection reuse.  Falls
                   back to module-level ``requests.get`` if ``None``. Defaults to None
   :type session: _type_, optional

   :returns: The raw bytes of the GRIB message for that byte range.
   :rtype: bytes


.. py:function:: _build_prs_entry_map(idx_entries: list[dict[str, str | int | None]], idx_name: str) -> dict[float, dict[str, str | None]]

   Return a ``{pressure_level_hPa: idx_entry}`` dict for a pressure-level variable.

   :param idx_entries: List of entries parsed from the .idx file.
   :type idx_entries: list[dict[str, str  |  int  |  None]]
   :param idx_name: Name of the variable to filter for.
   :type idx_name: str

   :returns: Mapping from pressure level (hPa) to the corresponding .idx entry for that variable.
   :rtype: dict[float, dict[str, str | None]]


.. py:function:: _resolve_pressure_levels(requested: list[int] | None, prs_map: dict[float, dict[str, str | None]], var_name: str) -> list[float]

   Return the float pressure levels to fetch, validating against available.

   :param requested: List of requested pressure levels.
   :type requested: list[int] | None
   :param prs_map: Mapping from _build_prs_entry_map()
   :type prs_map: dict[float, dict[str, str  |  None]]
   :param var_name: Variable name for error messages (e.g. "T", "U", "Q", etc.)
   :type var_name: str

   :raises ValueError: If any requested levels are not found in the available levels for that variable.

   :returns: The float pressure levels to fetch.
   :rtype: list[float]


.. py:function:: _build_nat_entry_map(idx_entries: list[dict[str, str | int | None]], idx_name: str) -> dict[int, dict[str, str | None]]

   Return ``{hybrid_level_index: idx_entry}`` for a wrfnat variable.

   HRRR native-level ``.idx`` entries look like::

       TMP:10 hybrid level:anl:

   i.e. ``level`` ends with ``" hybrid level"`` and the prefix is the integer
   level index (1-65, bottom-up).

   :param idx_entries: List of entries parsed from the .idx file.
   :type idx_entries: list[dict[str, str  |  int  |  None]]
   :param idx_name: Name of the variable to filter for.
   :type idx_name: str

   :returns: Mapping from hybrid level index to the corresponding .idx entry for that variable.
   :rtype: dict[int, dict[str, str | None]]


.. py:function:: _resolve_nat_levels(requested: list[int] | None, nat_map: dict[int, dict[str, str | None]], var_name: str) -> list[int]

   Return native level indices to fetch, validating against available.

   :param requested: List of requested hybrid levels.
   :type requested: list[int] | None
   :param nat_map: Mapping from _build_nat_entry_map()
   :type nat_map: dict[int, dict[str, str  |  None]]
   :param var_name: Variable name for error messages (e.g. "T", "U", "Q", etc.)
   :type var_name: str

   :raises ValueError: If any requested levels are not found in the available levels for that variable.

   :returns: The integer native level indices to fetch.
   :rtype: list[int]


.. py:function:: _find_subhf_entry(idx_entries: list[dict[str, str | int | None]], idx_name: str, idx_level: str, step_min: int) -> dict[str, str | int | None]

   Return the idx entry for a wrfsubh variable at a specific sub-step.

   Sub-hourly ``.idx`` entries have a ``step`` field like ``"15 min fcst"``,
   ``"30 min fcst"``, ``"45 min fcst"``, ``"60 min fcst"``.

   :param idx_entries: Parsed ``.idx`` entries for the wrfsubh file.
   :type idx_entries: list[dict[str, str  |  int  |  None]])
   :param idx_name: Variable name as it appears in the ``.idx``.
   :type idx_name: str
   :param idx_level: Level string (e.g. ``"2 m above ground"``).
   :type idx_level: str
   :param step_min: Sub-step in minutes (15, 30, 45, 60, …).
   :type step_min: int

   :raises KeyError: If no matching entry is found.

   :returns: The matching .idx entry for that variable, level, and step.
   :rtype: dict[str, str | int | None]


.. py:function:: _fetch_bytes_local(path: str, byte_start: int, byte_end: int | None) -> bytes

   Read a byte range directly from a local GRIB2 file.

   :param path: Absolute path to the local grib2 file.
   :type path: str
   :param byte_start: First byte (inclusive).
   :type byte_start: int
   :param byte_end: Last byte (inclusive), or ``None`` to read to EOF.
   :type byte_end: int | None

   :returns: Raw bytes for that message.
   :rtype: bytes


.. py:function:: _load_idx_local(grib2_path: str) -> list[dict[str, str | int | None]]

   Read and parse the ``.idx`` sidecar from local disk.

   Expects the index at ``{grib2_path}.idx``.  Download it alongside the
   grib2 with ``hrrr_download.py``.

   :param grib2_path: Absolute path to the local grib2 file.
   :type grib2_path: str

   :raises FileNotFoundError: If the ``.idx`` file is absent.

   :returns: Entries parsed from the .idx, in file order.
   :rtype: list[dict[str, str | int | None]]


.. py:function:: _to_float32(values: numpy.ndarray) -> numpy.ndarray

   Return float32, replacing masked values with NaN.

   :param values: Values to convert, potentially a masked array.
   :type values: np.ndarray

   :returns: Array with masked values filled with NaN and dtype float32.
   :rtype: np.ndarray


.. py:function:: _validate_product_request(product_request: str) -> VALID_PRODUCTS

   Validate the dataset request config, raising ValueError for invalid requests.

   :param product_request: The HRRR product name from the config (e.g. "wrfprs", "wrfnat", "wrfsubh").
   :type product_request: str

   :raises ValueError: If the product is not recognized or mapped to a valid HRRR product.

   :returns: The validated HRRR product name.
   :rtype: VALID_PRODUCTS


.. py:class:: HRRRDataset(data_config: dict[str, Any], return_target: bool = False)

   Bases: :py:obj:`credit.datasets.base_dataset.BaseDataset`


   CREDIT Dataset for HRRR GRIB2 data (wrfprs / wrfnat / wrfsubh).

   Implements the same field-type semantics as BaseDataset:

   * ``prognostic``      — input at step 0 and target (autoregressive rollout)
   * ``dynamic_forcing`` — input at every step; never a target
   * ``diagnostic``      — target only
   * ``static``          — input at step 0; never a target, applies to all steps

   Both modes use ``pygrib`` for GRIB2 decoding.  Remote mode fetches the
   ``.idx`` sidecar and issues parallel HTTP Range requests — no full file
   download required.

   See module docstring for full output format, tensor shapes, and YAML
   configuration examples.

   .. attribute:: dataset_type

      Tensor key - `"HRRR"`

   .. attribute:: product

      Active HRRR product (``"HRRR_PRS" / "wrfprs"``, ``"HRRR_NAT" / "wrfnat"``,
      or ``"HRRR_SUBH" / "wrfsubh"``) with default value ``"HRRR_PRS"``.

   .. attribute:: datetimes

      DatetimeIndex of valid initialisation timestamps.

   .. attribute:: static_metadata

      Dataset-level metadata for MultiSourceDataset.


   .. py:attribute:: dataset_type


   .. py:attribute:: product
      :type:  VALID_PRODUCTS


   .. py:attribute:: mode
      :type:  str


   .. py:attribute:: base_path
      :type:  str | None


   .. py:attribute:: forecast_hour
      :type:  int


   .. py:attribute:: extent
      :type:  list[float] | None


   .. py:attribute:: global_levels
      :type:  list[int] | None


   .. py:attribute:: num_fetch_workers
      :type:  int


   .. py:attribute:: static_metadata
      :type:  dict[str, Any]


   .. py:attribute:: _idx_cache
      :type:  dict[str, list[dict[str, str | int | None]]]


   .. py:attribute:: _http_session
      :value: None



   .. py:attribute:: _spatial_slice
      :type:  tuple[slice, slice] | None
      :value: None



   .. py:method:: _get_session()

      Return the shared ``requests.Session``, creating it on first call.

      Created lazily so the session is never open before a DataLoader worker
      forks — each worker ends up with its own independent connection pool.



   .. py:method:: _get_spatial_slice(lats: numpy.ndarray, lons: numpy.ndarray) -> tuple[slice, slice]

      Return ``(row_slice, col_slice)`` for ``self.extent``, computed once.

      The HRRR grid is fixed (Lambert Conformal Conic, ~1059 × 1799), so the
      bounding-box row/col indices for a given ``extent`` are identical for
      every message and every timestep.  The result is cached after the first
      call so subsequent samples pay no recomputation cost.

      :param lats: 2D latitude array from a decoded pygrib message.
      :type lats: np.ndarray
      :param lons: 2D longitude array from a decoded pygrib message.
      :type lons: np.ndarray

      :raises ValueError: If ``self.extent`` does not intersect the HRRR domain.

      :returns: ``(row_slice, col_slice)`` ready for direct numpy indexing.
                Both slices are ``slice(None)`` when ``self.extent`` is ``None``.



   .. py:method:: _register_field(field_type: credit.datasets.base_dataset.VALID_FIELD_TYPES, field_config: dict[str, list[str] | None] | None) -> None

      Extends the _register_field method of BaseDataset to include levels and checking with HRRR VAR_REGISTRY.

      :param field_type: One of VALID_FIELD_TYPES, namely: ``"prognostic"``, ``"dynamic_forcing"``,
                         ``"static"``, ``"diagnostic"``.
      :type field_type: VALID_FIELD_TYPES
      :param field_config: Field-type config dict, or ``None`` / null to disable the field.
      :type field_config: dict[str, list[str]  |  None] | None

      :raises KeyError: If a variable in the field config is not in the HRRR VAR_REGISTRY.



   .. py:method:: _extract_field(field_type: credit.datasets.base_dataset.VALID_FIELD_TYPES, t: pandas.Timestamp, sample: dict[str, Any]) -> None

      Replace the _extract_field method of BaseDataset to implement the
      HRRR-specific file resolution and fetching logic.

      Load all variables for *field_type* at time *t* into *sample*.

      Resolves the file path / URI, loads the ``.idx`` (cached), then
      delegates to :meth:`_extract_from_idx` with the appropriate byte
      fetcher for the current mode.

      For ``wrfsubh``, *t* is a 15-min-resolution timestamp.  This method
      derives the HRRR init time and FF file number automatically:

      * ``init_hour = t.floor("1h")``
      * ``step_min  = minutes since init`` (15, 30, 45, 60, …)
      * ``ff        = ceil(step_min / 60)`` (file number within the run)
      * If *t* is exactly on the hour, it is treated as the 60-min step of
        the previous hour's run (``init_hour -= 1h``, ``step_min = 60``).

      :param field_type: One of VALID_FIELD_TYPES, namely: ``"prognostic"``, ``"dynamic_forcing"``,
                         ``"static"``, ``"diagnostic"``.
      :type field_type: VALID_FIELD_TYPES
      :param t: Initialization timestamp (UTC).  For ``wrfsubh``, this is a
                15-min-resolution timestamp like ``2024-01-01T00:15:00Z``.
      :type t: pd.Timestamp
      :param sample: The sample dict being built in __getitem__
      :type sample: dict[str, Any]



   .. py:method:: _extract_from_idx(field_type: credit.datasets.base_dataset.VALID_FIELD_TYPES, idx_entries: list[dict[str, str | int | None]], fetcher: Callable[[dict[str, str | int | None]], bytes], vd: dict[str, list[str | int]], sample: dict[str, Any], step_min: int | None = None) -> None

      Shared fetch-plan → parallel byte fetch → decode → tensor pipeline.

      Used by both local and remote modes.  The only difference between modes
      is the *fetcher* callable that maps an idx entry to raw GRIB bytes.
      Product-specific level dispatch (pressure vs hybrid-sigma vs sub-hourly)
      is handled here based on ``self.product``.

      :param field_type: One of VALID_FIELD_TYPES.
      :type field_type: VALID_FIELD_TYPES
      :param idx_entries: Parsed ``.idx`` entries for the target file.
      :type idx_entries: list[dict[str, str  |  int  |  None]]
      :param fetcher: Callable ``(entry: dict) -> bytes`` that fetches the raw
                      GRIB message for a given idx entry.
      :param vd: Variable dict (``vars_3D``, ``vars_2D``, ``levels``).
      :type vd: dict[str, list[str | int]]
      :param sample: Output dict to populate in-place.
      :type sample: dict[str, Any]
      :param step_min: Sub-hourly step in minutes (15, 30, 45, 60, …).  Only
                       used when ``self.product == "wrfsubh"``.
      :type step_min: int | None



