credit.datasets.sequential_multistep#

Pytorch IterableDataset for multi-step training

Reference:

Non-daemonic Python pool process https://stackoverflow.com/questions/6974695/python-process-pool-non-daemonic

Pytorch Iterable Dataset https://colab.research.google.com/drive/1OFLZnX9y5QUFNONuvFsxOizq4M-tFvk-?usp=sharing#scrollTo=CxSCQPOMHgwo

Attributes#

Classes#

DistributedSequentialDataset

An iterable Dataset.

DistributedSequentialDatasetBasic

An iterable Dataset.

Functions#

worker(→ Dict[str, Any])

Processes a given index to extract and transform data for a specific time slice.

Module Contents#

credit.datasets.sequential_multistep.logger#
class credit.datasets.sequential_multistep.DistributedSequentialDataset(varname_upper_air: List[str], varname_surface: List[str], varname_dyn_forcing: List[str], varname_forcing: List[str], varname_static: List[str], varname_diagnostic: List[str], filenames: List[str], filename_surface: List[str] | None = None, filename_dyn_forcing: List[str] | None = None, filename_forcing: str | None = None, filename_static: str | None = None, filename_diagnostic: List[str] | None = None, rank: int = 0, world_size: int = 1, history_len: int = 2, forecast_len: int = 0, transform: Callable | None = None, seed: int = 42, skip_periods: int | None = None, max_forecast_len: int | None = None, shuffle: bool = True, num_workers: int = 0)#

Bases: torch.utils.data.IterableDataset

An iterable Dataset.

All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream.

All subclasses should overwrite __iter__(), which would return an iterator of samples in this dataset.

When a subclass is used with DataLoader, each item in the dataset will be yielded from the DataLoader iterator. When num_workers > 0, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. get_worker_info(), when called in a worker process, returns information about the worker. It can be used in either the dataset’s __iter__() method or the DataLoader ‘s worker_init_fn option to modify each copy’s behavior.

Example 1: splitting workload across all workers in __iter__():

>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
>>> # xdoctest: +SKIP("Fails on MacOS12")
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[tensor([3]), tensor([4]), tensor([5]), tensor([6])]

>>> # xdoctest: +REQUIRES(POSIX)
>>> # Multi-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> # xdoctest: +IGNORE_WANT("non deterministic")
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

>>> # With even more workers
>>> # xdoctest: +IGNORE_WANT("non deterministic")
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

Example 2: splitting workload across all workers using worker_init_fn:

>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
history_len = 2#
forecast_len = 0#
transform = None#
rank = 0#
world_size = 1#
shuffle = True#
current_epoch = 0#
num_workers = 0#
skip_periods = None#
total_seq_len = 2#
rng#
max_forecast_len = None#
all_files = []#
ERA5_indices#
filename_forcing = None#
filename_static = None#
__post_init__()#
__len__() int#
set_epoch(epoch: int) None#
__iter__()#
credit.datasets.sequential_multistep.worker(tuple_index: Tuple[int, int], ERA5_indices: Dict[str, List[int]], all_files: List[Any], surface_files: List[Any] | None, dyn_forcing_files: List[Any] | None, diagnostic_files: List[Any] | None, xarray_forcing: Any | None, xarray_static: Any | None, history_len: int, forecast_len: int, skip_periods: int, transform: Callable | None) Dict[str, Any]#

Processes a given index to extract and transform data for a specific time slice.

Parameters: - tuple_index (Tuple[int, int]): Tuple containing the current index and sub-index for processing. - ERA5_indices (Dict[str, List[int]]): Dictionary containing ERA5 indices metadata. - all_files (List[Any]): List of xarray datasets containing upper air data. - surface_files (Optional[List[Any]]): List of xarray datasets containing surface data. - dyn_forcing_files (Optional[List[Any]]): List of xarray datasets containing dynamic forcing data. - diagnostic_files (Optional[List[Any]]): List of xarray datasets containing diagnostic data. - history_len (int): Length of the history sequence. - forecast_len (int): Length of the forecast sequence. - skip_periods (int): Number of periods to skip between samples. - xarray_forcing (Optional[Any]): xarray dataset containing forcing data. - xarray_static (Optional[Any]): xarray dataset containing static data.

  • transform (Optional[Callable]): Transformation function to apply to the data.

Returns: - Dict[str, Any]: A dictionary containing historical ERA5 images, target ERA5 images, datetime index, and additional information.

class credit.datasets.sequential_multistep.DistributedSequentialDatasetBasic(varname_upper_air: List[str], varname_surface: List[str], varname_dyn_forcing: List[str], varname_forcing: List[str], varname_static: List[str], varname_diagnostic: List[str], filenames: List[str], filename_surface: List[str] | None = None, filename_dyn_forcing: List[str] | None = None, filename_forcing: str | None = None, filename_static: str | None = None, filename_diagnostic: List[str] | None = None, rank: int = 0, world_size: int = 1, history_len: int = 2, forecast_len: int = 0, transform: Callable | None = None, seed: int = 42, skip_periods: int | None = None, max_forecast_len: int | None = None, shuffle: bool = True, num_workers: int = 0)#

Bases: DistributedSequentialDataset

An iterable Dataset.

All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream.

All subclasses should overwrite __iter__(), which would return an iterator of samples in this dataset.

When a subclass is used with DataLoader, each item in the dataset will be yielded from the DataLoader iterator. When num_workers > 0, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. get_worker_info(), when called in a worker process, returns information about the worker. It can be used in either the dataset’s __iter__() method or the DataLoader ‘s worker_init_fn option to modify each copy’s behavior.

Example 1: splitting workload across all workers in __iter__():

>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
>>> # xdoctest: +SKIP("Fails on MacOS12")
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[tensor([3]), tensor([4]), tensor([5]), tensor([6])]

>>> # xdoctest: +REQUIRES(POSIX)
>>> # Multi-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> # xdoctest: +IGNORE_WANT("non deterministic")
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

>>> # With even more workers
>>> # xdoctest: +IGNORE_WANT("non deterministic")
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

Example 2: splitting workload across all workers using worker_init_fn:

>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER)
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
__iter__()#