credit.datasets.wrfmultistep
============================

.. py:module:: credit.datasets.wrfmultistep


Attributes
----------

.. autoapisummary::

   credit.datasets.wrfmultistep.logger


Classes
-------

.. autoapisummary::

   credit.datasets.wrfmultistep.RepeatingIndexSampler
   credit.datasets.wrfmultistep.WRFMultiStep


Functions
---------

.. autoapisummary::

   credit.datasets.wrfmultistep.worker


Module Contents
---------------

.. py:data:: logger

.. py:function:: worker(tuple_index: Tuple[int, int], WRF_file_indices: Dict[str, List[int]], list_upper_ds: List[Any], list_surf_ds: Optional[List[Any]], list_dyn_forcing_ds: Optional[List[Any]], list_diag_ds: Optional[List[Any]], xarray_forcing: Optional[Any], xarray_static: Optional[Any], history_len: int, forecast_len: int, list_upper_ds_outside: Optional[List[Any]], list_surf_ds_outside: Optional[List[Any]], outside_file_year_range: Optional[List[Any]], outside_file_indices: Optional[List[Any]], history_len_outside: int, transform: Optional[Callable]) -> Dict[str, Any]

.. py:class:: RepeatingIndexSampler(dataset, forecast_len, shuffle=True, seed=42, rank=0, num_replicas=1)

   Bases: :py:obj:`torch.utils.data.Sampler`


   Base class for all Samplers.

   Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
   way to iterate over indices or lists of indices (batches) of dataset elements,
   and may provide a :meth:`__len__` method that returns the length of the returned iterators.

   .. rubric:: Example

   >>> # xdoctest: +SKIP
   >>> class AccedingSequenceLengthSampler(Sampler[int]):
   >>>     def __init__(self, data: List[str]) -> None:
   >>>         self.data = data
   >>>
   >>>     def __len__(self) -> int:
   >>>         return len(self.data)
   >>>
   >>>     def __iter__(self) -> Iterator[int]:
   >>>         sizes = torch.tensor([len(x) for x in self.data])
   >>>         yield from torch.argsort(sizes).tolist()
   >>>
   >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
   >>>     def __init__(self, data: List[str], batch_size: int) -> None:
   >>>         self.data = data
   >>>         self.batch_size = batch_size
   >>>
   >>>     def __len__(self) -> int:
   >>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
   >>>
   >>>     def __iter__(self) -> Iterator[List[int]]:
   >>>         sizes = torch.tensor([len(x) for x in self.data])
   >>>         for batch in torch.chunk(torch.argsort(sizes), len(self)):
   >>>             yield batch.tolist()

   .. note:: The :meth:`__len__` method isn't strictly required by
             :class:`~torch.utils.data.DataLoader`, but is expected in any
             calculation involving the length of a :class:`~torch.utils.data.DataLoader`.


   .. py:attribute:: dataset


   .. py:attribute:: forecast_len


   .. py:attribute:: shuffle
      :value: True



   .. py:attribute:: seed
      :value: 42



   .. py:attribute:: rank
      :value: 0



   .. py:attribute:: num_replicas
      :value: 1



   .. py:attribute:: all_start_indices


   .. py:attribute:: num_indices_per_rank


   .. py:method:: __len__()

      Returns the total number of indices for this rank.



   .. py:method:: __iter__()

      Yields each start index repeated (forecast_len + 1) times.



   .. py:method:: batches_per_epoch()

      Computes the number of batches per epoch for a given batch size.

      Returns:
      - int: Number of batches per epoch.



.. py:class:: WRFMultiStep(param_interior, param_outside, transform=None, seed=42, rank=0, world_size=1, max_forecast_len=None)

   Bases: :py:obj:`torch.utils.data.Dataset`


   An abstract class representing a :class:`Dataset`.

   All datasets that represent a map from keys to data samples should subclass
   it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
   data sample for a given key. Subclasses could also optionally overwrite
   :meth:`__len__`, which is expected to return the size of the dataset by many
   :class:`~torch.utils.data.Sampler` implementations and the default options
   of :class:`~torch.utils.data.DataLoader`. Subclasses could also
   optionally implement :meth:`__getitems__`, for speedup batched samples
   loading. This method accepts list of indices of samples of batch and returns
   list of samples.

   .. note::
     :class:`~torch.utils.data.DataLoader` by default constructs an index
     sampler that yields integral indices.  To make it work with a map-style
     dataset with non-integral indices/keys, a custom sampler must be provided.


   .. py:attribute:: list_upper_ds


   .. py:attribute:: list_surf_ds
      :value: []



   .. py:attribute:: list_dyn_forcing_ds
      :value: []



   .. py:attribute:: list_diag_ds
      :value: []



   .. py:attribute:: history_len


   .. py:attribute:: forecast_len


   .. py:attribute:: seed
      :value: 42



   .. py:attribute:: rank
      :value: 0



   .. py:attribute:: world_size
      :value: 1



   .. py:attribute:: total_seq_len


   .. py:attribute:: WRF_file_indices


   .. py:attribute:: filename_forcing


   .. py:attribute:: filename_static


   .. py:attribute:: list_upper_ds_outside
      :value: []



   .. py:attribute:: list_surf_ds_outside
      :value: []



   .. py:attribute:: history_len_outside


   .. py:attribute:: forecast_len_outside


   .. py:attribute:: outside_file_year_range


   .. py:attribute:: outside_file_indices


   .. py:attribute:: transform
      :value: None



   .. py:attribute:: rng


   .. py:attribute:: max_forecast_len
      :value: None



   .. py:attribute:: worker


   .. py:attribute:: total_length
      :value: 0



   .. py:attribute:: current_epoch
      :value: None



   .. py:attribute:: forecast_step_count
      :value: 0



   .. py:attribute:: current_index
      :value: None



   .. py:attribute:: initial_index
      :value: None



   .. py:method:: __post_init__()


   .. py:method:: __len__()


   .. py:method:: set_epoch(epoch)


   .. py:method:: __getitem__(index)


