credit.datasets.era5_multistep
==============================

.. py:module:: credit.datasets.era5_multistep


Attributes
----------

.. autoapisummary::

   credit.datasets.era5_multistep.logger
   credit.datasets.era5_multistep.filename


Classes
-------

.. autoapisummary::

   credit.datasets.era5_multistep.RepeatingIndexSampler
   credit.datasets.era5_multistep.ERA5_and_Forcing_MultiStep


Functions
---------

.. autoapisummary::

   credit.datasets.era5_multistep.worker


Module Contents
---------------

.. py:data:: logger

.. py:function:: worker(tuple_index: Tuple[int, int], ERA5_indices: Dict[str, List[int]], all_files: List[Any], surface_files: Optional[List[Any]], dyn_forcing_files: Optional[List[Any]], diagnostic_files: Optional[List[Any]], xarray_forcing: Optional[Any], xarray_static: Optional[Any], history_len: int, forecast_len: int, skip_periods: int, transform: Optional[Callable], sst_forcing: Optional[Any] = None) -> Dict[str, Any]

   Processes a given index to extract and transform data for a specific time slice.

   Parameters:
   - tuple_index (Tuple[int, int]): Tuple containing the current index and sub-index for processing.
   - ERA5_indices (Dict[str, List[int]]): Dictionary containing ERA5 indices metadata.
   - all_files (List[Any]): List of xarray datasets containing upper air data.
   - surface_files (Optional[List[Any]]): List of xarray datasets containing surface data.
   - dyn_forcing_files (Optional[List[Any]]): List of xarray datasets containing dynamic forcing data.
   - diagnostic_files (Optional[List[Any]]): List of xarray datasets containing diagnostic data.
   - history_len (int): Length of the history sequence.
   - forecast_len (int): Length of the forecast sequence.
   - skip_periods (int): Number of periods to skip between samples.
   - xarray_forcing (Optional[Any]): xarray dataset containing forcing data.
   - xarray_static (Optional[Any]): xarray dataset containing static data.
   - transform (Optional[Callable]): Transformation function to apply to the data.

   Returns:
   - Dict[str, Any]: A dictionary containing historical ERA5 images, target ERA5 images, datetime index, and additional information.


.. py:class:: RepeatingIndexSampler(dataset, forecast_len, skip_periods=1, shuffle=True, seed=42, rank=0, num_replicas=1)

   Bases: :py:obj:`torch.utils.data.Sampler`


   Base class for all Samplers.

   Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
   way to iterate over indices or lists of indices (batches) of dataset elements,
   and may provide a :meth:`__len__` method that returns the length of the returned iterators.

   .. rubric:: Example

   >>> # xdoctest: +SKIP
   >>> class AccedingSequenceLengthSampler(Sampler[int]):
   >>>     def __init__(self, data: List[str]) -> None:
   >>>         self.data = data
   >>>
   >>>     def __len__(self) -> int:
   >>>         return len(self.data)
   >>>
   >>>     def __iter__(self) -> Iterator[int]:
   >>>         sizes = torch.tensor([len(x) for x in self.data])
   >>>         yield from torch.argsort(sizes).tolist()
   >>>
   >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
   >>>     def __init__(self, data: List[str], batch_size: int) -> None:
   >>>         self.data = data
   >>>         self.batch_size = batch_size
   >>>
   >>>     def __len__(self) -> int:
   >>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
   >>>
   >>>     def __iter__(self) -> Iterator[List[int]]:
   >>>         sizes = torch.tensor([len(x) for x in self.data])
   >>>         for batch in torch.chunk(torch.argsort(sizes), len(self)):
   >>>             yield batch.tolist()

   .. note:: The :meth:`__len__` method isn't strictly required by
             :class:`~torch.utils.data.DataLoader`, but is expected in any
             calculation involving the length of a :class:`~torch.utils.data.DataLoader`.


   .. py:attribute:: dataset


   .. py:attribute:: forecast_len


   .. py:attribute:: skip_periods
      :value: 1



   .. py:attribute:: shuffle
      :value: True



   .. py:attribute:: seed
      :value: 42



   .. py:attribute:: rank
      :value: 0



   .. py:attribute:: num_replicas
      :value: 1



   .. py:attribute:: all_start_indices


   .. py:attribute:: num_indices_per_rank


   .. py:method:: __len__()

      Returns the total number of indices for this rank.



   .. py:method:: __iter__()

      Yields each start index repeated (forecast_len + 1) times.



   .. py:method:: batches_per_epoch()

      Computes the number of batches per epoch for a given batch size.

      Returns:
      - int: Number of batches per epoch.



.. py:class:: ERA5_and_Forcing_MultiStep(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, history_len=2, forecast_len=0, transform=None, seed=42, rank=0, world_size=1, skip_periods=None, one_shot=None, max_forecast_len=None, sst_forcing=None)

   Bases: :py:obj:`torch.utils.data.Dataset`


   A Pytorch Dataset class that works on:
       - upper-air variables (time, level, lat, lon)
       - surface variables (time, lat, lon)
       - dynamic forcing variables (time, lat, lon)
       - foring variables (time, lat, lon)
       - diagnostic variables (time, lat, lon)
       - static variables (lat, lon)


   .. py:attribute:: history_len
      :value: 2



   .. py:attribute:: forecast_len
      :value: 0



   .. py:attribute:: transform
      :value: None



   .. py:attribute:: seed
      :value: 42



   .. py:attribute:: rank
      :value: 0



   .. py:attribute:: world_size
      :value: 1



   .. py:attribute:: skip_periods
      :value: None



   .. py:attribute:: one_shot
      :value: None



   .. py:attribute:: total_seq_len
      :value: 2



   .. py:attribute:: sst_forcing
      :value: None



   .. py:attribute:: rng


   .. py:attribute:: max_forecast_len
      :value: None



   .. py:attribute:: all_files
      :value: []



   .. py:attribute:: ERA5_indices


   .. py:attribute:: filename_forcing
      :value: None



   .. py:attribute:: filename_static
      :value: None



   .. py:attribute:: worker


   .. py:attribute:: total_length
      :value: 0



   .. py:attribute:: current_epoch
      :value: None



   .. py:attribute:: forecast_step_count
      :value: 0



   .. py:attribute:: current_index
      :value: None



   .. py:attribute:: initial_index
      :value: None



   .. py:method:: __post_init__()


   .. py:method:: __len__()


   .. py:method:: set_epoch(epoch)


   .. py:method:: __getitem__(index)


.. py:data:: filename
   :value: '../../config/example-v2026.1.0.yml'


