credit.datasets.era5_multistep_batcher#
Attributes#
Classes#
A Pytorch Dataset class that works on: |
|
A Pytorch Dataset class that works on: |
|
A Pytorch Dataset class that works on: |
|
A Pytorch Dataset class that works on: |
Module Contents#
- credit.datasets.era5_multistep_batcher.logger#
- class credit.datasets.era5_multistep_batcher.ERA5_MultiStep_Batcher(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, sst_forcing=None, history_len=2, forecast_len=0, transform=None, seed=42, rank=0, world_size=1, skip_periods=None, max_forecast_len=None, batch_size=1, shuffle=True)#
Bases:
torch.utils.data.Dataset- A Pytorch Dataset class that works on:
upper-air variables (time, level, lat, lon)
surface variables (time, lat, lon)
dynamic forcing variables (time, lat, lon)
foring variables (time, lat, lon)
diagnostic variables (time, lat, lon)
static variables (lat, lon)
- history_len = 2#
- forecast_len = 0#
- transform = None#
- seed = 42#
- rank = 0#
- world_size = 1#
- shuffle = True#
- skip_periods = None#
- total_seq_len = 2#
- rng#
- max_forecast_len = None#
- sst_forcing = None#
- all_files = []#
- ERA5_indices#
- filename_forcing = None#
- filename_static = None#
- worker#
- current_epoch = None#
- sampler#
- batch_size = 1#
- batch_indices = None#
- time_steps = None#
- forecast_step_counts = None#
- initialize_batch()#
Initializes batch indices using DistributedSampler’s indices. Ensures proper cycling when shuffle=False.
- __post_init__()#
- __len__()#
- set_epoch(epoch)#
- batches_per_epoch()#
- __getitem__(_)#
Fetches the current forecast step data for each item in the batch. Resets items when their forecast length is exceeded.
- class credit.datasets.era5_multistep_batcher.MultiprocessingBatcher(*args, num_workers=4, **kwargs)#
Bases:
ERA5_MultiStep_Batcher- A Pytorch Dataset class that works on:
upper-air variables (time, level, lat, lon)
surface variables (time, lat, lon)
dynamic forcing variables (time, lat, lon)
foring variables (time, lat, lon)
diagnostic variables (time, lat, lon)
static variables (lat, lon)
- num_workers = 4#
- manager#
- results#
- __getitem__(_)#
Fetches the current forecast step data for each item in the batch. Utilizes multiprocessing to parallelize calls to self.worker. Ensures the results are returned in the correct order.
- __del__()#
Cleanup the manager when the object is destroyed
- class credit.datasets.era5_multistep_batcher.MultiprocessingBatcherPrefetch(*args, num_workers=4, prefetch_factor=4, **kwargs)#
Bases:
ERA5_MultiStep_Batcher- A Pytorch Dataset class that works on:
upper-air variables (time, level, lat, lon)
surface variables (time, lat, lon)
dynamic forcing variables (time, lat, lon)
foring variables (time, lat, lon)
diagnostic variables (time, lat, lon)
static variables (lat, lon)
- num_workers = 4#
- prefetch_factor = 4#
- prefetch_queue#
- stop_signal#
- manager#
- results#
- stop_event#
- prefetch_thread = None#
- handle_signal(signum, frame)#
- set_epoch(epoch)#
- prefetch_batches()#
Prefetch batches asynchronously and store them in a queue. Stops when the stop_signal is set.
- worker_process(k, index_pair, result_dict)#
Worker function that processes individual tasks, with error handling for specific exceptions.
- _fetch_batch()#
Fetches a batch using multiprocessing workers and splits the work efficiently.
- _process_chunk(task_chunk, result_dict)#
Process a chunk of tasks and update the shared results dictionary.
- __getitem__(_)#
Get a batch from the prefetch queue.
- __del__()#
Cleanup processes and threads when the object is destroyed.
- __enter__()#
- __exit__(exc_type, exc_val, exc_tb)#
- class credit.datasets.era5_multistep_batcher.Predict_Dataset_Batcher(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, sst_forcing=None, fcst_datetime=None, lead_time_periods=6, history_len=1, transform=None, seed=42, rank=0, world_size=1, skip_periods=None, batch_size=1, skip_target=False)#
Bases:
torch.utils.data.Dataset- A Pytorch Dataset class that works on:
upper-air variables (time, level, lat, lon)
surface variables (time, lat, lon)
dynamic forcing variables (time, lat, lon)
foring variables (time, lat, lon)
diagnostic variables (time, lat, lon)
static variables (lat, lon)
- history_len = 1#
- transform = None#
- init_datetime = None#
- lead_time_periods = 6#
- seed = 42#
- rank = 0#
- world_size = 1#
- batch_size = 1#
- skip_target = False#
- skip_periods = None#
- rng#
- sst_forcing = None#
- all_files = []#
- filenames#
- filename_surface = None#
- filename_dyn_forcing = None#
- filename_forcing = None#
- filename_static = None#
- filename_diagnostic = None#
- varname_upper_air#
- varname_surface#
- varname_dyn_forcing#
- varname_forcing#
- varname_static#
- varname_diagnostic#
- ERA5_indices#
- forecast_period = 0#
- forecast_len = -1#
- batch_indices#
- batch_indices_splits = []#
- batch_call_count = 0#
- data_lookup = None#
- __len__()#
- ds_read_and_subset(filename, time_start, time_end, varnames)#
- get_time_variable(filename, time_start, time_end) xarray.Dataset#
Open NetCDF or Zarr file and return only the time variable.
- load_zarr_as_input(i_file, i_init_start, i_init_end, mode='input')#
- find_start_stop_indices(index)#
- initialize_batch()#
Initializes batch indices using DistributedSampler’s indices. Ensures proper cycling when shuffle=False.
- batches_per_epoch()#
- __getitem__(_)#
- credit.datasets.era5_multistep_batcher.option#