credit.data
===========

.. py:module:: credit.data

.. autoapi-nested-parse::

   Data.py contains modules for processing training data.

   Helper functions:
       - generate_datetime(start_time, end_time, interval_hr)
       - hour_to_nanoseconds(input_hr)
       - nanoseconds_to_year(nanoseconds_value)
       - extract_month_day_hour(dates)
       - find_common_indices(list1, list2)
       - concat_and_reshape(x1, x2)
       - reshape_only(x1)
       - get_forward_data(filename)
       - drop_var_from_dataset()
       - previous_hourly_steps()
       - next_n_hour()
       - encode_datetime64()

   Sample class:
       - Sample
       - Sample_WRF
       - Sample_dscale
       - Sample_diag
       - Sample_LES

   Deprecated
       - ERA5_and_Forcing_Dataset(torch.utils.data.Dataset)
       - Predict_Dataset(torch.utils.data.IterableDataset)



Attributes
----------

.. autoapisummary::

   credit.data.Array
   credit.data.IMAGE_ATTR_NAMES


Classes
-------

.. autoapisummary::

   credit.data.Sample
   credit.data.Sample_WRF
   credit.data.Sample_dscale
   credit.data.Sample_diag
   credit.data.Sample_LES
   credit.data.ERA5_and_Forcing_Dataset
   credit.data.ERA5_Dataset_Distributed
   credit.data.Predict_Dataset


Functions
---------

.. autoapisummary::

   credit.data.device_compatible_to
   credit.data.ensure_numpy_datetime
   credit.data.generate_datetime
   credit.data.hour_to_nanoseconds
   credit.data.nanoseconds_to_year
   credit.data.extract_month_day_hour
   credit.data.find_common_indices
   credit.data.concat_and_reshape
   credit.data.reshape_only
   credit.data.get_forward_data
   credit.data.flatten_list
   credit.data.generate_integer_list_around
   credit.data.find_key_for_number
   credit.data.drop_var_from_dataset
   credit.data.keep_dataset_vars
   credit.data.subset_patch
   credit.data.encode_datetime64
   credit.data.next_n_hour
   credit.data.previous_hourly_steps
   credit.data.filter_ds


Module Contents
---------------

.. py:data:: Array

.. py:data:: IMAGE_ATTR_NAMES
   :value: ('historical_ERA5_images', 'target_ERA5_images')


.. py:function:: device_compatible_to(tensor: torch.Tensor, device: torch.device) -> torch.Tensor

   Safely move tensor to device, with float32 casting on MPS (Metal Performance Shaders).
   Addresses runtime error in OSX about MPS not supporting float64.

   :param tensor: Input tensor to move.
   :type tensor: torch.Tensor
   :param device: Target device.
   :type device: torch.device

   :returns: Tensor moved to device (cast to float32 if device is MPS).
   :rtype: torch.Tensor


.. py:function:: ensure_numpy_datetime(value)

   Converts an input value (or array) to numpy.datetime64.
   Handles numpy arrays, pandas timestamps, cftime objects, and strings.


.. py:function:: generate_datetime(start_time, end_time, interval_hr)

   Generate a list of datetime.datetime based on stat, end times, and hour interval.

   :param start_time: start time
   :type start_time: datetime.datetime
   :param end_time: end time
   :type end_time: datetime.datetime
   :param interval_hr: hour interval
   :type interval_hr: int


.. py:function:: hour_to_nanoseconds(input_hr)

   Convert hour to nanoseconds.


.. py:function:: nanoseconds_to_year(nanoseconds_value)

   Given datetime info as nanoseconds, compute which year it belongs to.


.. py:function:: extract_month_day_hour(dates)

   Given an 1-d array of np.datatime64[ns], extract their mon, day, hr into a zipped list.


.. py:function:: find_common_indices(list1, list2)

   Find indices of common elements between two lists.


.. py:function:: concat_and_reshape(x1, x2)

   Flattening the "level" coordinate of upper-air variables and concatenate it will surface variables.


