credit.datasets.base_dataset

Contents

credit.datasets.base_dataset#

AbstractBaseDataset and BaseDataset: A PyTorch Dataset class for: 1. Type hinting and annotations throughout CREDIT 2. Scaffolding the development of future datasets 3. Provide a minimal implementation of a Dataset for testing 4. Avoid redundant code across dataset classes

Attributes#

Classes#

AbstractBaseDataset

Abstract base dataset class based on PyTorch Dataset class for CREDIT.

BaseDataset

PyTorch Dataset class for CREDIT that enables:

Module Contents#

credit.datasets.base_dataset.VALID_FIELD_TYPES#
class credit.datasets.base_dataset.AbstractBaseDataset(data_config: dict[str, Any], return_target: bool = False)#

Bases: torch.utils.data.Dataset[Any]

Abstract base dataset class based on PyTorch Dataset class for CREDIT.

This class defines the expected methods and attributes for any dataset in CREDIT, but does not provide any implementation. The BaseDataset class inherits from this class and provides a minimal implementation. Any future dataset should inherit from either AbstractBaseDataset or BaseDataset depending on the level of functionality needed.

For generality, the inheritance is from torch.utils.data.Dataset[Any], however there may be benefits to stricter typing than Any for consistency in the get item return, especially if torch supports dataset type accelerations in future releases.

curr_source_name: str#
dataset_type: str#
dt: pandas.Timedelta#
num_forecast_steps: int#
start_datetime: pandas.Timestamp#
end_datetime: pandas.Timestamp#
datetimes: pandas.DatetimeIndex#
return_target: bool#
mode: str#
file_dict: dict[str, Any]#
var_dict: dict[str, Any]#
static_metadata: dict[str, Any]#
abstractmethod __len__() int#
abstractmethod __getitem__(args: tuple[pandas.Timestamp, int]) dict[str, Any]#
abstractmethod _build_timestamps() pandas.DatetimeIndex#
abstractmethod _get_field_name(field_type: VALID_FIELD_TYPES, dim_str: str, vname: str) str#
abstractmethod init_register_all_fields() None#
abstractmethod _register_field(field_type: VALID_FIELD_TYPES, field_config: dict[str, Any] | None) None#
abstractmethod _get_file_source(field_config: dict[str, Any]) list[tuple[pandas.Timestamp, pandas.Timestamp, str]] | bool | None#
abstractmethod _extract_field(field_type: VALID_FIELD_TYPES, t: pandas.Timestamp, sample: dict[str, Any]) None#
class credit.datasets.base_dataset.BaseDataset(data_config: dict[str, Any], return_target: bool = False)#

Bases: AbstractBaseDataset

PyTorch Dataset class for CREDIT that enables: 1. Type hinting and annotations throughout CREDIT 2. Scaffolding the development of future datasets 3. Provide a minimal implementation of a Dataset for testing

Minimal YAML config for a dataset will have the following stucture:

```yaml
data:
source:
Example_Base: # User-provided name (arbitrary key)

# PARAMETERS FOR THIS DATASET TYPE # Ex: levels: [10, 20, 30] dataset_type: “base” # Needs to match per type of dataset! variables:

prognostic: null # vars_3D: [‘T’, ‘U’, ‘V’, ‘Q’] # Your 3D variables # vars_2D: [‘SP’, ‘t2m’] # Your 2D variables # dynamic_forcing: null # vars_3D: … # vars_2D: … # static: null # vars_3D: … # vars_2D: … # diagnostic: null # vars_3D: … # vars_2D: …

# OPTIONAL: Override the clock bounds for this dataset start_datetime: “2012-04-03T00:00Z”

# <YourName2>_<DatasetType2>: # Multiple datasets (see multi_source)

# These parameters set the overall clock of the sampler start_datetime: “2000-01-01T00:00:00Z” # The earliest datetime across datasets end_datetime: “2020-12-31T23:00:00Z” # The latest datetime across datasets timestep: “12h” # The smallest time interval for the clock forecast_len: 1 # The number of timesteps forward that need to be rolled out per sample

```

curr_source_name#
curr_source_cfg#
dt: pandas.Timedelta#
num_forecast_steps: int#
start_datetime: pandas.Timestamp#
end_datetime: pandas.Timestamp#
datetimes: pandas.DatetimeIndex#
return_target: bool = False#
mode = 'local'#
temporal_mode: str#
_persist_cache: dict#
static_metadata: dict[str, Any]#
file_dict: dict[str, Any]#
var_dict: dict[str, Any]#
__len__() int#

