credit.datasets.base_dataset
============================

.. py:module:: credit.datasets.base_dataset

.. autoapi-nested-parse::

   base_dataset.py
   -------------------------------------------------------
   AbstractBaseDataset and BaseDataset: A PyTorch Dataset class for:
   1. Type hinting and annotations throughout CREDIT
   2. Scaffolding the development of future datasets
   3. Provide a minimal implementation of a Dataset for testing
   4. Avoid redundant code across dataset classes



Attributes
----------

.. autoapisummary::

   credit.datasets.base_dataset.VALID_FIELD_TYPES


Classes
-------

.. autoapisummary::

   credit.datasets.base_dataset.AbstractBaseDataset
   credit.datasets.base_dataset.BaseDataset


Module Contents
---------------

.. py:data:: VALID_FIELD_TYPES

.. py:class:: AbstractBaseDataset(data_config: dict[str, Any], return_target: bool = False)

   Bases: :py:obj:`torch.utils.data.Dataset`\ [\ :py:obj:`Any`\ ]


   Abstract base dataset class based on PyTorch Dataset class for CREDIT.

   This class defines the expected methods and attributes for any dataset in CREDIT,
   but does not provide any implementation. The BaseDataset class inherits from this
   class and provides a minimal implementation. Any future dataset should inherit from
   either AbstractBaseDataset or BaseDataset depending on the level of functionality needed.

   For generality, the inheritance is from torch.utils.data.Dataset[Any], however
   there may be benefits to stricter typing than Any for consistency in the
   get item return, especially if torch supports dataset type accelerations in future
   releases.


   .. py:attribute:: curr_source_name
      :type:  str


   .. py:attribute:: dataset_type
      :type:  str


   .. py:attribute:: dt
      :type:  pandas.Timedelta


   .. py:attribute:: num_forecast_steps
      :type:  int


   .. py:attribute:: start_datetime
      :type:  pandas.Timestamp


   .. py:attribute:: end_datetime
      :type:  pandas.Timestamp


   .. py:attribute:: datetimes
      :type:  pandas.DatetimeIndex


   .. py:attribute:: return_target
      :type:  bool


   .. py:attribute:: mode
      :type:  str


   .. py:attribute:: file_dict
      :type:  dict[str, Any]


   .. py:attribute:: var_dict
      :type:  dict[str, Any]


   .. py:attribute:: static_metadata
      :type:  dict[str, Any]


   .. py:method:: __len__() -> int
      :abstractmethod:



   .. py:method:: __getitem__(args: tuple[pandas.Timestamp, int]) -> dict[str, Any]
      :abstractmethod:



   .. py:method:: _build_timestamps() -> pandas.DatetimeIndex
      :abstractmethod:



   .. py:method:: _get_field_name(field_type: VALID_FIELD_TYPES, dim_str: str, vname: str) -> str
      :abstractmethod:



   .. py:method:: init_register_all_fields() -> None
      :abstractmethod:



   .. py:method:: _register_field(field_type: VALID_FIELD_TYPES, field_config: dict[str, Any] | None) -> None
      :abstractmethod:



   .. py:method:: _get_file_source(field_config: dict[str, Any]) -> list[tuple[pandas.Timestamp, pandas.Timestamp, str]] | bool | None
      :abstractmethod:



   .. py:method:: _extract_field(field_type: VALID_FIELD_TYPES, t: pandas.Timestamp, sample: dict[str, Any]) -> None
      :abstractmethod:



