credit.datasets.tisr
====================

.. py:module:: credit.datasets.tisr

.. autoapi-nested-parse::

   tisr.py
   -------------------------------------------------------
   TISRDataset: PyTorch Dataset for Total Incident Solar Radiation (TISR) at the top of the atmosphere (TOA).

   Sample structure returned by __getitem__:

       {
           "input":    {<user_provided_name>: {"<user_provided_name>/dynamic_forcing/2d/tisr": tensor}},
           "target":   {<user_provided_name>: {}},  # empty since dynamic forcing is only input
           "metadata": {<user_provided_name>: {"input_datetime": int, "target_datetime": int}},
       }

   TISR only has a single variable and is 2D. Tensor shape (no batch dimension):
       (1, 1, lat, lon)   — singleton level dim, consistent with CREDIT Gen2 2D convention

   After DataLoader collation the batch dimension is prepended:
       (batch, 1, 1, lat, lon)

   Note that Total Incident Solar Radiation (TISR) and Total Solar Irradiance (TSI) are different physical quantities.
   - TSI is the total solar power per unit area measured on a plane perpendicular (at a 90 degree angle) to the sun's rays.
     It is measured at TOA and at the mean Sun-Earth distance (1 AU), and it fluctuates slightly with the Sun's 11-year solar cycle.
   - TISR is the actual amount of solar energy that hits a specific surface with any orientation. It can be measured at TOA or surface
     level, and it varies with time and location.



Attributes
----------

.. autoapisummary::

   credit.datasets.tisr.logger
   credit.datasets.tisr._TORCH_DTYPE


Classes
-------

.. autoapisummary::

   credit.datasets.tisr.TISRDataset


Functions
---------

.. autoapisummary::

   credit.datasets.tisr._era5_tsi_data
   credit.datasets.tisr._get_tsi
   credit.datasets.tisr._get_latlon_grid
   credit.datasets.tisr._get_j2000_days
   credit.datasets.tisr._get_orbital_parameters
   credit.datasets.tisr._get_solar_time
   credit.datasets.tisr._get_hour_angle
   credit.datasets.tisr._get_cosine_zenith_angle
   credit.datasets.tisr._get_instantaneous_toa_tisr
   credit.datasets.tisr._get_integrated_toa_tisr
   credit.datasets.tisr._compute_tisr


Module Contents
---------------

.. py:data:: logger

.. py:data:: _TORCH_DTYPE
   :value: Ellipsis


.. py:function:: _era5_tsi_data(device: torch.device | str = 'cpu') -> tuple[torch.Tensor, torch.Tensor]

   ERA5-compatible Total Solar Irradiance (TSI) time series.

   Sourced from
   `Graphcast <https://github.com/google-deepmind/graphcast/blob/main/graphcast/solar_radiation.py>`_.

   ECMWF provided the data used for ERA5, which was hardcoded in the IFS
   (cycle 41r2, 2016). Values from 2009 onwards repeat the 1996–2008 period
   (the last completed 13-year solar cycle available when the code was
   written). All values are scaled by 0.9965 to agree better with more
   recent solar observations.

   :returns:

                 - **times** – 1-D tensor of fractional years, one entry per year
                   from 1951.5 to 2034.5 (mid-year sampling).
                 - **tsi_values** – 1-D tensor of TSI values (W m⁻²) corresponding
                   to each entry in *times*.
   :rtype: tuple[torch.Tensor, torch.Tensor]


