credit.datasets.era5_multistep_batcher

Contents

credit.datasets.era5_multistep_batcher#

Attributes#

Classes#

ERA5_MultiStep_Batcher

A Pytorch Dataset class that works on:

MultiprocessingBatcher

A Pytorch Dataset class that works on:

MultiprocessingBatcherPrefetch

A Pytorch Dataset class that works on:

Predict_Dataset_Batcher

A Pytorch Dataset class that works on:

Module Contents#

credit.datasets.era5_multistep_batcher.logger#
class credit.datasets.era5_multistep_batcher.ERA5_MultiStep_Batcher(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, sst_forcing=None, history_len=2, forecast_len=0, transform=None, seed=42, rank=0, world_size=1, skip_periods=None, max_forecast_len=None, batch_size=1, shuffle=True)#

Bases: torch.utils.data.Dataset

A Pytorch Dataset class that works on:
  • upper-air variables (time, level, lat, lon)

  • surface variables (time, lat, lon)

  • dynamic forcing variables (time, lat, lon)

  • foring variables (time, lat, lon)

  • diagnostic variables (time, lat, lon)

  • static variables (lat, lon)

history_len = 2#
forecast_len = 0#
transform = None#
seed = 42#
rank = 0#
world_size = 1#
shuffle = True#
skip_periods = None#
total_seq_len = 2#
rng#
max_forecast_len = None#
sst_forcing = None#
all_files = []#
ERA5_indices#
filename_forcing = None#
filename_static = None#
worker#
current_epoch = None#
sampler#
batch_size = 1#
batch_indices = None#
time_steps = None#
forecast_step_counts = None#
initialize_batch()#

Initializes batch indices using DistributedSampler’s indices. Ensures proper cycling when shuffle=False.

__post_init__()#
__len__()#
set_epoch(epoch)#
batches_per_epoch()#
__getitem__(_)#

Fetches the current forecast step data for each item in the batch. Resets items when their forecast length is exceeded.

class credit.datasets.era5_multistep_batcher.MultiprocessingBatcher(*args, num_workers=4, **kwargs)#

Bases: ERA5_MultiStep_Batcher

A Pytorch Dataset class that works on:
  • upper-air variables (time, level, lat, lon)

  • surface variables (time, lat, lon)

  • dynamic forcing variables (time, lat, lon)

  • foring variables (time, lat, lon)

  • diagnostic variables (time, lat, lon)

  • static variables (lat, lon)

num_workers = 4#
manager#
results#
__getitem__(_)#

Fetches the current forecast step data for each item in the batch. Utilizes multiprocessing to parallelize calls to self.worker. Ensures the results are returned in the correct order.

__del__()#

Cleanup the manager when the object is destroyed

class credit.datasets.era5_multistep_batcher.MultiprocessingBatcherPrefetch(*args, num_workers=4, prefetch_factor=4, **kwargs)#

Bases: ERA5_MultiStep_Batcher

A Pytorch Dataset class that works on:
  • upper-air variables (time, level, lat, lon)

  • surface variables (time, lat, lon)

  • dynamic forcing variables (time, lat, lon)

  • foring variables (time, lat, lon)

  • diagnostic variables (time, lat, lon)

  • static variables (lat, lon)

num_workers = 4#
prefetch_factor = 4#
prefetch_queue#
stop_signal#
manager#
results#
stop_event#
prefetch_thread = None#
handle_signal(signum, frame)#
set_epoch(epoch)#
prefetch_batches()#

Prefetch batches asynchronously and store them in a queue. Stops when the stop_signal is set.

worker_process(k, index_pair, result_dict)#

Worker function that processes individual tasks, with error handling for specific exceptions.

_fetch_batch()#

Fetches a batch using multiprocessing workers and splits the work efficiently.

_process_chunk(task_chunk, result_dict)#

Process a chunk of tasks and update the shared results dictionary.

__getitem__(_)#

Get a batch from the prefetch queue.

__del__()#

Cleanup processes and threads when the object is destroyed.

__enter__()#
__exit__(exc_type, exc_val, exc_tb)#
class credit.datasets.era5_multistep_batcher.Predict_Dataset_Batcher(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, sst_forcing=None, fcst_datetime=None, lead_time_periods=6, history_len=1, transform=None, seed=42, rank=0, world_size=1, skip_periods=None, batch_size=1, skip_target=False)#

Bases: torch.utils.data.Dataset

A Pytorch Dataset class that works on:
  • upper-air variables (time, level, lat, lon)

  • surface variables (time, lat, lon)

  • dynamic forcing variables (time, lat, lon)

  • foring variables (time, lat, lon)

  • diagnostic variables (time, lat, lon)

  • static variables (lat, lon)

history_len = 1#
transform = None#
init_datetime = None#
lead_time_periods = 6#
seed = 42#
rank = 0#
world_size = 1#
batch_size = 1#
skip_target = False#
skip_periods = None#
rng#
sst_forcing = None#
all_files = []#
filenames#
filename_surface = None#
filename_dyn_forcing = None#
filename_forcing = None#
filename_static = None#
filename_diagnostic = None#
varname_upper_air#
varname_surface#
varname_dyn_forcing#
varname_forcing#
varname_static#
varname_diagnostic#
ERA5_indices#
forecast_period = 0#
forecast_len = -1#
batch_indices#
batch_indices_splits = []#
batch_call_count = 0#
data_lookup = None#
__len__()#
ds_read_and_subset(filename, time_start, time_end, varnames)#
get_time_variable(filename, time_start, time_end) xarray.Dataset#

Open NetCDF or Zarr file and return only the time variable.

load_zarr_as_input(i_file, i_init_start, i_init_end, mode='input')#
find_start_stop_indices(index)#
initialize_batch()#

Initializes batch indices using DistributedSampler’s indices. Ensures proper cycling when shuffle=False.

batches_per_epoch()#
__getitem__(_)#
credit.datasets.era5_multistep_batcher.option#