.. py:function:: reshape_only(x1)

   Flattening the "level" coordinate of upper-air variables.

   As in "concat_and_reshape", but no concat.


.. py:function:: get_forward_data(filename) -> xarray.Dataset

   Check nc vs. zarr files and open file as xr.Dataset.


.. py:function:: flatten_list(list_of_lists)

   Flatten a list of lists.

   :param list_of_lists: A list containing sublists.
   :type list_of_lists: list

   Returns
       flattened_list (list): A flattened list containing all elements from sublists.


.. py:function:: generate_integer_list_around(number, spacing=10)

   Generate a list of integers on either side of a given number with a specified spacing.

   :param number: The central number around which the list is generated.
   :type number: int
   :param spacing: The spacing between consecutive integers in the list. Default is 10.
   :type spacing: int

   :returns: List of integers on either side of the given number.
   :rtype: integer_list (list)


.. py:function:: find_key_for_number(input_number, data_dict)

   Find the key in the dictionary based on the given number.

   :param input_number: The number to search for in the dictionary.
   :type input_number: int
   :param data_dict: The dictionary with keys and corresponding value lists.
   :type data_dict: dict

   :returns: The key in the dictionary where the input number falls within the specified range.
   :rtype: key_found (str)


.. py:function:: drop_var_from_dataset(xarray_dataset, varname_keep)

   Preserve a given set of variables from an xarray.Dataset, and drop the rest.
   It will raise error if `varname_key` is missing from `xarray_dataset`.


.. py:function:: keep_dataset_vars(xarray_dataset: xarray.Dataset, varnames_keep: List[str])

   Return a version of an xarray dataset with only a selected subset of variables.

   :param xarray_dataset: The xarray dataset.
   :type xarray_dataset: xr.Dataset
   :param varnames_keep: a list of variable names to be kept.
   :type varnames_keep: List[str]

   :returns: xr.Dataset with only the variables in varnames_keep included.


.. py:function:: subset_patch(ds: xarray.Dataset, input_size, start, lat_name='yIndex', lon_name='xIndex') -> xarray.Dataset

   Return a spatial subset of shape (time, input_size[0], input_size[1]).
   Assumes ds has dims (time, lat, lon).


.. py:function:: encode_datetime64(dt_array)

.. py:function:: next_n_hour(dt, period_hours)

   Round dt forward to the next N-hour boundary.

   :param dt: np.datetime64[ns] or array of such values
   :param period_hours: int, the interval in hours (e.g., 3, 6)

   :returns: np.datetime64[ns] rounded forward to the next period_hours boundary


.. py:function:: previous_hourly_steps(time_pick, hour, step)

   Given a datetime64[ns] time_pick, compute time_pick - step * hours.


.. py:function:: filter_ds(ds: xarray.Dataset, varnames_keep: Sequence[str]) -> xarray.Dataset

   Return a new Dataset containing only the variables in `varnames_keep`.
   Raises if any var in `varnames_keep` is missing.


.. py:class:: Sample

   Bases: :py:obj:`TypedDict`


   Simple class for structuring data for the ML model.

   Using typing.TypedDict gives us several advantages:
     1. Single 'source of truth' for the type and documentation of each example.
     2. A static type checker can check the types are correct.

   Instead of TypedDict, we could use typing.NamedTuple,
   which would provide runtime checks, but the deal-breaker with Tuples is that they're immutable
   so we cannot change the values in the transforms.


   .. py:attribute:: historical_ERA5_images
      :type:  Array


   .. py:attribute:: target_ERA5_images
      :type:  Array


   .. py:attribute:: datetime_index
      :type:  Array


