credit.datasets.sequential_multistep
====================================

.. py:module:: credit.datasets.sequential_multistep

.. autoapi-nested-parse::

   Pytorch IterableDataset for multi-step training

   Reference:
       Non-daemonic Python pool process
       https://stackoverflow.com/questions/6974695/python-process-pool-non-daemonic

       Pytorch Iterable Dataset
       https://colab.research.google.com/drive/1OFLZnX9y5QUFNONuvFsxOizq4M-tFvk-?usp=sharing#scrollTo=CxSCQPOMHgwo



Attributes
----------

.. autoapisummary::

   credit.datasets.sequential_multistep.logger


Classes
-------

.. autoapisummary::

   credit.datasets.sequential_multistep.DistributedSequentialDataset
   credit.datasets.sequential_multistep.DistributedSequentialDatasetBasic


Functions
---------

.. autoapisummary::

   credit.datasets.sequential_multistep.worker


Module Contents
---------------

.. py:data:: logger

.. py:class:: DistributedSequentialDataset(varname_upper_air: List[str], varname_surface: List[str], varname_dyn_forcing: List[str], varname_forcing: List[str], varname_static: List[str], varname_diagnostic: List[str], filenames: List[str], filename_surface: Optional[List[str]] = None, filename_dyn_forcing: Optional[List[str]] = None, filename_forcing: Optional[str] = None, filename_static: Optional[str] = None, filename_diagnostic: Optional[List[str]] = None, rank: int = 0, world_size: int = 1, history_len: int = 2, forecast_len: int = 0, transform: Optional[Callable] = None, seed: int = 42, skip_periods: Optional[int] = None, max_forecast_len: Optional[int] = None, shuffle: bool = True, num_workers: int = 0)

   Bases: :py:obj:`torch.utils.data.IterableDataset`


   An iterable Dataset.

   All datasets that represent an iterable of data samples should subclass it.
   Such form of datasets is particularly useful when data come from a stream.

   All subclasses should overwrite :meth:`__iter__`, which would return an
   iterator of samples in this dataset.

   When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
   item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
   iterator. When :attr:`num_workers > 0`, each worker process will have a
   different copy of the dataset object, so it is often desired to configure
   each copy independently to avoid having duplicate data returned from the
   workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
   process, returns information about the worker. It can be used in either the
   dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
   :attr:`worker_init_fn` option to modify each copy's behavior.

   Example 1: splitting workload across all workers in :meth:`__iter__`::

       >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
       >>> # xdoctest: +SKIP("Fails on MacOS12")
       >>> class MyIterableDataset(torch.utils.data.IterableDataset):
       ...     def __init__(self, start, end):
       ...         super(MyIterableDataset).__init__()
       ...         assert end > start, "this example only works with end >= start"
       ...         self.start = start
       ...         self.end = end
       ...
       ...     def __iter__(self):
       ...         worker_info = torch.utils.data.get_worker_info()
       ...         if worker_info is None:  # single-process data loading, return the full iterator
       ...             iter_start = self.start
       ...             iter_end = self.end
       ...         else:  # in a worker process
       ...             # split workload
       ...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
       ...             worker_id = worker_info.id
       ...             iter_start = self.start + worker_id * per_worker
       ...             iter_end = min(iter_start + per_worker, self.end)
       ...         return iter(range(iter_start, iter_end))
       ...
       >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
       >>> ds = MyIterableDataset(start=3, end=7)

       >>> # Single-process loading
       >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
       [tensor([3]), tensor([4]), tensor([5]), tensor([6])]

       >>> # xdoctest: +REQUIRES(POSIX)
       >>> # Multi-process loading with two worker processes
       >>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
       >>> # xdoctest: +IGNORE_WANT("non deterministic")
       >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
       [tensor([3]), tensor([5]), tensor([4]), tensor([6])]

       >>> # With even more workers
       >>> # xdoctest: +IGNORE_WANT("non deterministic")
       >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
       [tensor([3]), tensor([5]), tensor([4]), tensor([6])]

   Example 2: splitting workload across all workers using :attr:`worker_init_fn`::

       >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
       >>> class MyIterableDataset(torch.utils.data.IterableDataset):
       ...     def __init__(self, start, end):
       ...         super(MyIterableDataset).__init__()
       ...         assert end > start, "this example only works with end >= start"
       ...         self.start = start
       ...         self.end = end
       ...
       ...     def __iter__(self):
       ...         return iter(range(self.start, self.end))
       ...
       >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
       >>> ds = MyIterableDataset(start=3, end=7)

       >>> # Single-process loading
       >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
       [3, 4, 5, 6]
       >>>
       >>> # Directly doing multi-process loading yields duplicate data
       >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
       [3, 3, 4, 4, 5, 5, 6, 6]

       >>> # Define a `worker_init_fn` that configures each dataset copy differently
       >>> def worker_init_fn(worker_id):
       ...     worker_info = torch.utils.data.get_worker_info()
       ...     dataset = worker_info.dataset  # the dataset copy in this worker process
       ...     overall_start = dataset.start
       ...     overall_end = dataset.end
       ...     # configure the dataset to only process the split workload
       ...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
       ...     worker_id = worker_info.id
       ...     dataset.start = overall_start + worker_id * per_worker
       ...     dataset.end = min(dataset.start + per_worker, overall_end)
       ...

       >>> # Mult-process loading with the custom `worker_init_fn`
       >>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
       >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
       [3, 5, 4, 6]

       >>> # With even more workers
       >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
       [3, 4, 5, 6]


   .. py:attribute:: history_len
      :value: 2



   .. py:attribute:: forecast_len
      :value: 0



   .. py:attribute:: transform
      :value: None



   .. py:attribute:: rank
      :value: 0



   .. py:attribute:: world_size
      :value: 1



   .. py:attribute:: shuffle
      :value: True



   .. py:attribute:: current_epoch
      :value: 0



   .. py:attribute:: num_workers
      :value: 0



   .. py:attribute:: skip_periods
      :value: None



   .. py:attribute:: total_seq_len
      :value: 2



   .. py:attribute:: rng


   .. py:attribute:: max_forecast_len
      :value: None



   .. py:attribute:: all_files
      :value: []



   .. py:attribute:: ERA5_indices


   .. py:attribute:: filename_forcing
      :value: None



   .. py:attribute:: filename_static
      :value: None



   .. py:method:: __post_init__()


   .. py:method:: __len__() -> int


   .. py:method:: set_epoch(epoch: int) -> None


   .. py:method:: __iter__()


