credit.datasets.goes
====================

.. py:module:: credit.datasets.goes

.. autoapi-nested-parse::

   goes.py
   -------------------------------------------------------
   GOESDataset: PyTorch Dataset for GOES data with nested input/target structure.

   Sample structure returned by __getitem__:

       {
           "input":    {<user_provided_name>: {"<user_provided_name>/prognostic/2d/CMI_C04": tensor,
                                               "<user_provided_name>/prognostic/2d/CMI_C07": tensor}},
           "target":   {<user_provided_name>: {"<user_provided_name>/prognostic/2d/CMI_C04": tensor,
                                               "<user_provided_name>/prognostic/2d/CMI_C07": tensor}},  # only populated when return_target=True
           "metadata": {<user_provided_name>: {"input_datetime": int, "target_datetime": int}},
       }

   All GOES variables are 2D. Tensor shape (no batch dimension):
       (1, 1, lat, lon)   — singleton level dim, consistent with CREDIT Gen2 2D convention

   After DataLoader collation the batch dimension is prepended:
       (batch, 1, 1, lat, lon)



Classes
-------

.. autoapisummary::

   credit.datasets.goes.GOESDataset


Functions
---------

.. autoapisummary::

   credit.datasets.goes._build_spatial_slices
   credit.datasets.goes._find_nearest_latlon


Module Contents
---------------

.. py:function:: _build_spatial_slices(extent: list[int] | None, lat2d: numpy.ndarray | None = None, lon2d: numpy.ndarray | None = None) -> tuple[slice, slice]

   Compute row (latitude) and column (longitude) slices that bound a geographic extent on a 2-D grid.

   Given an optional bounding box in geographic coordinates, returns a pair of
   ``slice`` objects that can be passed directly to ``xarray.Dataset.isel`` (or
   NumPy fancy-indexing) to crop a 2-D field to the requested region.

   :param extent: Bounding box as ``[lon_min, lon_max, lat_min, lat_max]`` in
                  decimal degrees. Pass ``None`` to select the entire grid (both
                  slices become ``slice(None)``). Valid range for longitude: ``[-180, 180]``;
                  and latitude: ``[-90, 90]``.
   :param lat2d: 2-D array of latitudes (degrees) with the same shape as the
                 target grid. Required when ``extent`` is not ``None``.
   :param lon2d: 2-D array of longitudes (degrees) with the same shape as the
                 target grid. Required when ``extent`` is not ``None``.

   :returns: A ``(y_slice, x_slice)`` tuple where ``y_slice`` indexes rows
             (latitude axis) and ``x_slice`` indexes columns (longitude axis) of
             the 2-D grid.

   :raises ValueError: If ``extent`` is a list but ``lat2d`` or ``lon2d`` are
       ``None``.
   :raises TypeError: If ``extent`` is neither ``None`` nor a list.


.. py:function:: _find_nearest_latlon(lat2d: numpy.ndarray, lon2d: numpy.ndarray, lat_target: float, lon_target: float) -> tuple[int, int]

   Find the 2-D grid indices of the point nearest to a target lat/lon using Haversine distance.

   :param lat2d: 2-D array of latitudes in decimal degrees.
   :param lon2d: 2-D array of longitudes in decimal degrees.
   :param lat_target: Target latitude in decimal degrees. Valid range for latitude: ``[-90, 90]``.
   :param lon_target: Target longitude in decimal degrees. Valid range for longitude: ``[-180, 180]``;

   :returns: A ``(i, j)`` tuple of the row and column indices of the nearest grid
             point.