.. py:class:: Sample_WRF

   Bases: :py:obj:`TypedDict`


   dict() -> new empty dictionary
   dict(mapping) -> new dictionary initialized from a mapping object's
       (key, value) pairs
   dict(iterable) -> new dictionary initialized as if via:
       d = {}
       for k, v in iterable:
           d[k] = v
   dict(**kwargs) -> new dictionary initialized with the name=value pairs
       in the keyword argument list.  For example:  dict(one=1, two=2)


   .. py:attribute:: WRF_input
      :type:  Array


   .. py:attribute:: WRF_target
      :type:  Array


   .. py:attribute:: boundary_input
      :type:  Array


   .. py:attribute:: time_encode
      :type:  Array


   .. py:attribute:: datetime_index
      :type:  Array


.. py:class:: Sample_dscale

   Bases: :py:obj:`TypedDict`


   dict() -> new empty dictionary
   dict(mapping) -> new dictionary initialized from a mapping object's
       (key, value) pairs
   dict(iterable) -> new dictionary initialized as if via:
       d = {}
       for k, v in iterable:
           d[k] = v
   dict(**kwargs) -> new dictionary initialized with the name=value pairs
       in the keyword argument list.  For example:  dict(one=1, two=2)


   .. py:attribute:: LR_input
      :type:  Array


   .. py:attribute:: HR_input
      :type:  Array


   .. py:attribute:: HR_target
      :type:  Array


   .. py:attribute:: time_encode
      :type:  Array


   .. py:attribute:: datetime_index
      :type:  Array


.. py:class:: Sample_diag

   Bases: :py:obj:`TypedDict`


   dict() -> new empty dictionary
   dict(mapping) -> new dictionary initialized from a mapping object's
       (key, value) pairs
   dict(iterable) -> new dictionary initialized as if via:
       d = {}
       for k, v in iterable:
           d[k] = v
   dict(**kwargs) -> new dictionary initialized with the name=value pairs
       in the keyword argument list.  For example:  dict(one=1, two=2)


   .. py:attribute:: WRF_input
      :type:  Array


   .. py:attribute:: WRF_target
      :type:  Array


   .. py:attribute:: time_encode
      :type:  Array


   .. py:attribute:: datetime_index
      :type:  Array


.. py:class:: Sample_LES

   Bases: :py:obj:`TypedDict`


   dict() -> new empty dictionary
   dict(mapping) -> new dictionary initialized from a mapping object's
       (key, value) pairs
   dict(iterable) -> new dictionary initialized as if via:
       d = {}
       for k, v in iterable:
           d[k] = v
   dict(**kwargs) -> new dictionary initialized with the name=value pairs
       in the keyword argument list.  For example:  dict(one=1, two=2)


   .. py:attribute:: LES_input
      :type:  Array


   .. py:attribute:: LES_target
      :type:  Array


   .. py:attribute:: datetime_index
      :type:  Array


.. py:class:: ERA5_and_Forcing_Dataset(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, history_len=2, forecast_len=0, transform=None, seed=42, skip_periods=None, one_shot=None, max_forecast_len=None, sst_forcing=None)

   Bases: :py:obj:`torch.utils.data.Dataset`


   **Deprecated**
   A Pytorch Dataset class that works on the following kinds of variables.

   * upper-air variables (time, level, lat, lon)
   * surface variables (time, lat, lon)
   * dynamic forcing variables (time, lat, lon)
   * forcing variables (time, lat, lon)
   * diagnostic variables (time, lat, lon)
   * static variables (lat, lon).


   .. py:attribute:: history_len
      :value: 2



   .. py:attribute:: forecast_len
      :value: 0



   .. py:attribute:: transform
      :value: None



   .. py:attribute:: skip_periods
      :value: None



   .. py:attribute:: one_shot
      :value: None



   .. py:attribute:: total_seq_len
      :value: 2



   .. py:attribute:: rng


   .. py:attribute:: max_forecast_len
      :value: None



   .. py:attribute:: sst_forcing
      :value: None



   .. py:attribute:: all_files
      :value: []



   .. py:attribute:: ERA5_indices


   .. py:attribute:: filename_forcing
      :value: None



   .. py:attribute:: filename_static
      :value: None



   .. py:method:: __post_init__()

      Calculate total sequence length after init.



   .. py:method:: __len__()

      Length of Dataset.



   .. py:method:: __getitem__(index)

      Get single item from the dataset.



