credit.datasets.era5
====================

.. py:module:: credit.datasets.era5

.. autoapi-nested-parse::

   era5.py
   -------------------------------------------------------
   Refactored ERA5Dataset with nested input/target structure.

   Sample structure returned by __getitem__:

   .. code-block:: python

       {
           "input": {
               "era5/prognostic/3d/T":        tensor,  # (n_levels, 1, lat, lon)
               "era5/prognostic/2d/SP":       tensor,  # (1,        1, lat, lon)
               "era5/dynamic_forcing/2d/tsi": tensor,
               "era5/static/2d/LSM":          tensor,
               ...
           },
           "target": {                                  # only when return_target=True
               "era5/prognostic/3d/T":        tensor,
               "era5/prognostic/2d/SP":       tensor,
               ...
           },
           "metadata": {
               "input_datetime":  int,                  # nanoseconds since epoch
               "target_datetime": int,                  # only when return_target=True
           },
       }

   Output key format (flat, slash-delimited):
       "{source}/{field_type}/{dim}/{varname}"

       source    : "era5"
       field_type: "prognostic" | "dynamic_forcing" | "static" | "diagnostic"
       dim       : "2d"  (surface / single-level)
                   "3d"  (multi-level upper-air)
       varname   : variable name as given in config (e.g. "T", "SP", "tsi")

   Tensor shapes (no batch dimension):
       3D variable : (n_levels, 1, lat, lon)   — n_levels = len(config levels)
       2D variable : (1,        1, lat, lon)   — singleton level dim

   After DataLoader collation the batch dimension is prepended:
       (batch, n_levels, 1, lat, lon)

   File naming:
       Each field type supports an optional ``filename_time_format`` config key
       that specifies a strftime format string describing how the datetime appears
       in the file name.  Defaults to ``"%Y"`` (annual files).

       Examples::

           filename_time_format: "%Y"       # era5_2021.zarr
           filename_time_format: "%Y_%m"    # era5_2021_06.nc
           filename_time_format: "%Y%m%d"   # era5_20210601.nc

       If only a single file matches the glob pattern, ``filename_time_format`` is
       ignored and that file is used for all timestamps.



Attributes
----------

.. autoapisummary::

   credit.datasets.era5.logger
   credit.datasets.era5.VALID_FIELD_TYPES


Classes
-------

.. autoapisummary::

   credit.datasets.era5.ERA5Dataset
   credit.datasets.era5.ARCOERA5Dataset


Module Contents
---------------

.. py:data:: logger

.. py:data:: VALID_FIELD_TYPES