.. py:function:: worker(tuple_index: Tuple[int, int], ERA5_indices: Dict[str, List[int]], all_files: List[Any], surface_files: Optional[List[Any]], dyn_forcing_files: Optional[List[Any]], diagnostic_files: Optional[List[Any]], xarray_forcing: Optional[Any], xarray_static: Optional[Any], history_len: int, forecast_len: int, skip_periods: int, transform: Optional[Callable]) -> Dict[str, Any]

   Processes a given index to extract and transform data for a specific time slice.

   Parameters:
   - tuple_index (Tuple[int, int]): Tuple containing the current index and sub-index for processing.
   - ERA5_indices (Dict[str, List[int]]): Dictionary containing ERA5 indices metadata.
   - all_files (List[Any]): List of xarray datasets containing upper air data.
   - surface_files (Optional[List[Any]]): List of xarray datasets containing surface data.
   - dyn_forcing_files (Optional[List[Any]]): List of xarray datasets containing dynamic forcing data.
   - diagnostic_files (Optional[List[Any]]): List of xarray datasets containing diagnostic data.
   - history_len (int): Length of the history sequence.
   - forecast_len (int): Length of the forecast sequence.
   - skip_periods (int): Number of periods to skip between samples.
   - xarray_forcing (Optional[Any]): xarray dataset containing forcing data.
   - xarray_static (Optional[Any]): xarray dataset containing static data.

   - transform (Optional[Callable]): Transformation function to apply to the data.

   Returns:
   - Dict[str, Any]: A dictionary containing historical ERA5 images, target ERA5 images, datetime index, and additional information.


