credit.datasets.era5_multistep_batcher
======================================

.. py:module:: credit.datasets.era5_multistep_batcher


Attributes
----------

.. autoapisummary::

   credit.datasets.era5_multistep_batcher.logger
   credit.datasets.era5_multistep_batcher.option


Classes
-------

.. autoapisummary::

   credit.datasets.era5_multistep_batcher.ERA5_MultiStep_Batcher
   credit.datasets.era5_multistep_batcher.MultiprocessingBatcher
   credit.datasets.era5_multistep_batcher.MultiprocessingBatcherPrefetch
   credit.datasets.era5_multistep_batcher.Predict_Dataset_Batcher


Module Contents
---------------

.. py:data:: logger

.. py:class:: ERA5_MultiStep_Batcher(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, sst_forcing=None, history_len=2, forecast_len=0, transform=None, seed=42, rank=0, world_size=1, skip_periods=None, max_forecast_len=None, batch_size=1, shuffle=True)

   Bases: :py:obj:`torch.utils.data.Dataset`


   A Pytorch Dataset class that works on:
       - upper-air variables (time, level, lat, lon)
       - surface variables (time, lat, lon)
       - dynamic forcing variables (time, lat, lon)
       - foring variables (time, lat, lon)
       - diagnostic variables (time, lat, lon)
       - static variables (lat, lon)


   .. py:attribute:: history_len
      :value: 2



   .. py:attribute:: forecast_len
      :value: 0



   .. py:attribute:: transform
      :value: None



   .. py:attribute:: seed
      :value: 42



   .. py:attribute:: rank
      :value: 0



   .. py:attribute:: world_size
      :value: 1



   .. py:attribute:: shuffle
      :value: True



   .. py:attribute:: skip_periods
      :value: None



   .. py:attribute:: total_seq_len
      :value: 2



   .. py:attribute:: rng


   .. py:attribute:: max_forecast_len
      :value: None



   .. py:attribute:: sst_forcing
      :value: None



   .. py:attribute:: all_files
      :value: []



   .. py:attribute:: ERA5_indices


   .. py:attribute:: filename_forcing
      :value: None



   .. py:attribute:: filename_static
      :value: None



   .. py:attribute:: worker


   .. py:attribute:: current_epoch
      :value: None



   .. py:attribute:: sampler


   .. py:attribute:: batch_size
      :value: 1



   .. py:attribute:: batch_indices
      :value: None



   .. py:attribute:: time_steps
      :value: None



   .. py:attribute:: forecast_step_counts
      :value: None



   .. py:method:: initialize_batch()

      Initializes batch indices using DistributedSampler's indices.
      Ensures proper cycling when shuffle=False.



   .. py:method:: __post_init__()


   .. py:method:: __len__()


   .. py:method:: set_epoch(epoch)


   .. py:method:: batches_per_epoch()


   .. py:method:: __getitem__(_)

      Fetches the current forecast step data for each item in the batch.
      Resets items when their forecast length is exceeded.



.. py:class:: MultiprocessingBatcher(*args, num_workers=4, **kwargs)

   Bases: :py:obj:`ERA5_MultiStep_Batcher`


   A Pytorch Dataset class that works on:
       - upper-air variables (time, level, lat, lon)
       - surface variables (time, lat, lon)
       - dynamic forcing variables (time, lat, lon)
       - foring variables (time, lat, lon)
       - diagnostic variables (time, lat, lon)
       - static variables (lat, lon)


   .. py:attribute:: num_workers
      :value: 4



   .. py:attribute:: manager


   .. py:attribute:: results


   .. py:method:: __getitem__(_)

      Fetches the current forecast step data for each item in the batch.
      Utilizes multiprocessing to parallelize calls to `self.worker`.
      Ensures the results are returned in the correct order.



   .. py:method:: __del__()

      Cleanup the manager when the object is destroyed