.. py:class:: ERA5Dataset(config: dict, return_target: bool = False)

   Bases: :py:obj:`torch.utils.data.Dataset`


   PyTorch Dataset for processed ERA5 data with nested input/target structure.

   See module docstring for full description of output format and file naming.

   Example YAML configuration::

       data:
         source:
           ERA5:
             level_coord: "level"
             levels: [10, 30, 40, 50, 60, 70, 80, 90, 95, 100, 105, 110, 120, 130, 136, 137]
             variables:
               prognostic:
                 vars_3D: ['T', 'U', 'V', 'Q']
                 vars_2D: ['SP', 't2m']
                 path: "/data/era5_*.zarr"
                 filename_time_format: "%Y"        # annual (default)
               dynamic_forcing:
                 vars_2D: ['tsi']
                 path: "/data/solar_*.nc"
                 filename_time_format: "%Y_%m"     # monthly
               static:
                 vars_2D: ['Z_GDS4_SFC', 'LSM']
                 path: "/data/lsm.nc"
                 # single file — filename_time_format not needed
               diagnostic: null

         start_datetime: "2017-01-01"
         end_datetime: "2019-12-31"
         timestep: "6h"
         forecast_len: 1

   Assumptions:
       1. A "time" dimension / coordinate is present for non-static fields.
       2. A level coordinate (name given by ``level_coord``) represents the
          vertical axis of 3D variables.
       3. Dimension order: (time, level, latitude, longitude) for 3D;
          (time, latitude, longitude) for 2D; (latitude, longitude) for static.


   .. py:attribute:: source_name
      :type:  str
      :value: 'era5'



   .. py:attribute:: level_coord
      :type:  str


   .. py:attribute:: levels
      :type:  list[int]


   .. py:attribute:: return_target
      :type:  bool
      :value: False



   .. py:attribute:: static_metadata
      :type:  dict


   .. py:attribute:: dt


   .. py:attribute:: num_forecast_steps
      :type:  int


   .. py:attribute:: start_datetime


   .. py:attribute:: end_datetime


   .. py:attribute:: datetimes
      :type:  pandas.DatetimeIndex


   .. py:attribute:: file_dict
      :type:  dict[str, list[tuple[pandas.Timestamp, pandas.Timestamp, str]] | None]


   .. py:attribute:: var_dict
      :type:  dict[str, dict[str, list[str]]]


   .. py:method:: __len__() -> int


   .. py:method:: __getitem__(args: tuple) -> dict

      Return a nested input/target sample dict.

      :param args: ``(t, i)`` where *t* is the current timestamp (nanoseconds
                   or pd.Timestamp) and *i* is the within-sequence step index
                   produced by the sampler. When ``i == 0`` prognostic and static
                   fields are loaded in addition to dynamic forcing.

      :returns: Dict with keys ``"input"``, ``"metadata"``, and optionally
                ``"target"`` (when ``return_target=True``). Both ``"input"`` and
                ``"target"`` are dicts of per-variable tensors keyed by
                ``"era5/{field_type}/{dim}/{varname}"``.



   .. py:method:: _register_field(field_type: str, d: dict | None) -> None

      Validate and register one field type from the config variables block.

      Populates ``self.file_dict`` and ``self.var_dict`` for *field_type*.

      :param field_type: One of ``"prognostic"``, ``"dynamic_forcing"``,
                         ``"static"``, ``"diagnostic"``.
      :param d: Field-type config dict, or ``None`` / null to disable the field.

      :raises KeyError: If *field_type* is not a recognised field type.
      :raises ValueError: If *d* defines neither ``vars_3D`` nor ``vars_2D``.



   .. py:method:: _build_timestamps() -> pandas.DatetimeIndex

      Return valid initialisation timestamps for the dataset.

      :returns: DatetimeIndex from ``start_datetime`` to ``end_datetime`` minus
                the forecast horizon, at the configured timestep frequency.



   .. py:method:: _extract_field(field_type: str, t: pandas.Timestamp, sample: dict) -> None

      Open the dataset for *field_type* at time *t* and populate *sample*.

      Keys written are ``"era5/{field_type}/3d/{varname}"`` for 3D variables
      and ``"era5/{field_type}/2d/{varname}"`` for 2D variables.

      :param field_type: One of ``"prognostic"``, ``"dynamic_forcing"``,
                         ``"static"``, ``"diagnostic"``.
      :param t: Timestamp to select.
      :param sample: Dict to write variable tensors into (modified in place).
                     Tensor shapes (no batch dimension):

                     - 3D variable: ``(n_levels, 1, lat, lon)``
                     - 2D variable: ``(1, 1, lat, lon)``



   .. py:method:: _to_cftime(ts: pandas.Timestamp, calendar: str) -> cftime.datetime
      :staticmethod:


      Convert a pandas Timestamp to a cftime.datetime.

      :param ts: Pandas Timestamp to convert.
      :param calendar: cftime calendar string read from the dataset
                       (e.g. ``"noleap"``, ``"gregorian"``, ``"proleptic_gregorian"``).

      :returns: cftime.datetime with the specified calendar.