.. py:class:: DistributedSequentialDatasetBasic(varname_upper_air: List[str], varname_surface: List[str], varname_dyn_forcing: List[str], varname_forcing: List[str], varname_static: List[str], varname_diagnostic: List[str], filenames: List[str], filename_surface: Optional[List[str]] = None, filename_dyn_forcing: Optional[List[str]] = None, filename_forcing: Optional[str] = None, filename_static: Optional[str] = None, filename_diagnostic: Optional[List[str]] = None, rank: int = 0, world_size: int = 1, history_len: int = 2, forecast_len: int = 0, transform: Optional[Callable] = None, seed: int = 42, skip_periods: Optional[int] = None, max_forecast_len: Optional[int] = None, shuffle: bool = True, num_workers: int = 0)

   Bases: :py:obj:`DistributedSequentialDataset`


   An iterable Dataset.

   All datasets that represent an iterable of data samples should subclass it.
   Such form of datasets is particularly useful when data come from a stream.

   All subclasses should overwrite :meth:`__iter__`, which would return an
   iterator of samples in this dataset.

   When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
   item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
   iterator. When :attr:`num_workers > 0`, each worker process will have a
   different copy of the dataset object, so it is often desired to configure
   each copy independently to avoid having duplicate data returned from the
   workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
   process, returns information about the worker. It can be used in either the
   dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
   :attr:`worker_init_fn` option to modify each copy's behavior.

   Example 1: splitting workload across all workers in :meth:`__iter__`::

       >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
       >>> # xdoctest: +SKIP("Fails on MacOS12")
       >>> class MyIterableDataset(torch.utils.data.IterableDataset):
       ...     def __init__(self, start, end):
       ...         super(MyIterableDataset).__init__()
       ...         assert end > start, "this example only works with end >= start"
       ...         self.start = start
       ...         self.end = end
       ...
       ...     def __iter__(self):
       ...         worker_info = torch.utils.data.get_worker_info()
       ...         if worker_info is None:  # single-process data loading, return the full iterator
       ...             iter_start = self.start
       ...             iter_end = self.end
       ...         else:  # in a worker process
       ...             # split workload
       ...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
       ...             worker_id = worker_info.id
       ...             iter_start = self.start + worker_id * per_worker
       ...             iter_end = min(iter_start + per_worker, self.end)
       ...         return iter(range(iter_start, iter_end))
       ...
       >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
       >>> ds = MyIterableDataset(start=3, end=7)

       >>> # Single-process loading
       >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
       [tensor([3]), tensor([4]), tensor([5]), tensor([6])]

       >>> # xdoctest: +REQUIRES(POSIX)
       >>> # Multi-process loading with two worker processes
       >>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
       >>> # xdoctest: +IGNORE_WANT("non deterministic")
       >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
       [tensor([3]), tensor([5]), tensor([4]), tensor([6])]

       >>> # With even more workers
       >>> # xdoctest: +IGNORE_WANT("non deterministic")
       >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
       [tensor([3]), tensor([5]), tensor([4]), tensor([6])]

   Example 2: splitting workload across all workers using :attr:`worker_init_fn`::

       >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
       >>> class MyIterableDataset(torch.utils.data.IterableDataset):
       ...     def __init__(self, start, end):
       ...         super(MyIterableDataset).__init__()
       ...         assert end > start, "this example only works with end >= start"
       ...         self.start = start
       ...         self.end = end
       ...
       ...     def __iter__(self):
       ...         return iter(range(self.start, self.end))
       ...
       >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
       >>> ds = MyIterableDataset(start=3, end=7)

       >>> # Single-process loading
       >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
       [3, 4, 5, 6]
       >>>
       >>> # Directly doing multi-process loading yields duplicate data
       >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
       [3, 3, 4, 4, 5, 5, 6, 6]

       >>> # Define a `worker_init_fn` that configures each dataset copy differently
       >>> def worker_init_fn(worker_id):
       ...     worker_info = torch.utils.data.get_worker_info()
       ...     dataset = worker_info.dataset  # the dataset copy in this worker process
       ...     overall_start = dataset.start
       ...     overall_end = dataset.end
       ...     # configure the dataset to only process the split workload
       ...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
       ...     worker_id = worker_info.id
       ...     dataset.start = overall_start + worker_id * per_worker
       ...     dataset.end = min(dataset.start + per_worker, overall_end)
       ...

       >>> # Mult-process loading with the custom `worker_init_fn`
       >>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
       >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
       [3, 5, 4, 6]

       >>> # With even more workers
       >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
       [3, 4, 5, 6]


   .. py:method:: __iter__()