.. py:function:: _get_tsi(timestamps: collections.abc.Sequence[str | pandas.Timestamp], tsi_times: torch.Tensor, tsi_values: torch.Tensor) -> torch.Tensor

   Interpolate Total Solar Irradiance (TSI) at the given timestamps.

   Converts each timestamp to a fractional year and performs piecewise
   linear interpolation against the provided annual TSI time series.

   :param timestamps: Sequence of timestamps (strings or ``pd.Timestamp``
                      objects) at which to evaluate TSI.
   :param tsi_times: 1-D tensor of fractional years (e.g. ``2003.5``),
                     sorted in ascending order.
   :param tsi_values: 1-D tensor of TSI values (W m⁻²) corresponding
                      element-wise to *tsi_times*.

   :returns: 1-D tensor of interpolated TSI values (W m⁻²),
             one per input timestamp.
   :rtype: torch.Tensor

   :raises ValueError: If any timestamp's **year** falls outside the integer
       range spanned by *tsi_times*.  Note that timestamps in the
       first half of the first year or the second half of the last
       year are still accepted and will be lightly extrapolated.


.. py:function:: _get_latlon_grid(path: str | None = None, lat_spec: collections.abc.Sequence[float] | None = None, lon_spec: collections.abc.Sequence[float] | None = None, device: torch.device | str = 'cpu') -> tuple[torch.Tensor, torch.Tensor]

   Obtain a lat/lon grid, either by reading a NetCDF file or building one from specs.

   Exactly one of the two modes must be supplied:

   * **From file** (``path``) – read latitude/longitude from a NetCDF file.
     Handles two representations:

     - *Curvilinear grids* – latitude/longitude stored as 2-D variables of
       shape ``(ny, nx)``, read directly.
     - *Rectangular grids* – latitude/longitude stored as 1-D coordinate
       arrays of lengths ``ny`` and ``nx``; a 2-D meshgrid is constructed.

     Common field name aliases (``latitude``/``lat``/``XLAT``/``nav_lat`` and
     ``longitude``/``lon``/``XLONG``/``nav_lon``) are tried in order so files
     from different models/tools are accepted without preprocessing.

   * **From specs** (``lat_spec`` and ``lon_spec``) – synthesize a rectangular
     grid in-memory without any file I/O. Each spec is ``[start, end, num_points]``
     with both endpoints inclusive (so ``[90, -90, 721]`` yields the ERA5 0.25°
     latitude axis). Both specs must be supplied together; callers are expected
     to validate the pair (see :class:`TISRDataset.__init__`).

   :param path: Path to a NetCDF file containing latitude/longitude
                information. Mutually exclusive with ``lat_spec``/``lon_spec``.
   :type path: str | None
   :param lat_spec: Latitude axis as ``[start, end, num_points]``
                    (e.g. ``[90, -90, 721]``). Mutually exclusive with ``path``.
   :type lat_spec: Sequence[float] | None
   :param lon_spec: Longitude axis as ``[start, end, num_points]``
                    (e.g. ``[0, 359.75, 1440]``). Mutually exclusive with ``path``.
   :type lon_spec: Sequence[float] | None

   :returns:     - **lat_tensor** – Latitude grid in degrees, shape ``(1, ny, nx)``.
                 - **lon_tensor** – Longitude grid in degrees, shape ``(1, ny, nx)``.
   :rtype: tuple[torch.Tensor, torch.Tensor]

   :raises ValueError: If neither or both modes are supplied; if only one of
       ``lat_spec``/``lon_spec`` is given; if a spec is not length 3, or its
       ``num_points`` is not an int >= 1; if the file cannot be opened; if no
       recognised latitude/longitude field is found; or if the lat/lon arrays
       are neither 1-D nor 2-D after squeezing.


.. py:function:: _get_j2000_days(timestamp: pandas.Timestamp | pandas.DatetimeIndex, device: torch.device | str = 'cpu') -> torch.Tensor

   Convert UTC timestamp(s) to fractional days since the J2000.0 epoch.

   :param timestamp: UTC timestamp or collection
                     of timestamps to convert.
   :type timestamp: pd.Timestamp | pd.DatetimeIndex

   :returns:

             Fractional days elapsed since J2000.0 (2000-01-01 12:00:00 UTC).
                 Shape matches the input dimensions, matching `_TORCH_DTYPE`.
   :rtype: torch.Tensor

   .. rubric:: References

   - https://en.wikipedia.org/wiki/Epoch_(astronomy)#Julian_years_and_J2000


