credit.datasets.wrfmultistep#
Attributes#
Classes#
Base class for all Samplers. |
|
An abstract class representing a |
Functions#
|
Module Contents#
- credit.datasets.wrfmultistep.logger#
- credit.datasets.wrfmultistep.worker(tuple_index: Tuple[int, int], WRF_file_indices: Dict[str, List[int]], list_upper_ds: List[Any], list_surf_ds: List[Any] | None, list_dyn_forcing_ds: List[Any] | None, list_diag_ds: List[Any] | None, xarray_forcing: Any | None, xarray_static: Any | None, history_len: int, forecast_len: int, list_upper_ds_outside: List[Any] | None, list_surf_ds_outside: List[Any] | None, outside_file_year_range: List[Any] | None, outside_file_indices: List[Any] | None, history_len_outside: int, transform: Callable | None) Dict[str, Any]#
- class credit.datasets.wrfmultistep.RepeatingIndexSampler(dataset, forecast_len, shuffle=True, seed=42, rank=0, num_replicas=1)#
Bases:
torch.utils.data.SamplerBase class for all Samplers.
Every Sampler subclass has to provide an
__iter__()method, providing a way to iterate over indices or lists of indices (batches) of dataset elements, and may provide a__len__()method that returns the length of the returned iterators.Example
>>> # xdoctest: +SKIP >>> class AccedingSequenceLengthSampler(Sampler[int]): >>> def __init__(self, data: List[str]) -> None: >>> self.data = data >>> >>> def __len__(self) -> int: >>> return len(self.data) >>> >>> def __iter__(self) -> Iterator[int]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> yield from torch.argsort(sizes).tolist() >>> >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): >>> def __init__(self, data: List[str], batch_size: int) -> None: >>> self.data = data >>> self.batch_size = batch_size >>> >>> def __len__(self) -> int: >>> return (len(self.data) + self.batch_size - 1) // self.batch_size >>> >>> def __iter__(self) -> Iterator[List[int]]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> for batch in torch.chunk(torch.argsort(sizes), len(self)): >>> yield batch.tolist()
Note
The
__len__()method isn’t strictly required byDataLoader, but is expected in any calculation involving the length of aDataLoader.- dataset#
- forecast_len#
- shuffle = True#
- seed = 42#
- rank = 0#
- num_replicas = 1#
- all_start_indices#
- num_indices_per_rank#
- __len__()#
Returns the total number of indices for this rank.
- __iter__()#
Yields each start index repeated (forecast_len + 1) times.
- batches_per_epoch()#
Computes the number of batches per epoch for a given batch size.
Returns: - int: Number of batches per epoch.
- class credit.datasets.wrfmultistep.WRFMultiStep(param_interior, param_outside, transform=None, seed=42, rank=0, world_size=1, max_forecast_len=None)#
Bases:
torch.utils.data.DatasetAn abstract class representing a
Dataset.All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite
__getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite__len__(), which is expected to return the size of the dataset by manySamplerimplementations and the default options ofDataLoader. Subclasses could also optionally implement__getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.Note
DataLoaderby default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.- list_upper_ds#
- list_surf_ds = []#
- list_dyn_forcing_ds = []#
- list_diag_ds = []#
- history_len#
- forecast_len#
- seed = 42#
- rank = 0#
- world_size = 1#
- total_seq_len#
- WRF_file_indices#
- filename_forcing#
- filename_static#
- list_upper_ds_outside = []#
- list_surf_ds_outside = []#
- history_len_outside#
- forecast_len_outside#
- outside_file_year_range#
- outside_file_indices#
- transform = None#
- rng#
- max_forecast_len = None#
- worker#
- total_length = 0#
- current_epoch = None#
- forecast_step_count = 0#
- current_index = None#
- initial_index = None#
- __post_init__()#
- __len__()#
- set_epoch(epoch)#
- __getitem__(index)#