credit.datasets.era5_multistep

Contents

credit.datasets.era5_multistep#

Attributes#

Classes#

RepeatingIndexSampler

Base class for all Samplers.

ERA5_and_Forcing_MultiStep

A Pytorch Dataset class that works on:

Functions#

worker(→ Dict[str, Any])

Processes a given index to extract and transform data for a specific time slice.

Module Contents#

credit.datasets.era5_multistep.logger#
credit.datasets.era5_multistep.worker(tuple_index: Tuple[int, int], ERA5_indices: Dict[str, List[int]], all_files: List[Any], surface_files: List[Any] | None, dyn_forcing_files: List[Any] | None, diagnostic_files: List[Any] | None, xarray_forcing: Any | None, xarray_static: Any | None, history_len: int, forecast_len: int, skip_periods: int, transform: Callable | None, sst_forcing: Any | None = None) Dict[str, Any]#

Processes a given index to extract and transform data for a specific time slice.

Parameters: - tuple_index (Tuple[int, int]): Tuple containing the current index and sub-index for processing. - ERA5_indices (Dict[str, List[int]]): Dictionary containing ERA5 indices metadata. - all_files (List[Any]): List of xarray datasets containing upper air data. - surface_files (Optional[List[Any]]): List of xarray datasets containing surface data. - dyn_forcing_files (Optional[List[Any]]): List of xarray datasets containing dynamic forcing data. - diagnostic_files (Optional[List[Any]]): List of xarray datasets containing diagnostic data. - history_len (int): Length of the history sequence. - forecast_len (int): Length of the forecast sequence. - skip_periods (int): Number of periods to skip between samples. - xarray_forcing (Optional[Any]): xarray dataset containing forcing data. - xarray_static (Optional[Any]): xarray dataset containing static data. - transform (Optional[Callable]): Transformation function to apply to the data.

Returns: - Dict[str, Any]: A dictionary containing historical ERA5 images, target ERA5 images, datetime index, and additional information.

class credit.datasets.era5_multistep.RepeatingIndexSampler(dataset, forecast_len, skip_periods=1, shuffle=True, seed=42, rank=0, num_replicas=1)#

Bases: torch.utils.data.Sampler

Base class for all Samplers.

Every Sampler subclass has to provide an __iter__() method, providing a way to iterate over indices or lists of indices (batches) of dataset elements, and may provide a __len__() method that returns the length of the returned iterators.

Example

>>> # xdoctest: +SKIP
>>> class AccedingSequenceLengthSampler(Sampler[int]):
>>>     def __init__(self, data: List[str]) -> None:
>>>         self.data = data
>>>
>>>     def __len__(self) -> int:
>>>         return len(self.data)
>>>
>>>     def __iter__(self) -> Iterator[int]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         yield from torch.argsort(sizes).tolist()
>>>
>>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
>>>     def __init__(self, data: List[str], batch_size: int) -> None:
>>>         self.data = data
>>>         self.batch_size = batch_size
>>>
>>>     def __len__(self) -> int:
>>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
>>>
>>>     def __iter__(self) -> Iterator[List[int]]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         for batch in torch.chunk(torch.argsort(sizes), len(self)):
>>>             yield batch.tolist()

Note

The __len__() method isn’t strictly required by DataLoader, but is expected in any calculation involving the length of a DataLoader.

dataset#
forecast_len#
skip_periods = 1#
shuffle = True#
seed = 42#
rank = 0#
num_replicas = 1#
all_start_indices#
num_indices_per_rank#
__len__()#

Returns the total number of indices for this rank.

__iter__()#

Yields each start index repeated (forecast_len + 1) times.

batches_per_epoch()#

Computes the number of batches per epoch for a given batch size.

Returns: - int: Number of batches per epoch.

class credit.datasets.era5_multistep.ERA5_and_Forcing_MultiStep(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, history_len=2, forecast_len=0, transform=None, seed=42, rank=0, world_size=1, skip_periods=None, one_shot=None, max_forecast_len=None, sst_forcing=None)#

Bases: torch.utils.data.Dataset

A Pytorch Dataset class that works on:
  • upper-air variables (time, level, lat, lon)

  • surface variables (time, lat, lon)

  • dynamic forcing variables (time, lat, lon)

  • foring variables (time, lat, lon)

  • diagnostic variables (time, lat, lon)

  • static variables (lat, lon)

history_len = 2#
forecast_len = 0#
transform = None#
seed = 42#
rank = 0#
world_size = 1#
skip_periods = None#
one_shot = None#
total_seq_len = 2#
sst_forcing = None#
rng#
max_forecast_len = None#
all_files = []#
ERA5_indices#
filename_forcing = None#
filename_static = None#
worker#
total_length = 0#
current_epoch = None#
forecast_step_count = 0#
current_index = None#
initial_index = None#
__post_init__()#
__len__()#
set_epoch(epoch)#
__getitem__(index)#
credit.datasets.era5_multistep.filename = '../../config/example-v2026.1.0.yml'#