.. py:function:: _get_orbital_parameters(j2000_days: torch.Tensor) -> dict[str, torch.Tensor]

   Compute solar orbital parameters from J2000 day count.

   Derives the key quantities needed for TISR calculation — solar declination,
   equation of time, and Earth-Sun distance — using low-order trigonometric
   approximations sourced from the ERA5 / IFS parameterization.

   This function is a PyTorch port of ``_get_orbital_parameters`` from Graphcast's
   ``graphcast/solar_radiation.py`` (Google DeepMind).  The logic, variable names,
   and numerical constants are kept identical to the original; the only changes are
   replacing JAX/NumPy array operations (``jnp.stack``, ``jnp.dot``, ``jnp.sin``,
   etc.) with their PyTorch equivalents (``torch.stack``, ``@``, ``torch.sin``,
   etc.), and passing explicit ``dtype`` and ``device`` arguments to tensor
   constructors to ensure compatibility with the calling context.

   :param j2000_days: Days elapsed since the J2000 epoch
                      (2000-01-01 12:00 TT), shape ``(T,)``.
   :type j2000_days: torch.Tensor

   :returns: Dictionary with the following keys, all
             shape ``(T,)`` unless noted:

             - ``theta``: fractional Julian years since J2000.
             - ``rotational_phase``: UTC time-of-day as a day-fraction
               (0.0 = UTC noon, 0.5 = UTC midnight).
             - ``sin_declination``, ``cos_declination``: sine and cosine of the
               solar declination angle (dimensionless).
             - ``eq_of_time_seconds``: equation of time in seconds.
             - ``solar_distance_au``: Earth-Sun distance in Astronomical Units.
   :rtype: dict[str, torch.Tensor]

   .. rubric:: References

   - https://github.com/google-deepmind/graphcast/blob/08cf73625c9d12bd9aaa038868bcb2fe488f2a22/graphcast/solar_radiation.py#L293


.. py:function:: _get_solar_time(rotational_phase: torch.Tensor, eq_of_time_seconds: torch.Tensor) -> torch.Tensor

   Compute local apparent solar time as a fraction of a day.

   Adjusts the fractional UTC day (rotational phase) by the Equation of Time
   to account for solar variance due to Earth's orbital eccentricity and axial tilt.

   :param rotational_phase: Fractional part of the J2000 day count,
                            representing UTC time-of-day, shape ``(T,)``. Because the J2000 epoch
                            starts at noon, ``0.0`` represents UTC Noon (12:00) and ``0.5``
                            represents UTC Midnight (00:00).
   :type rotational_phase: torch.Tensor
   :param eq_of_time_seconds: Equation of time in seconds,
                              shape ``(T,)``. Positive values indicate apparent solar noon occurs
                              before mean solar noon.
   :type eq_of_time_seconds: torch.Tensor

   :returns:

             Apparent solar time at the prime meridian as a fraction
                 of a day in ``[0, 1)``, shape ``(T,)``.
   :rtype: torch.Tensor

   .. rubric:: References

   - https://en.wikipedia.org/wiki/Equation_of_time


.. py:function:: _get_hour_angle(solar_time: torch.Tensor, longitude: torch.Tensor) -> torch.Tensor

   Compute the solar hour angle from apparent solar time and longitude.

   The hour angle measures how far the Sun has moved across the sky relative
   to the local meridian: 0° at solar noon, increasing 15° per hour (360° per day).
   Longitude shifts the prime-meridian-referenced solar time to each grid point's
   local meridian.

   :param solar_time: Apparent solar time as a fraction of a day,
                      with 0.0 corresponding to UTC noon at the prime meridian (J2000 origin).
                      Shape broadcastable to ``(T,)``.
   :type solar_time: torch.Tensor
   :param longitude: Geographic longitude in degrees (positive east),
                     shape broadcastable to ``(ny, nx)``.
   :type longitude: torch.Tensor

   :returns:

             Solar hour angle in degrees, shape ``(T, ny, nx)`` after
                 broadcasting.  0° at solar noon; ±180° at solar midnight.  Not
                 wrapped to ``[−180°, 180°]`` — pass directly to ``torch.deg2rad``
                 before taking the cosine.
   :rtype: torch.Tensor

   .. rubric:: References

   - https://en.wikipedia.org/wiki/Hour_angle#Solar_hour_angle
     (conceptual reference; the page gives no explicit formula, only the
     prose rule "15° per hour before/after solar noon". The formula here
     uses a J2000 rotational phase origin at UTC noon, so the −180° offset
     present in midnight-origin derivations is absent.)


