credit.datasets.load_dataset_and_dataloader#

Attributes#

Classes#

Functions#

collate_fn(batch)

Custom collate function for use with the ERA5_MultiStep_Batcher dataset.

load_dataset(conf[, rank, world_size, is_train])

Load the dataset based on the configuration.

load_dataloader(conf, dataset[, rank, world_size, ...])

Load the DataLoader based on the dataset type.

Module Contents#

class credit.datasets.load_dataset_and_dataloader.BatchForecastLenSampler(dataset)#
dataset#
forecast_len#
len#
__iter__()#

Returns an iterator for the sampler.

The iterator generates a sequence of zeros with a length equal to the calculated len. This is primarily a placeholder to satisfy the interface requirements of a PyTorch Sampler.

Returns:

An iterator over a sequence of zeros.

__len__()#

Returns the length of the sampler.

The length is the total number of iterations based on the forecast length and batches per epoch from the dataset.

Returns:

The total number of iterations.

Return type:

int

class credit.datasets.load_dataset_and_dataloader.BatchForecastLenSamplerSamudra(dataset)#

Bases: BatchForecastLenSampler

forecast_len#
len#
class credit.datasets.load_dataset_and_dataloader.BatchForecastLenDataLoader(dataset, offset=1)#
dataset#
forecast_len#
__iter__()#

Iterates over the dataset.

This method directly yields samples from the dataset. The forecast

length is not explicitly handled here; it is assumed to be accounted for in the dataset’s structure or sampling.

Yields:

sample – A single sample from the dataset.

__len__()#

Returns the length of the DataLoader.

The length is determined by the forecast length and either the

dataset’s batches_per_epoch() method (if available) or the dataset’s overall length.

Returns:

The total number of samples or iterations.

Return type:

int

credit.datasets.load_dataset_and_dataloader.collate_fn(batch)#

Custom collate function for use with the ERA5_MultiStep_Batcher dataset.

This function ensures that the time and batch dimensions are not flipped

during data loading. It assumes that the dataset is structured such that the first element of the batch contains the correctly formatted data.

Parameters:

batch (list) – A list of samples from the dataset, where each sample is expected to be identically structured.

Returns:

The first element of the batch, which contains the correctly

formatted data.

Return type:

Any

credit.datasets.load_dataset_and_dataloader.load_dataset(conf, rank=0, world_size=1, is_train=True)#

Load the dataset based on the configuration.

Parameters:
  • conf (dict) – Configuration dictionary containing dataset and training parameters.

  • rank (int, optional) – Rank of the current process. Default is 0.

  • world_size (int, optional) – Number of processes participating in the job. Default is 1.

  • is_train (bool, optional) – Flag indicating whether the dataset is for training or validation. Default is True.

Returns:

The loaded dataset.

Return type:

Dataset

credit.datasets.load_dataset_and_dataloader.load_dataloader(conf, dataset, rank=0, world_size=1, is_train=True)#

Load the DataLoader based on the dataset type.

Parameters:
  • conf (dict) – Configuration dictionary containing dataloader parameters.

  • dataset (Dataset) – The dataset to be used in the DataLoader.

  • rank (int, optional) – Rank of the current process. Default is 0.

  • world_size (int, optional) – Number of processes participating in the job. Default is 1.

  • is_train (bool, optional) – Flag indicating whether the dataset is for training or validation. Default is True.

Returns:

The loaded DataLoader.

Return type:

DataLoader

credit.datasets.load_dataset_and_dataloader.dataset_id#