.. py:class:: ARCOERA5Dataset(data_config: dict, return_target: bool = False)

   Bases: :py:obj:`torch.utils.data.Dataset`


   PyTorch Dataset for Google Cloud ARCO ERA5 data with nested input/target structure.

   See module docstring for full description of output format and file naming.

   Example YAML configuration::

       data:
         source:
           ARCO_ERA5:
             level_coord: "hybrid"
             levels: [10, 30, 40, 50, 60, 70, 80, 90, 95, 100, 105, 110, 120, 130, 136, 137]
             variables:
               prognostic:
                 vars_3D: ["temperature", "u_component_of_wind", "v_component_of_wind", "specific_humidity"]
                 vars_2D: ["surface_pressure"]
               dynamic_forcing:
                 vars_2D: ["toa_incident_solar_radiation"]
               static:
                 vars_2D: ["land_sea_mask"]
               diagnostic:
                 vars_2D: ["total_precipitation"]

         start_datetime: "2017-01-01"
         end_datetime: "2019-12-31"
         timestep: "6h"
         forecast_len: 1

   Assumptions:
       1. A "time" dimension / coordinate is present for non-static fields.
       2. A level coordinate (name given by ``level_coord``) represents the
          vertical axis of 3D variables.
       3. Dimension order: (time, level, latitude, longitude) for 3D;
          (time, latitude, longitude) for 2D; (latitude, longitude) for static.


   .. py:attribute:: pressure_lev_era5_path
      :value: 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3'



   .. py:attribute:: model_lev_era5_path
      :value: 'gs://gcp-public-data-arco-era5/ar/model-level-1h-0p25deg.zarr-v1'



   .. py:attribute:: model_lev_vars
      :value: ['divergence', 'fraction_of_cloud_cover', 'geopotential', 'ozone_mass_mixing_ratio',...



   .. py:attribute:: source_name
      :type:  str
      :value: 'arco_era5'



   .. py:attribute:: level_coord
      :type:  str


   .. py:attribute:: return_target
      :type:  bool
      :value: False



   .. py:attribute:: static_metadata
      :type:  dict


   .. py:attribute:: dt


   .. py:attribute:: num_forecast_steps
      :type:  int


   .. py:attribute:: start_datetime


   .. py:attribute:: end_datetime


   .. py:attribute:: datetimes
      :type:  pandas.DatetimeIndex


   .. py:attribute:: var_dict
      :type:  dict[str, dict[str, list[str]]]


   .. py:attribute:: fs
      :value: None



   .. py:attribute:: mod_level_store
      :value: None



   .. py:attribute:: pres_level_store
      :value: None



   .. py:method:: __len__() -> int


   .. py:method:: __getitem__(args: tuple) -> dict

      Return a nested input/target sample dict.

      :param args: ``(t, i)`` where *t* is the current timestamp (nanoseconds
                   or pd.Timestamp) and *i* is the within-sequence step index
                   produced by the sampler. When ``i == 0`` prognostic and static
                   fields are loaded in addition to dynamic forcing.

      :returns: Dict with keys ``"input"``, ``"metadata"``, and optionally
                ``"target"`` (when ``return_target=True``). Both ``"input"`` and
                ``"target"`` are dicts of per-variable tensors keyed by
                ``"arco_era5/{field_type}/{dim}/{varname}"``.



   .. py:method:: _init_fs()


   .. py:method:: _register_field(field_type: str, d: dict | None) -> None

      Validate and register one field type from the config variables block.

      Populates ``self.file_dict`` and ``self.var_dict`` for *field_type*.

      :param field_type: One of ``"prognostic"``, ``"dynamic_forcing"``,
                         ``"static"``, ``"diagnostic"``.
      :param d: Field-type config dict, or ``None`` / null to disable the field.

      :raises KeyError: If *field_type* is not a recognised field type.
      :raises ValueError: If *d* defines neither ``vars_3D`` nor ``vars_2D``.



   .. py:method:: _build_timestamps() -> pandas.DatetimeIndex

      Return valid initialisation timestamps for the dataset.

      :returns: DatetimeIndex from ``start_datetime`` to ``end_datetime`` minus
                the forecast horizon, at the configured timestep frequency.



   .. py:method:: _extract_field(field_type: str, t: pandas.Timestamp, sample: dict) -> None

      Open the dataset for *field_type* at time *t* and populate *sample*.

      Keys written are ``"era5/{field_type}/3d/{varname}"`` for 3D variables
      and ``"era5/{field_type}/2d/{varname}"`` for 2D variables.

      :param field_type: One of ``"prognostic"``, ``"dynamic_forcing"``,
                         ``"static"``, ``"diagnostic"``.
      :param t: Timestamp to select.
      :param sample: Dict to write variable tensors into (modified in place).
                     Tensor shapes (no batch dimension):

                     - 3D variable: ``(n_levels, 1, lat, lon)``
                     - 2D variable: ``(1, 1, lat, lon)``



   .. py:method:: _to_cftime(ts: pandas.Timestamp, calendar: str) -> cftime.datetime
      :staticmethod:


      Convert a pandas Timestamp to a cftime.datetime.

      :param ts: Pandas Timestamp to convert.
      :param calendar: cftime calendar string read from the dataset
                       (e.g. ``"noleap"``, ``"gregorian"``, ``"proleptic_gregorian"``).

      :returns: cftime.datetime with the specified calendar.