.. py:function:: _get_cosine_zenith_angle(cos_declination: torch.Tensor, sin_declination: torch.Tensor, latitude: torch.Tensor, hour_angle: torch.Tensor) -> torch.Tensor

   Compute the cosine of the solar zenith angle at each grid point and time.

   Uses the standard spherical-trigonometry identity::

       cos(θ_z) = cos(φ)·cos(δ)·cos(H) + sin(φ)·sin(δ)

   where ``φ`` is geographic latitude, ``δ`` is solar declination, and
   ``H`` is the hour angle.  Negative values (Sun below the horizon) are
   floored to zero; no upper clamp is applied since values above 1.0 are
   physically impossible and should surface as bugs rather than be silently
   masked.

   :param cos_declination: Cosine of solar declination, shape
                           ``(T,)`` – one value per timestamp.
   :type cos_declination: torch.Tensor
   :param sin_declination: Sine of solar declination, shape
                           ``(T,)``.
   :type sin_declination: torch.Tensor
   :param latitude: Geographic latitude in degrees, shape
                    ``(1, ny, nx)`` or broadcastable equivalent.
   :type latitude: torch.Tensor
   :param hour_angle: Solar hour angle in degrees, shape
                      broadcastable to ``(T, ny, nx)``.
   :type hour_angle: torch.Tensor

   :returns: Cosine of the solar zenith angle floored at 0,
             shape ``(T, ny, nx)``.  A value of 1.0 means the Sun is directly
             overhead; 0.0 means the Sun is on or below the horizon.
   :rtype: torch.Tensor

   .. rubric:: References

   https://en.wikipedia.org/wiki/Solar_zenith_angle#Formula


.. py:function:: _get_instantaneous_toa_tisr(tsi: torch.Tensor, solar_factor: torch.Tensor, cos_zenith: torch.Tensor) -> torch.Tensor

   Compute the instantaneous total incident solar radiation at the top of the atmosphere.

   Applies the standard TISR formula::

       tisr = tsi * solar_factor * cos_zenith

   :param tsi: Total solar irradiance, in W/m². Broadcast-compatible
               with ``solar_factor`` and ``cos_zenith``.
   :type tsi: torch.Tensor
   :param solar_factor: Earth-Sun distance correction factor, defined
                        as the inverse square of the Earth-Sun distance in Astronomical Units
                        (AU). Also referred to as the eccentricity correction factor.
   :type solar_factor: torch.Tensor
   :param cos_zenith: Cosine of the solar zenith angle
   :type cos_zenith: torch.Tensor

   :returns:

             Instantaneous total incident solar radiation at the top of
                 the atmosphere, in W/m².
   :rtype: torch.Tensor