.. py:class:: ERA5_Dataset_Distributed(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, history_len=2, forecast_len=0, transform=None, seed=42, skip_periods=None, one_shot=None, max_forecast_len=None, sst_forcing=None)

   Bases: :py:obj:`torch.utils.data.Dataset`


   ERA5 Dataset for Distributed training (legacy).


   .. py:attribute:: history_len
      :value: 2



   .. py:attribute:: forecast_len
      :value: 0



   .. py:attribute:: transform
      :value: None



   .. py:attribute:: varname_upper_air


   .. py:attribute:: varname_surface


   .. py:attribute:: varname_dyn_forcing


   .. py:attribute:: varname_forcing


   .. py:attribute:: varname_static


   .. py:attribute:: skip_periods
      :value: None



   .. py:attribute:: one_shot
      :value: None



   .. py:attribute:: total_seq_len
      :value: 2



   .. py:attribute:: rng


   .. py:attribute:: max_forecast_len
      :value: None



   .. py:attribute:: sst_forcing
      :value: None



   .. py:attribute:: filenames


   .. py:attribute:: all_files
      :value: []



   .. py:attribute:: ERA5_indices


   .. py:attribute:: filename_forcing
      :value: None



   .. py:attribute:: filename_static
      :value: None



   .. py:method:: __post_init__()

      Calculate total sequence length.



   .. py:method:: __len__()

      Length of dataset.



   .. py:method:: __getitem__(index)

      Get item.

      :param index: index of timestep

      :returns: pytorch Tensor containing a full state.



.. py:class:: Predict_Dataset(conf, varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface, filename_dyn_forcing, filename_forcing, filename_static, filename_diagnostic, fcst_datetime, history_len, rank, world_size, transform=None, rollout_p=0.0, which_forecast=None)

   Bases: :py:obj:`torch.utils.data.IterableDataset`


   Same as ERA5_and_Forcing_Dataset() but work with old rollout_to_netcdf.py.


   .. py:attribute:: rank


   .. py:attribute:: world_size


   .. py:attribute:: transform
      :value: None



   .. py:attribute:: history_len


   .. py:attribute:: init_datetime


   .. py:attribute:: which_forecast
      :value: None



   .. py:attribute:: filenames


   .. py:attribute:: filename_surface


   .. py:attribute:: filename_dyn_forcing


   .. py:attribute:: filename_forcing


   .. py:attribute:: filename_static


   .. py:attribute:: filename_diagnostic


   .. py:attribute:: varname_upper_air


   .. py:attribute:: varname_surface


   .. py:attribute:: varname_dyn_forcing


   .. py:attribute:: varname_forcing


   .. py:attribute:: varname_static


   .. py:attribute:: varname_diagnostic


   .. py:attribute:: all_files
      :value: []



   .. py:attribute:: current_epoch
      :value: 0



   .. py:attribute:: rollout_p
      :value: 0.0



   .. py:attribute:: lead_time_periods


   .. py:attribute:: skip_periods


   .. py:method:: ds_read_and_subset(filename, time_start, time_end, varnames)

      Read and subset specified dataset.

      :param filename: path to specified dataset file.
      :type filename: str
      :param time_start: start time index.
      :type time_start: int
      :param time_end: end time index.
      :type time_end: int
      :param varnames: List of variables to be read.
      :type varnames: list



   .. py:method:: load_zarr_as_input(i_file, i_init_start, i_init_end, mode='input')

      Load input data from zarr files.

      :param i_file: index of the file
      :param i_init_start: start index of the data being loaded
      :param i_init_end: end index of the data being loaded.
      :param mode: "input" or "target"

      :returns: xr.Dataset containing all the variables.



   .. py:method:: find_start_stop_indices(index)

      Find start and stop indices for a given yearly data zarr file.

      :param index: indices of zarr file.



   .. py:method:: __len__()

      Length of dataset.



   .. py:method:: __iter__()

      Iterate through batch.