.. py:class:: BaseDataset(data_config: dict[str, Any], return_target: bool = False)

   Bases: :py:obj:`AbstractBaseDataset`


   PyTorch Dataset class for CREDIT that  enables:
   1. Type hinting and annotations throughout CREDIT
   2. Scaffolding the development of future datasets
   3. Provide a minimal implementation of a Dataset for testing

   Minimal YAML config for a dataset will have the following stucture:

   ```yaml
       data:
         source:
           Example_Base:  # User-provided name (arbitrary key)
             # PARAMETERS FOR THIS DATASET TYPE
             # Ex: levels: [10, 20, 30]
             dataset_type: "base"  # Needs to match per type of dataset!
             variables:
               prognostic: null
               #  vars_3D: ['T', 'U', 'V', 'Q'] # Your 3D variables
               #  vars_2D: ['SP', 't2m'] # Your 2D variables
               # dynamic_forcing: null
               #  vars_3D: ...
               #  vars_2D: ...
               # static: null
               #  vars_3D: ...
               #  vars_2D: ...
               # diagnostic: null
               #  vars_3D: ...
               #  vars_2D: ...
             # OPTIONAL: Override the clock bounds for this dataset
             start_datetime: "2012-04-03T00:00Z"

           # <YourName2>_<DatasetType2>: # Multiple datasets (see multi_source)

         # These parameters set the overall clock of the sampler
         start_datetime: "2000-01-01T00:00:00Z" # The earliest datetime across datasets
         end_datetime: "2020-12-31T23:00:00Z" # The latest datetime across datasets
         timestep: "12h" # The smallest time interval for the clock
         forecast_len: 1 # The number of timesteps forward that need to be rolled out per sample
   ```


   .. py:attribute:: curr_source_name


   .. py:attribute:: curr_source_cfg


   .. py:attribute:: dt
      :type:  pandas.Timedelta


   .. py:attribute:: num_forecast_steps
      :type:  int


   .. py:attribute:: start_datetime
      :type:  pandas.Timestamp


   .. py:attribute:: end_datetime
      :type:  pandas.Timestamp


   .. py:attribute:: datetimes
      :type:  pandas.DatetimeIndex


   .. py:attribute:: return_target
      :type:  bool
      :value: False



   .. py:attribute:: mode
      :value: 'local'



   .. py:attribute:: temporal_mode
      :type:  str


   .. py:attribute:: _persist_cache
      :type:  dict


   .. py:attribute:: static_metadata
      :type:  dict[str, Any]


   .. py:attribute:: file_dict
      :type:  dict[str, Any]


   .. py:attribute:: var_dict
      :type:  dict[str, Any]


   .. py:method:: __len__() -> int

      For a CREDIT dataset, the length is the number of unique datetimes that can be sampled from.

      :returns: Dataset length
      :rtype: int



   .. py:method:: __getitem__(args: tuple[pandas.Timestamp, int]) -> dict[str, Any]

      Return a nested input/target sample dict.

      When ``temporal_mode == "persist"``, the requested timestamp *t* is
      snapped to the last native timestamp at-or-before *t* via
      ``pd.DatetimeIndex.asof()``.  The result is cached so that multiple
      fine-resolution master-clock ticks within the same native interval
      only trigger a single file read.

      :param args: ``(t, i)`` where *t* is the current timestamp (nanoseconds
                   or pd.Timestamp) and *i* is the within-sequence step index
                   produced by the sampler. When ``i == 0`` prognostic and static
                   fields are loaded in addition to dynamic forcing.

      :returns: Dict with keys ``"input"``, ``"metadata"``, and optionally
                ``"target"`` (when ``return_target=True``). Both ``"input"`` and
                ``"target"`` are dicts of per-variable tensors keyed by
                ``"{source}/{field_type}/{dim}/{varname}"``.



   .. py:method:: _resolve_persist_timestamp(t: pandas.Timestamp) -> pandas.Timestamp

      Snap *t* to the last native timestamp at or before *t*.

      Uses ``pd.DatetimeIndex.asof()`` so non-uniform or non-zero-aligned
      cadences are handled correctly.

      :param t: Master-clock timestamp to resolve.

      :returns: The last native timestamp ``<= t``.

      :raises ValueError: If *t* is before the first native timestamp.



   .. py:method:: _load_sample(t: pandas.Timestamp, i: int) -> dict[str, Any]

      Build and return the sample dict for timestamp *t* and step index *i*.

      This is the inner implementation called by ``__getitem__``, separated
      so the persist cache can call it without re-entering the dispatch logic.

      :param t: Timestamp to load (already resolved for persist sources).
      :param i: Within-sequence step index (0 = initial step).

      :returns: Sample dict with ``"input"``, ``"metadata"``, and optionally ``"target"``.



   .. py:method:: _check_in_data_config(data_config: dict[str, Any], key: str) -> None

      Check that a key is in the data config. If not, raise an error since this is required for any dataset.

      :param data_config: Portion of the config under "data"
      :type data_config: dict[str, Any]
      :param key: The key to check (e.g., "timestep")
      :type key: str

      :raises KeyError: When the key is not found in the data config



   .. py:method:: _in_source_config(data_config: dict[str, Any], curr_source_cfg: dict[str, Any], key: str) -> bool

      Helper to determine if a key is in the source config.

      :param data_config: Portion of the config under "data"
      :type data_config: dict[str, Any]
      :param curr_source_cfg: Portion of the config under a specific source
      :type curr_source_cfg: dict[str, Any]
      :param key: The key to check (e.g., "timestep")
      :type key: str

      :returns: True if the key is in the source config, False otherwise
      :rtype: bool



   .. py:method:: _load_dt(data_config: dict[str, Any], curr_source_config: dict[str, Any], dt_key: str = 'timestep') -> pandas.Timedelta

      The timestep (dt) is a required parameter for any dataset, and is used to build the clock of the sampler.
      In general, the timestep in the data config should be the smallest timestep across all sources.
      The inherited dataset may need a coarser timestep (e.g., in the case of multi-source datasets).

      :param data_config: Portion of the config under "data"
      :type data_config: dict[str, Any]
      :param curr_source_config: Portion of the config under a specific source
      :type curr_source_config: dict[str, Any]
      :param dt_key: The key for the timestep parameter. Defaults to "timestep".
      :type dt_key: str, optional

      :returns: The timestep for the dataset
      :rtype: pd.Timedelta



   .. py:method:: _load_num_forecast_steps(data_config: dict[str, Any], curr_source_config: dict[str, Any], num_forecast_steps_key: str = 'forecast_len') -> int

      The number of forecast steps (num_forecast_steps) is a required parameter for any dataset, and
      is used in the sampler. In general, the number of forecast steps in the data config should be
      the largest across all sources, since this determines how far forward the sampler needs to roll
      out. The inherited dataset may not be able to rollout further.

      :param data_config: Portion of the config under "data"
      :type data_config: dict[str, Any]
      :param curr_source_config: Portion of the config under a specific source
      :type curr_source_config: dict[str, Any]
      :param num_forecast_steps_key: The key for the number of forecast steps parameter. Defaults to "forecast_len".
      :type num_forecast_steps_key: str, optional

      :returns: The number of forecast steps (i.e., length ahead) for the dataset
      :rtype: int



   .. py:method:: _load_start_datetime(data_config: dict[str, Any], curr_source_config: dict[str, Any], start_datetime_key: str = 'start_datetime') -> pandas.Timestamp

      The start_datetime is a required parameter for any dataset, and is used in the sampler. In general,
      the start_datetime in the data config should be the earliest across all sources, since this determines
      the earliest point in time that the sampler can draw from. The inherited dataset may not be able to go
      as far back.

      :param data_config: Portion of the config under "data"
      :type data_config: dict[str, Any]
      :param curr_source_config: Portion of the config under a specific source
      :type curr_source_config: dict[str, Any]
      :param start_datetime_key: The key for the start datetime parameter. Defaults to "start_datetime".
      :type start_datetime_key: str, optional

      :returns: The start datetime for the dataset
      :rtype: pd.Timestamp



   .. py:method:: _load_end_datetime(data_config: dict[str, Any], curr_source_config: dict[str, Any], end_datetime_key: str = 'end_datetime') -> pandas.Timestamp

      The end_datetime is a required parameter for any dataset, and is used in the sampler. In general,
      the end_datetime in the data config should be the latest across all sources, since this determines
      the latest point in time that the sampler can draw from. The inherited dataset may not be able to go
      as far forward.

      :param data_config: Portion of the config under "data"
      :type data_config: dict[str, Any]
      :param curr_source_config: Portion of the config under a specific source
      :type curr_source_config: dict[str, Any]
      :param end_datetime_key: The key for the end datetime parameter. Defaults to "end_datetime".
      :type end_datetime_key: str, optional

      :returns: The end datetime for the dataset
      :rtype: pd.Timestamp



   .. py:method:: _build_timestamps() -> pandas.DatetimeIndex

      Return timestamps for the dataset using the class parameters.
      The timestamps should ensure that there are enough future timesteps to
      rollout based on num_forecast_steps and the dt timestep length, and
      should be at the configured timestep frequency.

      Note: Please override this method if you would like to apply Quality Control
      checks that limit the datetimes from which to sample from, or if you would
      like to enforce time bounds automatically for your dataset. You can use
      super() to have base functionality in these cases.

      :returns:

                DatetimeIndex from ``start_datetime`` to ``end_datetime`` minus
                    the forecast horizon, at the configured timestep frequency.
      :rtype: pd.DatetimeIndex



   .. py:method:: _get_field_name(field_type: VALID_FIELD_TYPES, dim_str: str, vname: str) -> str

      Get the field name and enforce a consistent convention across datasets.

      The convention for the field name is: ``"{user's current source name}/{field_type}/{dim_str}/{vname}"``.

      :param field_type: The field type (e.g., "prognostic")
      :type field_type: VALID_FIELD_TYPES
      :param dim_str: The dimension string (e.g., "3d" or "2d")
      :type dim_str: str
      :param vname: The variable name (e.g., "T" or "t2m")
      :type vname: str

      :returns: The key string that will be used to access the field variable
      :rtype: str



   .. py:method:: init_register_all_fields() -> None

      Initialize and register all fields for the dataset.

      :raises KeyError: If the config does not include any variables.



   .. py:method:: _register_field(field_type: VALID_FIELD_TYPES, field_config: dict[str, Any] | None) -> None

      Validate and register one field type from the config variables block.

      Populates ``self.file_dict`` and ``self.var_dict`` for *field_type*.

      :param field_type: One of VALID_FIELD_TYPES, namely: ``"prognostic"``, ``"dynamic_forcing"``,
                         ``"static"``, ``"diagnostic"``.
      :param field_config: Field-type config dict, or ``None`` / null to disable the field.

      :raises KeyError: If *field_type* is not a recognised field type.
      :raises ValueError: If *field_config* defines neither ``vars_3D`` nor ``vars_2D``.



   .. py:method:: _get_file_source(field_config: dict[str, Any]) -> list[tuple[pandas.Timestamp, pandas.Timestamp, str]] | bool | None

      Return the file source for a field. Override in subclasses for different modes/backends.

      :param field_config: Validated field-type config dict.
      :type field_config: dict[str, Any]

      :raises FileNotFoundError: If ``self.mode == "local"`` and the glob pattern matches no files.
      :raises ValueError: If ``self.mode`` is not a recognised mode.

      :returns:

                Depending on the mode and field type,
                    this method may return a list of (start_time, end_time, file_path) tuples produced by _map_files,
                    a boolean indicating the presence of the field (e.g., for remote data), or None if the field is disabled.
                    The expected return type should be consistent within a dataset class.
      :rtype: list[tuple[pd.Timestamp, pd.Timestamp, str]] | bool | None



   .. py:method:: _extract_field(field_type: VALID_FIELD_TYPES, t: pandas.Timestamp, sample: dict[str, Any]) -> None

      Base extract field method, which should be overridden in the inherited dataset class to extract the data for
      each field type. The method should populate data_dict with the extracted data for the given field type and
      timestamp. The keys in data_dict should follow the format in _get_field_name.

      The entries are added as tensors to the sample["input"] or sample["target"] dict in __getitem__.

      :param field_type: One of VALID_FIELD_TYPES.
      :type field_type: VALID_FIELD_TYPES
      :param t: Query timestamp for which to extract the field data.
      :type t: pd.Timestamp
      :param sample: The sample dict being built in __getitem__
      :type sample: dict[str, Any]