.. py:class:: GOESDataset(data_config: dict[str, Any], return_target: bool = False)

   Bases: :py:obj:`credit.datasets.base_dataset.BaseDataset`


   PyTorch Dataset for GOES-R ABI Level-2 (L2) satellite imagery.

   Field types follow CREDIT Gen2 conventions: ``prognostic`` variables appear in
   both input (at step 0) and target; ``dynamic_forcing`` appears in input
   at every step; ``diagnostic`` appears in target only.  At step ``i > 0``
   the model's own prognostic predictions are fed back — no disk read occurs
   for prognostic fields at those steps.

   Supports loading directly from AWS S3 (remote mode) or from local
   NetCDF files (local mode). Spatial subsetting via ``extent``
   is applied at load time on the curvilinear GOES grid.

   See module docstring for full description of output format and file naming.

   Example YAML configuration (local mode):

       data:
           source:
               Example_GOES:  # User-provided name (arbitrary key)
                   dataset_type: "goes"
                   goes_position: "east"  # or "west"
                   mode: "local"
                   product: "ABI-L2-MCMIPC"
                   variables:
                       prognostic:
                           vars_2D: ["CMI_C04", "CMI_C07", "CMI_C08", "CMI_C09", "CMI_C10", "CMI_C13"]
                           path: "/glade/derecho/scratch/kevinyang/datasets/goes/"
                       diagnostic: null
                       dynamic_forcing: null
               latlon2d_dir: "/glade/derecho/scratch/kevinyang/datasets/goes/"
               extent: [-130, -60, 20, 55]

           start_datetime: "2021-06-01"
           end_datetime: "2021-06-04"
           timestep: "6h"
           forecast_len: 0

   Example YAML configuration (remote mode):

       data:
           source:
               Example_GOES:  # User-provided name (arbitrary key)
                   dataset_type: "goes"
                   goes_position: "east" # or "west"
                   mode: "remote"
                   product: "ABI-L2-MCMIPC"
                   variables:
                       prognostic:
                           vars_2D: ["CMI_C04", "CMI_C07", "CMI_C08", "CMI_C09", "CMI_C10", "CMI_C13"]
                       diagnostic: null
                       dynamic_forcing: null
               latlon2d_dir: "/glade/derecho/scratch/kevinyang/datasets/goes/"
               extent: [-130, -60, 20, 55]

   :param config: Top-level experiment configuration dictionary. The relevant
                  sub-keys are:

                  - ``config["source"]["Example_GOES"]``: user-provided source name.

                    - ``dataset_type`` (str): has to be "goes" to trigger this dataset class.
                    - ``goes_position`` (str): Satellite position. One of ``"east"``, ``"west"``. Defaults to
                      ``"east"``.
                    - ``mode`` (str): ``"local"`` or ``"remote"`` (S3). Defaults to
                      ``"local"``.
                    - ``product`` (str): ABI product string, e.g.
                      ``"ABI-L2-MCMIPC"``.
                    - ``extent`` (list, optional): Bounding box
                      ``[lon_min, lon_max, lat_min, lat_max]`` to spatially crop
                      each field.
                    - ``latlon2d_dir`` (str): Directory containing pre-computed
                      lat/lon grid NetCDF files.
                    - ``qc_path`` (str, optional): Path to a Parquet QC table.
                      Timestamps that miss file or fail QC are replaced with ``None`` in the file
                      map.
                    - ``variables`` (dict): Mapping of field_type to variable spec,

                  - ``config["timestep"]`` (str): Model timestep as a
                    ``pandas.Timedelta``-parseable string (e.g. ``"1h"``).
                  - ``config["forecast_len"]`` (int): Number of autoregressive
                    forecast steps.
                  - ``config["start_datetime"]`` (str): Start of the data range.
                  - ``config["end_datetime"]`` (str): End of the data range.
   :param return_target: When ``True`` the sample also contains a ``"target"``
                         key populated with prognostic and diagnostic fields at ``t + dt``.
                         Defaults to ``False``.

   .. attribute:: datetimes

      Valid input times for which samples can
      be fetched.

      :type: pd.DatetimeIndex

   .. attribute:: file_dict

      Maps each field type to a list of
      ``(period_start, period_end, file path)`` tuples built during
      initialization.

      :type: dict

   .. attribute:: var_dict

      Maps each field type to
      ``{"vars_2D": [<variable names>]}``.

      :type: dict

   .. attribute:: y_slice

      Row crop derived from ``extent`` (or ``slice(None)``
      for the full grid).

      :type: slice

   .. attribute:: x_slice

      Column crop derived from ``extent`` (or
      ``slice(None)`` for the full grid).

      :type: slice

   :raises FileNotFoundError: If the lat/lon grid NetCDF cannot be found under
       ``latlon2d_dir``.


   .. py:attribute:: dataset_type
      :value: 'goes'



   .. py:attribute:: goes_position
      :type:  str


   .. py:attribute:: product
      :type:  str


   .. py:attribute:: static_metadata
      :type:  dict[str, Any]


   .. py:attribute:: qc_path
      :type:  str


   .. py:attribute:: tolerance


   .. py:attribute:: _fs
      :value: None



   .. py:attribute:: latlon2d_dir
      :type:  str


   .. py:attribute:: extent


   .. py:method:: _collect_GOES_file_path(base_dir: str = '', verbose: bool = False)

      Build a time-ordered file map for the dataset's datetime range.

      For each requested timestamp the method lists the appropriate S3 or
      local hourly directory, parses GOES L2 filenames, and associates each
      timestamp with the nearest file within ``tolerance`` (default 3
      minutes). QC filtering is applied automatically when ``self.qc_path``
      is set, masking bad intervals by setting their path entry to ``None``.

      :param base_dir: Root directory prepended to relative paths when ``mode``
                       is ``"local"``. Ignored for remote mode.
      :param verbose: When ``True``, print a warning for each hour directory
                      that cannot be listed (missing data, permission errors, etc.).

      :returns: A list of ``(period_start, period_end, file_path)`` tuples, one
                per timestamp in ``datetimes``. ``file_path`` is ``"NONE"`` when
                no file was found within tolerance, or ``None`` when the interval
                was masked by the QC table at ``self.qc_path``.

      :raises FileNotFoundError: If GOES L2 files are not found for
          the requested datetime.
      :raises ValueError: If the GOES L2 filenames do not match the expected
          naming convention (fewer than 6 underscore-separated tokens).



   .. py:method:: _get_file_source(field_config: dict[str, Any]) -> list[tuple[pandas.Timestamp, pandas.Timestamp, str]] | bool | None

      Return the file source for a field. Override in subclasses for different modes/backends.

      :param field_config: Validated field-type config dict.
      :type field_config: dict[str, Any]

      :raises ValueError: If ``self.mode`` is not a recognised mode.

      :returns:

                Depending on the mode and field type,
                    this method may return a list of (start_time, end_time, file_path) tuples produced by _map_files,
                    a boolean indicating the presence of the field (e.g., for remote data), or None if the field is disabled.
                    The expected return type should be consistent within a dataset class.
      :rtype: list[tuple[pd.Timestamp, pd.Timestamp, str]] | bool | None



   .. py:method:: _extract_field(field_type: str, t: pandas.Timestamp, sample: dict) -> None

      Load all 2-D variables for a field type at time ``t`` into ``sample``.

      Dispatches to ``_load_local_var`` or ``_load_remote_var`` depending on
      ``mode``, then stores each variable as a ``torch.Tensor`` of shape
      ``(1, 1, ny, nx)`` under the key
      ``"{source_name}/{field_type}/2d/{vname}"`` in ``sample``. Does nothing if
      the field type has no registered variables.

      :param field_type: One of ``"prognostic"``, ``"diagnostic"``, or
                         ``"dynamic_forcing"``.
      :param t: Timestamp for which to load data.
      :param sample: Output dictionary that is updated in-place.



   .. py:method:: _load_local_var(field_type: str, vnames: list[str], t: pandas.Timestamp)

      Load variables from a local NetCDF file and apply spatial cropping.

      :param field_type: Field type used to look up the file map in
                         ``file_dict``.
      :param vnames: Variable names to extract from the dataset.
      :param t: Timestamp used to locate the correct file via ``_find_file``.

      :returns: A dict mapping each variable name to its cropped ``numpy.ndarray``.

      :raises KeyError: If no files are registered for ``field_type``.



   .. py:method:: _load_remote_var(field_type: str, vnames: list[str], t: pandas.Timestamp)

      Load variables from a remote S3 NetCDF file and apply spatial cropping.

      Uses the cached ``_fs`` S3FileSystem to open the file as a byte stream
      and reads it with the ``h5netcdf`` engine.

      :param field_type: Field type used to look up the file map in
                         ``file_dict``.
      :param vnames: Variable names to extract from the dataset.
      :param t: Timestamp used to locate the correct file via ``_find_file``.

      :returns: A dict mapping each variable name to its cropped ``numpy.ndarray``.

      :raises KeyError: If no files are registered for ``field_type``.



