credit.data
===========

.. py:module:: credit.data

.. autoapi-nested-parse::

   Data.py contains modules for processing training data.

   Content:
       - generate_datetime(start_time, end_time, interval_hr)
       - hour_to_nanoseconds(input_hr)
       - nanoseconds_to_year(nanoseconds_value)
       - extract_month_day_hour(dates)
       - find_common_indices(list1, list2)
       - concat_and_reshape(x1, x2)
       - reshape_only(x1)
       - get_forward_data(filename)
       - drop_var_from_dataset()
       - ERA5_and_Forcing_Dataset(torch.utils.data.Dataset)
       - Predict_Dataset(torch.utils.data.IterableDataset)



Attributes
----------

.. autoapisummary::

   credit.data.Array
   credit.data.IMAGE_ATTR_NAMES


Classes
-------

.. autoapisummary::

   credit.data.Sample
   credit.data.ERA5_and_Forcing_Dataset
   credit.data.ERA5_Dataset_Distributed
   credit.data.Predict_Dataset


Functions
---------

.. autoapisummary::

   credit.data.ensure_numpy_datetime
   credit.data.generate_datetime
   credit.data.hour_to_nanoseconds
   credit.data.nanoseconds_to_year
   credit.data.extract_month_day_hour
   credit.data.find_common_indices
   credit.data.concat_and_reshape
   credit.data.reshape_only
   credit.data.get_forward_data
   credit.data.flatten_list
   credit.data.generate_integer_list_around
   credit.data.find_key_for_number
   credit.data.drop_var_from_dataset
   credit.data.keep_dataset_vars


Module Contents
---------------

.. py:data:: Array

.. py:data:: IMAGE_ATTR_NAMES
   :value: ('historical_ERA5_images', 'target_ERA5_images')


.. py:function:: ensure_numpy_datetime(value)

   Converts an input value (or array) to numpy.datetime64.
   Handles numpy arrays, pandas timestamps, cftime objects, and strings.


.. py:function:: generate_datetime(start_time, end_time, interval_hr)

   Generate a list of datetime.datetime based on stat, end times, and hour interval.

   :param start_time: start time
   :type start_time: datetime.datetime
   :param end_time: end time
   :type end_time: datetime.datetime
   :param interval_hr: hour interval
   :type interval_hr: int


.. py:function:: hour_to_nanoseconds(input_hr)

   Convert hour to nanoseconds.


.. py:function:: nanoseconds_to_year(nanoseconds_value)

   Given datetime info as nanoseconds, compute which year it belongs to.


.. py:function:: extract_month_day_hour(dates)

   Given an 1-d array of np.datatime64[ns], extract their mon, day, hr into a zipped list.


.. py:function:: find_common_indices(list1, list2)

   Find indices of common elements between two lists.


.. py:function:: concat_and_reshape(x1, x2)

   Flattening the "level" coordinate of upper-air variables and concatenate it will surface variables.


.. py:function:: reshape_only(x1)

   Flattening the "level" coordinate of upper-air variables.

   As in "concat_and_reshape", but no concat.


.. py:function:: get_forward_data(filename) -> xarray.Dataset

   Check nc vs. zarr files and open file as xr.Dataset.


.. py:class:: Sample

   Bases: :py:obj:`TypedDict`


   Simple class for structuring data for the ML model.

   Using typing.TypedDict gives us several advantages:
     1. Single 'source of truth' for the type and documentation of each example.
     2. A static type checker can check the types are correct.

   Instead of TypedDict, we could use typing.NamedTuple,
   which would provide runtime checks, but the deal-breaker with Tuples is that they're immutable
   so we cannot change the values in the transforms.


   .. py:attribute:: historical_ERA5_images
      :type:  Array


   .. py:attribute:: target_ERA5_images
      :type:  Array


   .. py:attribute:: datetime_index
      :type:  Array


.. py:function:: flatten_list(list_of_lists)

   Flatten a list of lists.

   :param - list_of_lists (list):
   :type - list_of_lists (list): A list containing sublists.

   :returns: **- flattened_list (list)**
   :rtype: A flattened list containing all elements from sublists.


.. py:function:: generate_integer_list_around(number, spacing=10)

   Generate a list of integers on either side of a given number with a specified spacing.

   :param - number (int):
   :type - number (int): The central number around which the list is generated.
   :param - spacing (int):
   :type - spacing (int): The spacing between consecutive integers in the list. Default is 10.

   :returns: **- integer_list (list)**
   :rtype: List of integers on either side of the given number.


.. py:function:: find_key_for_number(input_number, data_dict)

   Find the key in the dictionary based on the given number.

   :param - input_number (int):
   :type - input_number (int): The number to search for in the dictionary.
   :param - data_dict (dict):
   :type - data_dict (dict): The dictionary with keys and corresponding value lists.

   :returns: **- key_found (str)**
   :rtype: The key in the dictionary where the input number falls within the specified range.


.. py:function:: drop_var_from_dataset(xarray_dataset, varname_keep)

   Preserve a given set of variables from an xarray.Dataset, and drop the rest.

   It will raise error if `varname_key` is missing from `xarray_dataset`.


.. py:function:: keep_dataset_vars(xarray_dataset: xarray.Dataset, varnames_keep: List[str])

   Return a version of an xarray dataset with only a selected subset of variables.

   :param xarray_dataset: The xarray dataset.
   :type xarray_dataset: xr.Dataset
   :param varnames_keep: a list of variable names to be kept.
   :type varnames_keep: List[str]

   :returns: xr.Dataset with only the variables in varnames_keep included.