For a CREDIT dataset, the length is the number of unique datetimes that can be sampled from.

Returns:

Dataset length

Return type:

int

__getitem__(args: tuple[pandas.Timestamp, int]) dict[str, Any]#

Return a nested input/target sample dict.

When temporal_mode == "persist", the requested timestamp t is snapped to the last native timestamp at-or-before t via pd.DatetimeIndex.asof(). The result is cached so that multiple fine-resolution master-clock ticks within the same native interval only trigger a single file read.

Parameters:

args(t, i) where t is the current timestamp (nanoseconds or pd.Timestamp) and i is the within-sequence step index produced by the sampler. When i == 0 prognostic and static fields are loaded in addition to dynamic forcing.

Returns:

Dict with keys "input", "metadata", and optionally "target" (when return_target=True). Both "input" and "target" are dicts of per-variable tensors keyed by "{source}/{field_type}/{dim}/{varname}".

_resolve_persist_timestamp(t: pandas.Timestamp) pandas.Timestamp#

Snap t to the last native timestamp at or before t.

Uses pd.DatetimeIndex.asof() so non-uniform or non-zero-aligned cadences are handled correctly.

Parameters:

t – Master-clock timestamp to resolve.

Returns:

The last native timestamp <= t.

Raises:

ValueError – If t is before the first native timestamp.

_load_sample(t: pandas.Timestamp, i: int) dict[str, Any]#

Build and return the sample dict for timestamp t and step index i.

This is the inner implementation called by __getitem__, separated so the persist cache can call it without re-entering the dispatch logic.

Parameters:
  • t – Timestamp to load (already resolved for persist sources).

  • i – Within-sequence step index (0 = initial step).

Returns:

Sample dict with "input", "metadata", and optionally "target".

_check_in_data_config(data_config: dict[str, Any], key: str) None#

Check that a key is in the data config. If not, raise an error since this is required for any dataset.

Parameters:
  • data_config (dict[str, Any]) – Portion of the config under “data”

  • key (str) – The key to check (e.g., “timestep”)

Raises:

KeyError – When the key is not found in the data config

_in_source_config(data_config: dict[str, Any], curr_source_cfg: dict[str, Any], key: str) bool#

Helper to determine if a key is in the source config.

Parameters:
  • data_config (dict[str, Any]) – Portion of the config under “data”

  • curr_source_cfg (dict[str, Any]) – Portion of the config under a specific source

  • key (str) – The key to check (e.g., “timestep”)

Returns:

True if the key is in the source config, False otherwise

Return type:

bool

_load_dt(data_config: dict[str, Any], curr_source_config: dict[str, Any], dt_key: str = 'timestep') pandas.Timedelta#

The timestep (dt) is a required parameter for any dataset, and is used to build the clock of the sampler. In general, the timestep in the data config should be the smallest timestep across all sources. The inherited dataset may need a coarser timestep (e.g., in the case of multi-source datasets).

Parameters:
  • data_config (dict[str, Any]) – Portion of the config under “data”

  • curr_source_config (dict[str, Any]) – Portion of the config under a specific source

  • dt_key (str, optional) – The key for the timestep parameter. Defaults to “timestep”.

Returns:

The timestep for the dataset

Return type:

pd.Timedelta

_load_num_forecast_steps(data_config: dict[str, Any], curr_source_config: dict[str, Any], num_forecast_steps_key: str = 'forecast_len') int#

The number of forecast steps (num_forecast_steps) is a required parameter for any dataset, and is used in the sampler. In general, the number of forecast steps in the data config should be the largest across all sources, since this determines how far forward the sampler needs to roll out. The inherited dataset may not be able to rollout further.

Parameters:
  • data_config (dict[str, Any]) – Portion of the config under “data”

  • curr_source_config (dict[str, Any]) – Portion of the config under a specific source

  • num_forecast_steps_key (str, optional) – The key for the number of forecast steps parameter. Defaults to “forecast_len”.

Returns:

The number of forecast steps (i.e., length ahead) for the dataset

Return type:

int

_load_start_datetime(data_config: dict[str, Any], curr_source_config: dict[str, Any], start_datetime_key: str = 'start_datetime') pandas.Timestamp#

The start_datetime is a required parameter for any dataset, and is used in the sampler. In general, the start_datetime in the data config should be the earliest across all sources, since this determines the earliest point in time that the sampler can draw from. The inherited dataset may not be able to go as far back.

