credit.datasets.era5_multistep#
Attributes#
Classes#
Base class for all Samplers. |
|
A Pytorch Dataset class that works on: |
Functions#
|
Processes a given index to extract and transform data for a specific time slice. |
Module Contents#
- credit.datasets.era5_multistep.logger#
- credit.datasets.era5_multistep.worker(tuple_index: Tuple[int, int], ERA5_indices: Dict[str, List[int]], all_files: List[Any], surface_files: List[Any] | None, dyn_forcing_files: List[Any] | None, diagnostic_files: List[Any] | None, xarray_forcing: Any | None, xarray_static: Any | None, history_len: int, forecast_len: int, skip_periods: int, transform: Callable | None, sst_forcing: Any | None = None) Dict[str, Any]#
Processes a given index to extract and transform data for a specific time slice.
Parameters: - tuple_index (Tuple[int, int]): Tuple containing the current index and sub-index for processing. - ERA5_indices (Dict[str, List[int]]): Dictionary containing ERA5 indices metadata. - all_files (List[Any]): List of xarray datasets containing upper air data. - surface_files (Optional[List[Any]]): List of xarray datasets containing surface data. - dyn_forcing_files (Optional[List[Any]]): List of xarray datasets containing dynamic forcing data. - diagnostic_files (Optional[List[Any]]): List of xarray datasets containing diagnostic data. - history_len (int): Length of the history sequence. - forecast_len (int): Length of the forecast sequence. - skip_periods (int): Number of periods to skip between samples. - xarray_forcing (Optional[Any]): xarray dataset containing forcing data. - xarray_static (Optional[Any]): xarray dataset containing static data. - transform (Optional[Callable]): Transformation function to apply to the data.
Returns: - Dict[str, Any]: A dictionary containing historical ERA5 images, target ERA5 images, datetime index, and additional information.
- class credit.datasets.era5_multistep.RepeatingIndexSampler(dataset, forecast_len, skip_periods=1, shuffle=True, seed=42, rank=0, num_replicas=1)#
Bases:
torch.utils.data.SamplerBase class for all Samplers.
Every Sampler subclass has to provide an
__iter__()method, providing a way to iterate over indices or lists of indices (batches) of dataset elements, and may provide a__len__()method that returns the length of the returned iterators.Example
>>> # xdoctest: +SKIP >>> class AccedingSequenceLengthSampler(Sampler[int]): >>> def __init__(self, data: List[str]) -> None: >>> self.data = data >>> >>> def __len__(self) -> int: >>> return len(self.data) >>> >>> def __iter__(self) -> Iterator[int]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> yield from torch.argsort(sizes).tolist() >>> >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): >>> def __init__(self, data: List[str], batch_size: int) -> None: >>> self.data = data >>> self.batch_size = batch_size >>> >>> def __len__(self) -> int: >>> return (len(self.data) + self.batch_size - 1) // self.batch_size >>> >>> def __iter__(self) -> Iterator[List[int]]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> for batch in torch.chunk(torch.argsort(sizes), len(self)): >>> yield batch.tolist()
Note
The
__len__()method isn’t strictly required byDataLoader, but is expected in any calculation involving the length of aDataLoader.- dataset#
- forecast_len#
- skip_periods = 1#
- shuffle = True#
- seed = 42#
- rank = 0#
- num_replicas = 1#
- all_start_indices#
- num_indices_per_rank#
- __len__()#
Returns the total number of indices for this rank.
- __iter__()#
Yields each start index repeated (forecast_len + 1) times.
- batches_per_epoch()#
Computes the number of batches per epoch for a given batch size.
Returns: - int: Number of batches per epoch.
- class credit.datasets.era5_multistep.ERA5_and_Forcing_MultiStep(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, history_len=2, forecast_len=0, transform=None, seed=42, rank=0, world_size=1, skip_periods=None, one_shot=None, max_forecast_len=None, sst_forcing=None)#
Bases:
torch.utils.data.Dataset- A Pytorch Dataset class that works on:
upper-air variables (time, level, lat, lon)
surface variables (time, lat, lon)
dynamic forcing variables (time, lat, lon)
foring variables (time, lat, lon)
diagnostic variables (time, lat, lon)
static variables (lat, lon)
- history_len = 2#
- forecast_len = 0#
- transform = None#
- seed = 42#
- rank = 0#
- world_size = 1#
- skip_periods = None#
- one_shot = None#
- total_seq_len = 2#
- sst_forcing = None#
- rng#
- max_forecast_len = None#
- all_files = []#
- ERA5_indices#
- filename_forcing = None#
- filename_static = None#
- worker#
- total_length = 0#
- current_epoch = None#
- forecast_step_count = 0#
- current_index = None#
- initial_index = None#
- __post_init__()#
- __len__()#
- set_epoch(epoch)#
- __getitem__(index)#
- credit.datasets.era5_multistep.filename = '../../config/example-v2026.1.0.yml'#