.. py:class:: ERA5_and_Forcing_Dataset(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, history_len=2, forecast_len=0, transform=None, seed=42, skip_periods=None, one_shot=None, max_forecast_len=None, sst_forcing=None)

   Bases: :py:obj:`torch.utils.data.Dataset`


   A Pytorch Dataset class that works on the following kinds of variables.

   * upper-air variables (time, level, lat, lon)
   * surface variables (time, lat, lon)
   * dynamic forcing variables (time, lat, lon)
   * forcing variables (time, lat, lon)
   * diagnostic variables (time, lat, lon)
   * static variables (lat, lon).


   .. py:attribute:: history_len
      :value: 2



   .. py:attribute:: forecast_len
      :value: 0



   .. py:attribute:: transform
      :value: None



   .. py:attribute:: skip_periods
      :value: None



   .. py:attribute:: one_shot
      :value: None



   .. py:attribute:: total_seq_len
      :value: 2



   .. py:attribute:: rng


   .. py:attribute:: max_forecast_len
      :value: None



   .. py:attribute:: sst_forcing
      :value: None



   .. py:attribute:: all_files
      :value: []



   .. py:attribute:: ERA5_indices


   .. py:attribute:: filename_forcing
      :value: None



   .. py:attribute:: filename_static
      :value: None



   .. py:method:: __post_init__()

      Calculate total sequence length after init.



   .. py:method:: __len__()

      Length of Dataset.



   .. py:method:: __getitem__(index)

      Get single item from the dataset.



.. py:class:: ERA5_Dataset_Distributed(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, history_len=2, forecast_len=0, transform=None, seed=42, skip_periods=None, one_shot=None, max_forecast_len=None, sst_forcing=None)

   Bases: :py:obj:`torch.utils.data.Dataset`


   ERA5 Dataset for Distributed training (legacy).


   .. py:attribute:: history_len
      :value: 2



   .. py:attribute:: forecast_len
      :value: 0



   .. py:attribute:: transform
      :value: None



   .. py:attribute:: varname_upper_air


   .. py:attribute:: varname_surface


   .. py:attribute:: varname_dyn_forcing


   .. py:attribute:: varname_forcing


   .. py:attribute:: varname_static


   .. py:attribute:: skip_periods
      :value: None



   .. py:attribute:: one_shot
      :value: None



   .. py:attribute:: total_seq_len
      :value: 2



   .. py:attribute:: rng


   .. py:attribute:: max_forecast_len
      :value: None



   .. py:attribute:: sst_forcing
      :value: None



   .. py:attribute:: filenames


   .. py:attribute:: all_files
      :value: []



   .. py:attribute:: ERA5_indices


   .. py:attribute:: filename_forcing
      :value: None



   .. py:attribute:: filename_static
      :value: None



   .. py:method:: __post_init__()

      Calculate total sequence length.



   .. py:method:: __len__()

      Length of dataset.



   .. py:method:: __getitem__(index)

      Get item.

      :param index: index of timestep

      :returns: pytorch Tensor containing a full state.



.. py:class:: Predict_Dataset(conf, varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface, filename_dyn_forcing, filename_forcing, filename_static, filename_diagnostic, fcst_datetime, history_len, rank, world_size, transform=None, rollout_p=0.0, which_forecast=None)

   Bases: :py:obj:`torch.utils.data.IterableDataset`


   Same as ERA5_and_Forcing_Dataset() but work with old rollout_to_netcdf.py.


   .. py:attribute:: rank


   .. py:attribute:: world_size


   .. py:attribute:: transform
      :value: None



   .. py:attribute:: history_len


   .. py:attribute:: init_datetime


   .. py:attribute:: which_forecast
      :value: None



   .. py:attribute:: filenames


   .. py:attribute:: filename_surface


   .. py:attribute:: filename_dyn_forcing


   .. py:attribute:: filename_forcing


   .. py:attribute:: filename_static


   .. py:attribute:: filename_diagnostic


   .. py:attribute:: varname_upper_air


   .. py:attribute:: varname_surface


   .. py:attribute:: varname_dyn_forcing


   .. py:attribute:: varname_forcing


   .. py:attribute:: varname_static


   .. py:attribute:: varname_diagnostic


   .. py:attribute:: all_files
      :value: []



   .. py:attribute:: current_epoch
      :value: 0



   .. py:attribute:: rollout_p
      :value: 0.0



   .. py:attribute:: lead_time_periods


   .. py:attribute:: skip_periods


   .. py:method:: ds_read_and_subset(filename, time_start, time_end, varnames)

      Read and subset specified dataset.

      :param filename: path to specified dataset file.
      :type filename: str
      :param time_start: start time index.
      :type time_start: int
      :param time_end: end time index.
      :type time_end: int
      :param varnames: List of variables to be read.
      :type varnames: list



   .. py:method:: load_zarr_as_input(i_file, i_init_start, i_init_end, mode='input')

      Load input data from zarr files.

      :param i_file: index of the file
      :param i_init_start: start index of the data being loaded
      :param i_init_end: end index of the data being loaded.
      :param mode: "input" or "target"

      :returns: xr.Dataset containing all the variables.



   .. py:method:: find_start_stop_indices(index)

      Find start and stop indices for a given yearly data zarr file.

      :param index: indices of zarr file.



   .. py:method:: __len__()

      Length of dataset.



   .. py:method:: __iter__()

      Iterate through batch.