Parameters:
  • data_config (dict[str, Any]) – Portion of the config under “data”

  • curr_source_config (dict[str, Any]) – Portion of the config under a specific source

  • start_datetime_key (str, optional) – The key for the start datetime parameter. Defaults to “start_datetime”.

Returns:

The start datetime for the dataset

Return type:

pd.Timestamp

_load_end_datetime(data_config: dict[str, Any], curr_source_config: dict[str, Any], end_datetime_key: str = 'end_datetime') pandas.Timestamp#

The end_datetime is a required parameter for any dataset, and is used in the sampler. In general, the end_datetime in the data config should be the latest across all sources, since this determines the latest point in time that the sampler can draw from. The inherited dataset may not be able to go as far forward.

Parameters:
  • data_config (dict[str, Any]) – Portion of the config under “data”

  • curr_source_config (dict[str, Any]) – Portion of the config under a specific source

  • end_datetime_key (str, optional) – The key for the end datetime parameter. Defaults to “end_datetime”.

Returns:

The end datetime for the dataset

Return type:

pd.Timestamp

_build_timestamps() pandas.DatetimeIndex#

Return timestamps for the dataset using the class parameters. The timestamps should ensure that there are enough future timesteps to rollout based on num_forecast_steps and the dt timestep length, and should be at the configured timestep frequency.

Note: Please override this method if you would like to apply Quality Control checks that limit the datetimes from which to sample from, or if you would like to enforce time bounds automatically for your dataset. You can use super() to have base functionality in these cases.

Returns:

DatetimeIndex from start_datetime to end_datetime minus

the forecast horizon, at the configured timestep frequency.

Return type:

pd.DatetimeIndex

_get_field_name(field_type: VALID_FIELD_TYPES, dim_str: str, vname: str) str#

Get the field name and enforce a consistent convention across datasets.

The convention for the field name is: "{user's current source name}/{field_type}/{dim_str}/{vname}".

Parameters:
  • field_type (VALID_FIELD_TYPES) – The field type (e.g., “prognostic”)

  • dim_str (str) – The dimension string (e.g., “3d” or “2d”)

  • vname (str) – The variable name (e.g., “T” or “t2m”)

Returns:

The key string that will be used to access the field variable

Return type:

str

init_register_all_fields() None#

Initialize and register all fields for the dataset.

Raises:

KeyError – If the config does not include any variables.

_register_field(field_type: VALID_FIELD_TYPES, field_config: dict[str, Any] | None) None#

Validate and register one field type from the config variables block.

Populates self.file_dict and self.var_dict for field_type.

Parameters:
  • field_type – One of VALID_FIELD_TYPES, namely: "prognostic", "dynamic_forcing", "static", "diagnostic".

  • field_config – Field-type config dict, or None / null to disable the field.

Raises:
  • KeyError – If field_type is not a recognised field type.

  • ValueError – If field_config defines neither vars_3D nor vars_2D.

_get_file_source(field_config: dict[str, Any]) list[tuple[pandas.Timestamp, pandas.Timestamp, str]] | bool | None#

Return the file source for a field. Override in subclasses for different modes/backends.

Parameters:

field_config (dict[str, Any]) – Validated field-type config dict.

Raises:
  • FileNotFoundError – If self.mode == "local" and the glob pattern matches no files.

  • ValueError – If self.mode is not a recognised mode.

Returns:

Depending on the mode and field type,

this method may return a list of (start_time, end_time, file_path) tuples produced by _map_files, a boolean indicating the presence of the field (e.g., for remote data), or None if the field is disabled. The expected return type should be consistent within a dataset class.

Return type:

list[tuple[pd.Timestamp, pd.Timestamp, str]] | bool | None

_extract_field(field_type: VALID_FIELD_TYPES, t: pandas.Timestamp, sample: dict[str, Any]) None#

Base extract field method, which should be overridden in the inherited dataset class to extract the data for each field type. The method should populate data_dict with the extracted data for the given field type and timestamp. The keys in data_dict should follow the format in _get_field_name.

The entries are added as tensors to the sample[“input”] or sample[“target”] dict in __getitem__.

Parameters:
  • field_type (VALID_FIELD_TYPES) – One of VALID_FIELD_TYPES.

  • t (pd.Timestamp) – Query timestamp for which to extract the field data.

  • sample (dict[str, Any]) – The sample dict being built in __getitem__