.. py:function:: _get_integrated_toa_tisr(instantaneous_toa_tisr: torch.Tensor, integration_period: pandas.Timedelta = pd.Timedelta(hours=1), num_integration_steps: int = 360) -> torch.Tensor

   Compute the integrated total incident solar radiation at the top of the atmosphere.

   Uses the trapezoidal rule to integrate instantaneous TOA TISR over a given
   period, following ERA5's convention of labeling accumulated fields at the end
   of the accumulation window: ``(target_time - integration_period, target_time]``.
   For example, with the default 1-hour period, the integrated TOA TISR at
   2021-06-01 00:00:00 covers 2021-05-31 23:00:00 to 2021-06-01 00:00:00.

   :param instantaneous_toa_tisr: Instantaneous total incident solar
                                  radiation at the top of the atmosphere. Shape: ``(time_steps, lat, lon)``
                                  where ``time_steps`` must equal ``num_integration_steps + 1``
                                  (both endpoints are required for trapezoidal integration).
   :type instantaneous_toa_tisr: torch.Tensor
   :param integration_period: Duration over which to integrate
                              ``instantaneous_toa_tisr``. Defaults to 1 hour (compatible with ERA5).
   :type integration_period: pd.Timedelta, optional
   :param num_integration_steps: Number of equally-spaced bins over
                                 the integration period. Defaults to 360.
   :type num_integration_steps: int, optional

   :raises ValueError: If ``num_integration_steps`` is not a positive integer.
   :raises ValueError: If ``instantaneous_toa_tisr.shape[0] != num_integration_steps + 1``.

   :returns:

             Integrated total incident solar radiation at the top of the
                 atmosphere, in J/m². Shape: ``(lat, lon)``.
   :rtype: torch.Tensor


.. py:function:: _compute_tisr(t: pandas.Timestamp, integration_period: pandas.Timedelta, num_integration_steps: int, latitude: torch.Tensor, longitude: torch.Tensor) -> torch.Tensor

   Full pipeline for integrated top-of-atmosphere TISR at a target timestamp.

   Vectorized over both space and time — all grid points and timesteps are
   processed in a single pass without any Python-level loops: builds a
   time grid covering the accumulation window, loads the lat/lon grid, retrieves
   total solar irradiance and orbital parameters, computes per-grid-point cosine
   zenith angles, and finally integrates instantaneous TOA TISR over the
   accumulation window using the trapezoidal rule.

   Following ERA5 convention, the accumulation window is the half-open interval
   ``(t - integration_period, t]``. For example, with the default 1-hour period,
   a target time of 2021-06-01 01:00:00 covers 2021-06-01 00:00:00 to
   2021-06-01 01:00:00.

   The ERA5 dataset contains one hourly, 31 km high resolution realisation
   (referred to as "reanalysis" or "HRES") and a reduced resolution ten member
   ensemble (referred to as "ensemble" or "EDA"). For more details, see the
   `ERA5 data documentation <https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation>`_.

   .. note::

      This function always returns integrated TISR in J/m². To obtain
      instantaneous TISR in W/m², use :func:`_get_instantaneous_toa_tisr`
      directly.

   :param t: Target timestamp at the end of the accumulation window.
   :type t: pd.Timestamp
   :param integration_period: Length of the accumulation window.
                              Defaults to 1 hour in callers, consistent with ERA5 hourly accumulations.
   :type integration_period: pd.Timedelta
   :param num_integration_steps: Number of equally-spaced sub-intervals used
                                 by the trapezoidal integrator. Higher values increase accuracy.
                                 Must be a positive integer.
   :type num_integration_steps: int
   :param latitude: Latitude grid in degrees, shape ``(1, ny, nx)``.
   :type latitude: torch.Tensor
   :param longitude: Longitude grid in degrees, shape ``(1, ny, nx)``.
   :type longitude: torch.Tensor

   :returns:

             Integrated TOA TISR over the accumulation window, in J/m².
                 Shape: ``(lat, lon)``.
   :rtype: torch.Tensor


