credit.datasets.base_dataset#
AbstractBaseDataset and BaseDataset: A PyTorch Dataset class for: 1. Type hinting and annotations throughout CREDIT 2. Scaffolding the development of future datasets 3. Provide a minimal implementation of a Dataset for testing 4. Avoid redundant code across dataset classes
Attributes#
Classes#
Abstract base dataset class based on PyTorch Dataset class for CREDIT. |
|
PyTorch Dataset class for CREDIT that enables: |
Module Contents#
- credit.datasets.base_dataset.VALID_FIELD_TYPES#
- class credit.datasets.base_dataset.AbstractBaseDataset(data_config: dict[str, Any], return_target: bool = False)#
Bases:
torch.utils.data.Dataset[Any]Abstract base dataset class based on PyTorch Dataset class for CREDIT.
This class defines the expected methods and attributes for any dataset in CREDIT, but does not provide any implementation. The BaseDataset class inherits from this class and provides a minimal implementation. Any future dataset should inherit from either AbstractBaseDataset or BaseDataset depending on the level of functionality needed.
For generality, the inheritance is from torch.utils.data.Dataset[Any], however there may be benefits to stricter typing than Any for consistency in the get item return, especially if torch supports dataset type accelerations in future releases.
- curr_source_name: str#
- dataset_type: str#
- dt: pandas.Timedelta#
- num_forecast_steps: int#
- start_datetime: pandas.Timestamp#
- end_datetime: pandas.Timestamp#
- datetimes: pandas.DatetimeIndex#
- return_target: bool#
- mode: str#
- file_dict: dict[str, Any]#
- var_dict: dict[str, Any]#
- static_metadata: dict[str, Any]#
- abstractmethod __len__() int#
- abstractmethod __getitem__(args: tuple[pandas.Timestamp, int]) dict[str, Any]#
- abstractmethod _build_timestamps() pandas.DatetimeIndex#
- abstractmethod _get_field_name(field_type: VALID_FIELD_TYPES, dim_str: str, vname: str) str#
- abstractmethod init_register_all_fields() None#
- abstractmethod _register_field(field_type: VALID_FIELD_TYPES, field_config: dict[str, Any] | None) None#
- abstractmethod _get_file_source(field_config: dict[str, Any]) list[tuple[pandas.Timestamp, pandas.Timestamp, str]] | bool | None#
- abstractmethod _extract_field(field_type: VALID_FIELD_TYPES, t: pandas.Timestamp, sample: dict[str, Any]) None#
- class credit.datasets.base_dataset.BaseDataset(data_config: dict[str, Any], return_target: bool = False)#
Bases:
AbstractBaseDatasetPyTorch Dataset class for CREDIT that enables: 1. Type hinting and annotations throughout CREDIT 2. Scaffolding the development of future datasets 3. Provide a minimal implementation of a Dataset for testing
Minimal YAML config for a dataset will have the following stucture:
- ```yaml
- data:
- source:
- Example_Base: # User-provided name (arbitrary key)
# PARAMETERS FOR THIS DATASET TYPE # Ex: levels: [10, 20, 30] dataset_type: “base” # Needs to match per type of dataset! variables:
prognostic: null # vars_3D: [‘T’, ‘U’, ‘V’, ‘Q’] # Your 3D variables # vars_2D: [‘SP’, ‘t2m’] # Your 2D variables # dynamic_forcing: null # vars_3D: … # vars_2D: … # static: null # vars_3D: … # vars_2D: … # diagnostic: null # vars_3D: … # vars_2D: …
# OPTIONAL: Override the clock bounds for this dataset start_datetime: “2012-04-03T00:00Z”
# <YourName2>_<DatasetType2>: # Multiple datasets (see multi_source)
# These parameters set the overall clock of the sampler start_datetime: “2000-01-01T00:00:00Z” # The earliest datetime across datasets end_datetime: “2020-12-31T23:00:00Z” # The latest datetime across datasets timestep: “12h” # The smallest time interval for the clock forecast_len: 1 # The number of timesteps forward that need to be rolled out per sample
- curr_source_name#
- curr_source_cfg#
- dt: pandas.Timedelta#
- num_forecast_steps: int#
- start_datetime: pandas.Timestamp#
- end_datetime: pandas.Timestamp#
- datetimes: pandas.DatetimeIndex#
- return_target: bool = False#
- mode = 'local'#
- temporal_mode: str#
- _persist_cache: dict#
- static_metadata: dict[str, Any]#
- file_dict: dict[str, Any]#
- var_dict: dict[str, Any]#
- __len__() int#
For a CREDIT dataset, the length is the number of unique datetimes that can be sampled from.
- Returns:
Dataset length
- Return type:
int
- __getitem__(args: tuple[pandas.Timestamp, int]) dict[str, Any]#
Return a nested input/target sample dict.
When
temporal_mode == "persist", the requested timestamp t is snapped to the last native timestamp at-or-before t viapd.DatetimeIndex.asof(). The result is cached so that multiple fine-resolution master-clock ticks within the same native interval only trigger a single file read.- Parameters:
args –
(t, i)where t is the current timestamp (nanoseconds or pd.Timestamp) and i is the within-sequence step index produced by the sampler. Wheni == 0prognostic and static fields are loaded in addition to dynamic forcing.- Returns:
Dict with keys
"input","metadata", and optionally"target"(whenreturn_target=True). Both"input"and"target"are dicts of per-variable tensors keyed by"{source}/{field_type}/{dim}/{varname}".
- _resolve_persist_timestamp(t: pandas.Timestamp) pandas.Timestamp#
Snap t to the last native timestamp at or before t.
Uses
pd.DatetimeIndex.asof()so non-uniform or non-zero-aligned cadences are handled correctly.- Parameters:
t – Master-clock timestamp to resolve.
- Returns:
The last native timestamp
<= t.- Raises:
ValueError – If t is before the first native timestamp.
- _load_sample(t: pandas.Timestamp, i: int) dict[str, Any]#
Build and return the sample dict for timestamp t and step index i.
This is the inner implementation called by
__getitem__, separated so the persist cache can call it without re-entering the dispatch logic.- Parameters:
t – Timestamp to load (already resolved for persist sources).
i – Within-sequence step index (0 = initial step).
- Returns:
Sample dict with
"input","metadata", and optionally"target".
- _check_in_data_config(data_config: dict[str, Any], key: str) None#
Check that a key is in the data config. If not, raise an error since this is required for any dataset.
- Parameters:
data_config (dict[str, Any]) – Portion of the config under “data”
key (str) – The key to check (e.g., “timestep”)
- Raises:
KeyError – When the key is not found in the data config
- _in_source_config(data_config: dict[str, Any], curr_source_cfg: dict[str, Any], key: str) bool#
Helper to determine if a key is in the source config.
- Parameters:
data_config (dict[str, Any]) – Portion of the config under “data”
curr_source_cfg (dict[str, Any]) – Portion of the config under a specific source
key (str) – The key to check (e.g., “timestep”)
- Returns:
True if the key is in the source config, False otherwise
- Return type:
bool
- _load_dt(data_config: dict[str, Any], curr_source_config: dict[str, Any], dt_key: str = 'timestep') pandas.Timedelta#
The timestep (dt) is a required parameter for any dataset, and is used to build the clock of the sampler. In general, the timestep in the data config should be the smallest timestep across all sources. The inherited dataset may need a coarser timestep (e.g., in the case of multi-source datasets).
- Parameters:
data_config (dict[str, Any]) – Portion of the config under “data”
curr_source_config (dict[str, Any]) – Portion of the config under a specific source
dt_key (str, optional) – The key for the timestep parameter. Defaults to “timestep”.
- Returns:
The timestep for the dataset
- Return type:
pd.Timedelta
- _load_num_forecast_steps(data_config: dict[str, Any], curr_source_config: dict[str, Any], num_forecast_steps_key: str = 'forecast_len') int#
The number of forecast steps (num_forecast_steps) is a required parameter for any dataset, and is used in the sampler. In general, the number of forecast steps in the data config should be the largest across all sources, since this determines how far forward the sampler needs to roll out. The inherited dataset may not be able to rollout further.
- Parameters:
data_config (dict[str, Any]) – Portion of the config under “data”
curr_source_config (dict[str, Any]) – Portion of the config under a specific source
num_forecast_steps_key (str, optional) – The key for the number of forecast steps parameter. Defaults to “forecast_len”.
- Returns:
The number of forecast steps (i.e., length ahead) for the dataset
- Return type:
int
- _load_start_datetime(data_config: dict[str, Any], curr_source_config: dict[str, Any], start_datetime_key: str = 'start_datetime') pandas.Timestamp#
The start_datetime is a required parameter for any dataset, and is used in the sampler. In general, the start_datetime in the data config should be the earliest across all sources, since this determines the earliest point in time that the sampler can draw from. The inherited dataset may not be able to go as far back.
- Parameters:
data_config (dict[str, Any]) – Portion of the config under “data”
curr_source_config (dict[str, Any]) – Portion of the config under a specific source
start_datetime_key (str, optional) – The key for the start datetime parameter. Defaults to “start_datetime”.
- Returns:
The start datetime for the dataset
- Return type:
pd.Timestamp
- _load_end_datetime(data_config: dict[str, Any], curr_source_config: dict[str, Any], end_datetime_key: str = 'end_datetime') pandas.Timestamp#
The end_datetime is a required parameter for any dataset, and is used in the sampler. In general, the end_datetime in the data config should be the latest across all sources, since this determines the latest point in time that the sampler can draw from. The inherited dataset may not be able to go as far forward.
- Parameters:
data_config (dict[str, Any]) – Portion of the config under “data”
curr_source_config (dict[str, Any]) – Portion of the config under a specific source
end_datetime_key (str, optional) – The key for the end datetime parameter. Defaults to “end_datetime”.
- Returns:
The end datetime for the dataset
- Return type:
pd.Timestamp
- _build_timestamps() pandas.DatetimeIndex#
Return timestamps for the dataset using the class parameters. The timestamps should ensure that there are enough future timesteps to rollout based on num_forecast_steps and the dt timestep length, and should be at the configured timestep frequency.
Note: Please override this method if you would like to apply Quality Control checks that limit the datetimes from which to sample from, or if you would like to enforce time bounds automatically for your dataset. You can use super() to have base functionality in these cases.
- Returns:
- DatetimeIndex from
start_datetimetoend_datetimeminus the forecast horizon, at the configured timestep frequency.
- DatetimeIndex from
- Return type:
pd.DatetimeIndex
- _get_field_name(field_type: VALID_FIELD_TYPES, dim_str: str, vname: str) str#
Get the field name and enforce a consistent convention across datasets.
The convention for the field name is:
"{user's current source name}/{field_type}/{dim_str}/{vname}".- Parameters:
field_type (VALID_FIELD_TYPES) – The field type (e.g., “prognostic”)
dim_str (str) – The dimension string (e.g., “3d” or “2d”)
vname (str) – The variable name (e.g., “T” or “t2m”)
- Returns:
The key string that will be used to access the field variable
- Return type:
str
- init_register_all_fields() None#
Initialize and register all fields for the dataset.
- Raises:
KeyError – If the config does not include any variables.
- _register_field(field_type: VALID_FIELD_TYPES, field_config: dict[str, Any] | None) None#
Validate and register one field type from the config variables block.
Populates
self.file_dictandself.var_dictfor field_type.- Parameters:
field_type – One of VALID_FIELD_TYPES, namely:
"prognostic","dynamic_forcing","static","diagnostic".field_config – Field-type config dict, or
None/ null to disable the field.
- Raises:
KeyError – If field_type is not a recognised field type.
ValueError – If field_config defines neither
vars_3Dnorvars_2D.
- _get_file_source(field_config: dict[str, Any]) list[tuple[pandas.Timestamp, pandas.Timestamp, str]] | bool | None#
Return the file source for a field. Override in subclasses for different modes/backends.
- Parameters:
field_config (dict[str, Any]) – Validated field-type config dict.
- Raises:
FileNotFoundError – If
self.mode == "local"and the glob pattern matches no files.ValueError – If
self.modeis not a recognised mode.
- Returns:
- Depending on the mode and field type,
this method may return a list of (start_time, end_time, file_path) tuples produced by _map_files, a boolean indicating the presence of the field (e.g., for remote data), or None if the field is disabled. The expected return type should be consistent within a dataset class.
- Return type:
list[tuple[pd.Timestamp, pd.Timestamp, str]] | bool | None
- _extract_field(field_type: VALID_FIELD_TYPES, t: pandas.Timestamp, sample: dict[str, Any]) None#
Base extract field method, which should be overridden in the inherited dataset class to extract the data for each field type. The method should populate data_dict with the extracted data for the given field type and timestamp. The keys in data_dict should follow the format in _get_field_name.
The entries are added as tensors to the sample[“input”] or sample[“target”] dict in __getitem__.
- Parameters:
field_type (VALID_FIELD_TYPES) – One of VALID_FIELD_TYPES.
t (pd.Timestamp) – Query timestamp for which to extract the field data.
sample (dict[str, Any]) – The sample dict being built in __getitem__