.. py:class:: MultiprocessingBatcherPrefetch(*args, num_workers=4, prefetch_factor=4, **kwargs)

   Bases: :py:obj:`ERA5_MultiStep_Batcher`


   A Pytorch Dataset class that works on:
       - upper-air variables (time, level, lat, lon)
       - surface variables (time, lat, lon)
       - dynamic forcing variables (time, lat, lon)
       - foring variables (time, lat, lon)
       - diagnostic variables (time, lat, lon)
       - static variables (lat, lon)


   .. py:attribute:: num_workers
      :value: 4



   .. py:attribute:: prefetch_factor
      :value: 4



   .. py:attribute:: prefetch_queue


   .. py:attribute:: stop_signal


   .. py:attribute:: manager


   .. py:attribute:: results


   .. py:attribute:: stop_event


   .. py:attribute:: prefetch_thread
      :value: None



   .. py:method:: handle_signal(signum, frame)


   .. py:method:: set_epoch(epoch)


   .. py:method:: prefetch_batches()

      Prefetch batches asynchronously and store them in a queue.
      Stops when the `stop_signal` is set.



   .. py:method:: worker_process(k, index_pair, result_dict)

      Worker function that processes individual tasks, with error handling for specific exceptions.



   .. py:method:: _fetch_batch()

      Fetches a batch using multiprocessing workers and splits the work efficiently.



   .. py:method:: _process_chunk(task_chunk, result_dict)

      Process a chunk of tasks and update the shared results dictionary.



   .. py:method:: __getitem__(_)

      Get a batch from the prefetch queue.



   .. py:method:: __del__()

      Cleanup processes and threads when the object is destroyed.



   .. py:method:: __enter__()


   .. py:method:: __exit__(exc_type, exc_val, exc_tb)


.. py:class:: Predict_Dataset_Batcher(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, sst_forcing=None, fcst_datetime=None, lead_time_periods=6, history_len=1, transform=None, seed=42, rank=0, world_size=1, skip_periods=None, batch_size=1, skip_target=False)

   Bases: :py:obj:`torch.utils.data.Dataset`


   A Pytorch Dataset class that works on:
       - upper-air variables (time, level, lat, lon)
       - surface variables (time, lat, lon)
       - dynamic forcing variables (time, lat, lon)
       - foring variables (time, lat, lon)
       - diagnostic variables (time, lat, lon)
       - static variables (lat, lon)


   .. py:attribute:: history_len
      :value: 1



   .. py:attribute:: transform
      :value: None



   .. py:attribute:: init_datetime
      :value: None



   .. py:attribute:: lead_time_periods
      :value: 6



   .. py:attribute:: seed
      :value: 42



   .. py:attribute:: rank
      :value: 0



   .. py:attribute:: world_size
      :value: 1



   .. py:attribute:: batch_size
      :value: 1



   .. py:attribute:: skip_target
      :value: False



   .. py:attribute:: skip_periods
      :value: None



   .. py:attribute:: rng


   .. py:attribute:: sst_forcing
      :value: None



   .. py:attribute:: all_files
      :value: []



   .. py:attribute:: filenames


   .. py:attribute:: filename_surface
      :value: None



   .. py:attribute:: filename_dyn_forcing
      :value: None



   .. py:attribute:: filename_forcing
      :value: None



   .. py:attribute:: filename_static
      :value: None



   .. py:attribute:: filename_diagnostic
      :value: None



   .. py:attribute:: varname_upper_air


   .. py:attribute:: varname_surface


   .. py:attribute:: varname_dyn_forcing


   .. py:attribute:: varname_forcing


   .. py:attribute:: varname_static


   .. py:attribute:: varname_diagnostic


   .. py:attribute:: ERA5_indices


   .. py:attribute:: forecast_period
      :value: 0



   .. py:attribute:: forecast_len
      :value: -1



   .. py:attribute:: batch_indices


   .. py:attribute:: batch_indices_splits
      :value: []



   .. py:attribute:: batch_call_count
      :value: 0



   .. py:attribute:: data_lookup
      :value: None



   .. py:method:: __len__()


   .. py:method:: ds_read_and_subset(filename, time_start, time_end, varnames)


   .. py:method:: get_time_variable(filename, time_start, time_end) -> xarray.Dataset

      Open NetCDF or Zarr file and return only the time variable.



   .. py:method:: load_zarr_as_input(i_file, i_init_start, i_init_end, mode='input')


   .. py:method:: find_start_stop_indices(index)


   .. py:method:: initialize_batch()

      Initializes batch indices using DistributedSampler's indices.
      Ensures proper cycling when shuffle=False.



   .. py:method:: batches_per_epoch()


   .. py:method:: __getitem__(_)


.. py:data:: option