.. py:class:: TISRDataset(data_config: dict[str, Any], return_target: bool = False)

   Bases: :py:obj:`credit.datasets.base_dataset.BaseDataset`


   PyTorch Dataset for Total Incident Solar Radiation (TISR) at the top of the atmosphere (TOA).

   Computations in this class are designed to mimic ERA5's ``toa_incident_solar_radiation`` (``tisr``)
   variable (units: J/m2, see https://codes.ecmwf.int/grib/param-db/212) by interpolating ERA5-compatible
   Total Solar Irradiance (TSI) values to the requested timestamps, then integrating the product of the TSI,
   a solar scaling factor, and the cosine of the solar zenith angle over the specified period. Defaults to
   ERA5-compatible settings: an integration period of one hour with 360 integration bins.

   While the default configuration targets ERA5 compatibility, both the integration period and bin
   count are configurable for other use cases. Input timestamps must fall within the TSI data
   range (1951-2034).

   Note that the TISR dataset is typically used as a dynamic forcing/input variable rather than a
   target, so the ``return_target`` parameter is set to False by default. TISR dataset is not loading
   any data from local or remote files, but rather performing the computation on-the-fly (no need to
   specify loading mode like most other datasets). Because computation happens inside ``__getitem__``,
   the dataset emits CPU tensors by default. When used with a multi-worker ``DataLoader``
   (``num_workers > 0``), keep ``device="cpu"`` (the default) and let the training loop move each
   collated batch to the GPU; constructing CUDA tensors in worker subprocesses is unsupported by
   PyTorch. The ``device`` config key is provided mainly for single-process (``num_workers=0``) use.

   Exactly one grid source must be configured: either ``latlon_grid_path`` (read from a NetCDF
   file) or both ``lat_spec`` and ``lon_spec`` (build a rectangular grid in-memory, no file read).
   Each spec is a ``[start, end, num_points]`` list with both endpoints inclusive.

   See module docstring for full description of output format and file naming.

   Example YAML configuration (grid read from file):

       data:
           source:
               Example_TISR:  # User-provided name (arbitrary key)
                   dataset_type: "tisr"
                   variables:
                       prognostic: null
                       diagnostic: null
                       dynamic_forcing:
                           var_2d: ['tisr']  # only accept 'tisr'
                   num_integration_steps: 2160  # 360 steps per hour → 6h integration with 1h accumulation windows
                   latlon_grid_path: "/glade/derecho/scratch/cbecker/test_CREDIT_data/era5_local_testing_data_onedeg_2021.nc"

           start_datetime: "2021-06-01"
           end_datetime: "2021-06-04"
           timestep: "6h"
           forecast_len: 0

   Example YAML configuration (grid built in-memory from specs):

       data:
           source:
               Example_TISR:
                   dataset_type: "tisr"
                   variables:
                       prognostic: null
                       diagnostic: null
                       dynamic_forcing:
                           var_2d: ['tisr']
                   num_integration_steps: 2160
                   lat_spec: [90, -90, 721]      # [start, end, num_points], endpoints inclusive
                   lon_spec: [0, 359.75, 1440]   # 0.25° grid; excludes the 360° wrap

           start_datetime: "2021-06-01"
           end_datetime: "2021-06-04"
           timestep: "6h"
           forecast_len: 0


   .. py:attribute:: dataset_type
      :value: 'tisr'



   .. py:attribute:: static_metadata
      :type:  dict[str, Any]


   .. py:attribute:: num_integration_steps
      :type:  int


   .. py:attribute:: device
      :type:  torch.device


   .. py:attribute:: latlon_grid_path
      :type:  str | None


   .. py:method:: _get_file_source(field_config: dict[str, Any]) -> None

      Returns None since TISR dataset is not loading any data from local or remote files.



   .. py:method:: _extract_field(field_type: str, t: pandas.Timestamp, sample: dict) -> None

      Load the TISR 2-D variable for a field type (dynamic_forcing only) at time ``t`` into ``sample``.

      Computes the top-of-atmosphere solar radiation integrated over ``dt``
      ending at ``t``, and stores it as a ``torch.Tensor`` of shape
      ``(1, 1, ny, nx)`` under the key ``"{source_name}/{field_type}/2d/tisr"``
      in ``sample``. Does nothing if the field type has no registered variables.

      :param field_type: ``"dynamic_forcing"`` only, and others set to null.
      :param t: Timestamp for which to load data.
      :param sample: Output dictionary that is updated in-place.



