credit.datasets.les_singlestep#
- Content:
LESDataset
LESPredict
Classes#
LES model Pytorch Dataset class |
|
An iterable Dataset. |
Module Contents#
- class credit.datasets.les_singlestep.LESDataset(param_interior, transform=None, seed=42)#
Bases:
torch.utils.data.DatasetLES model Pytorch Dataset class
- list_upper_ds#
- list_surf_ds = []#
- list_dyn_forcing_ds = []#
- list_diag_ds = []#
- history_len#
- forecast_len#
- total_seq_len#
- LES_file_indices#
- filename_forcing#
- filename_static#
- transform = None#
- size_list#
- size_full#
- rng#
- total_len = 0#
- __post_init__()#
- __len__()#
- __getitem__(index)#
- class credit.datasets.les_singlestep.LESPredict(param_interior, data_lookup, rank, world_size, transform=None)#
Bases:
torch.utils.data.IterableDatasetAn iterable Dataset.
All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream.
All subclasses should overwrite
__iter__(), which would return an iterator of samples in this dataset.When a subclass is used with
DataLoader, each item in the dataset will be yielded from theDataLoaderiterator. Whennum_workers > 0, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers.get_worker_info(), when called in a worker process, returns information about the worker. It can be used in either the dataset’s__iter__()method or theDataLoader‘sworker_init_fnoption to modify each copy’s behavior.Example 1: splitting workload across all workers in
__iter__():>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) >>> # xdoctest: +SKIP("Fails on MacOS12") >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [tensor([3]), tensor([4]), tensor([5]), tensor([6])] >>> # xdoctest: +REQUIRES(POSIX) >>> # Multi-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> # xdoctest: +IGNORE_WANT("non deterministic") >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] >>> # With even more workers >>> # xdoctest: +IGNORE_WANT("non deterministic") >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])]
Example 2: splitting workload across all workers using
worker_init_fn:>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # Define a `worker_init_fn` that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # Mult-process loading with the custom `worker_init_fn` >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) [3, 4, 5, 6]
- list_upper_ds#
- list_surf_ds = []#
- list_dyn_forcing_ds = []#
- list_diag_ds = []#
- filenames#
- filename_surface#
- filename_dyn_forcing#
- filename_forcing#
- filename_static#
- filename_diagnostic#
- rank#
- world_size#
- transform = None#
- history_len#
- data_lookup#
- load_zarr_as_input(i_file, i_init_start, i_init_end, mode='input')#
- __len__()#
- __iter__()#