credit.data#
Data.py contains modules for processing training data.
- Helper functions:
generate_datetime(start_time, end_time, interval_hr)
hour_to_nanoseconds(input_hr)
nanoseconds_to_year(nanoseconds_value)
extract_month_day_hour(dates)
find_common_indices(list1, list2)
concat_and_reshape(x1, x2)
reshape_only(x1)
get_forward_data(filename)
drop_var_from_dataset()
previous_hourly_steps()
next_n_hour()
encode_datetime64()
- Sample class:
Sample
Sample_WRF
Sample_dscale
Sample_diag
Sample_LES
- Deprecated
ERA5_and_Forcing_Dataset(torch.utils.data.Dataset)
Predict_Dataset(torch.utils.data.IterableDataset)
Attributes#
Classes#
Simple class for structuring data for the ML model. |
|
dict() -> new empty dictionary |
|
dict() -> new empty dictionary |
|
dict() -> new empty dictionary |
|
dict() -> new empty dictionary |
|
Deprecated |
|
ERA5 Dataset for Distributed training (legacy). |
|
Same as ERA5_and_Forcing_Dataset() but work with old rollout_to_netcdf.py. |
Functions#
|
Safely move tensor to device, with float32 casting on MPS (Metal Performance Shaders). |
|
Converts an input value (or array) to numpy.datetime64. |
|
Generate a list of datetime.datetime based on stat, end times, and hour interval. |
|
Convert hour to nanoseconds. |
|
Given datetime info as nanoseconds, compute which year it belongs to. |
|
Given an 1-d array of np.datatime64[ns], extract their mon, day, hr into a zipped list. |
|
Find indices of common elements between two lists. |
|
Flattening the "level" coordinate of upper-air variables and concatenate it will surface variables. |
|
Flattening the "level" coordinate of upper-air variables. |
|
Check nc vs. zarr files and open file as xr.Dataset. |
|
Flatten a list of lists. |
|
Generate a list of integers on either side of a given number with a specified spacing. |
|
Find the key in the dictionary based on the given number. |
|
Preserve a given set of variables from an xarray.Dataset, and drop the rest. |
|
Return a version of an xarray dataset with only a selected subset of variables. |
|
Return a spatial subset of shape (time, input_size[0], input_size[1]). |
|
|
|
Round dt forward to the next N-hour boundary. |
|
Given a datetime64[ns] time_pick, compute time_pick - step * hours. |
|
Return a new Dataset containing only the variables in varnames_keep. |
Module Contents#
- credit.data.Array#
- credit.data.IMAGE_ATTR_NAMES = ('historical_ERA5_images', 'target_ERA5_images')#
- credit.data.device_compatible_to(tensor: torch.Tensor, device: torch.device) torch.Tensor#
Safely move tensor to device, with float32 casting on MPS (Metal Performance Shaders). Addresses runtime error in OSX about MPS not supporting float64.
- Parameters:
tensor (torch.Tensor) – Input tensor to move.
device (torch.device) – Target device.
- Returns:
Tensor moved to device (cast to float32 if device is MPS).
- Return type:
torch.Tensor
- credit.data.ensure_numpy_datetime(value)#
Converts an input value (or array) to numpy.datetime64. Handles numpy arrays, pandas timestamps, cftime objects, and strings.
- credit.data.generate_datetime(start_time, end_time, interval_hr)#
Generate a list of datetime.datetime based on stat, end times, and hour interval.
- Parameters:
start_time (datetime.datetime) – start time
end_time (datetime.datetime) – end time
interval_hr (int) – hour interval
- credit.data.hour_to_nanoseconds(input_hr)#
Convert hour to nanoseconds.
- credit.data.nanoseconds_to_year(nanoseconds_value)#
Given datetime info as nanoseconds, compute which year it belongs to.
- credit.data.extract_month_day_hour(dates)#
Given an 1-d array of np.datatime64[ns], extract their mon, day, hr into a zipped list.
- credit.data.find_common_indices(list1, list2)#
Find indices of common elements between two lists.
- credit.data.concat_and_reshape(x1, x2)#
Flattening the “level” coordinate of upper-air variables and concatenate it will surface variables.
- credit.data.reshape_only(x1)#
Flattening the “level” coordinate of upper-air variables.
As in “concat_and_reshape”, but no concat.
- credit.data.get_forward_data(filename) xarray.Dataset#
Check nc vs. zarr files and open file as xr.Dataset.
- credit.data.flatten_list(list_of_lists)#
Flatten a list of lists.
- Parameters:
list_of_lists (list) – A list containing sublists.
- Returns
flattened_list (list): A flattened list containing all elements from sublists.
- credit.data.generate_integer_list_around(number, spacing=10)#
Generate a list of integers on either side of a given number with a specified spacing.
- Parameters:
number (int) – The central number around which the list is generated.
spacing (int) – The spacing between consecutive integers in the list. Default is 10.
- Returns:
List of integers on either side of the given number.
- Return type:
integer_list (list)
- credit.data.find_key_for_number(input_number, data_dict)#
Find the key in the dictionary based on the given number.
- Parameters:
input_number (int) – The number to search for in the dictionary.
data_dict (dict) – The dictionary with keys and corresponding value lists.
- Returns:
The key in the dictionary where the input number falls within the specified range.
- Return type:
key_found (str)
- credit.data.drop_var_from_dataset(xarray_dataset, varname_keep)#
Preserve a given set of variables from an xarray.Dataset, and drop the rest. It will raise error if varname_key is missing from xarray_dataset.
- credit.data.keep_dataset_vars(xarray_dataset: xarray.Dataset, varnames_keep: List[str])#
Return a version of an xarray dataset with only a selected subset of variables.
- Parameters:
xarray_dataset (xr.Dataset) – The xarray dataset.
varnames_keep (List[str]) – a list of variable names to be kept.
- Returns:
xr.Dataset with only the variables in varnames_keep included.
- credit.data.subset_patch(ds: xarray.Dataset, input_size, start, lat_name='yIndex', lon_name='xIndex') xarray.Dataset#
Return a spatial subset of shape (time, input_size[0], input_size[1]). Assumes ds has dims (time, lat, lon).
- credit.data.encode_datetime64(dt_array)#
- credit.data.next_n_hour(dt, period_hours)#
Round dt forward to the next N-hour boundary.
- Parameters:
dt – np.datetime64[ns] or array of such values
period_hours – int, the interval in hours (e.g., 3, 6)
- Returns:
np.datetime64[ns] rounded forward to the next period_hours boundary
- credit.data.previous_hourly_steps(time_pick, hour, step)#
Given a datetime64[ns] time_pick, compute time_pick - step * hours.
- credit.data.filter_ds(ds: xarray.Dataset, varnames_keep: Sequence[str]) xarray.Dataset#
Return a new Dataset containing only the variables in varnames_keep. Raises if any var in varnames_keep is missing.
- class credit.data.Sample#
Bases:
TypedDictSimple class for structuring data for the ML model.
- Using typing.TypedDict gives us several advantages:
Single ‘source of truth’ for the type and documentation of each example.
A static type checker can check the types are correct.
Instead of TypedDict, we could use typing.NamedTuple, which would provide runtime checks, but the deal-breaker with Tuples is that they’re immutable so we cannot change the values in the transforms.
- class credit.data.Sample_WRF#
Bases:
TypedDictdict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s
(key, value) pairs
- dict(iterable) -> new dictionary initialized as if via:
d = {} for k, v in iterable:
d[k] = v
- dict(**kwargs) -> new dictionary initialized with the name=value pairs
in the keyword argument list. For example: dict(one=1, two=2)
- class credit.data.Sample_dscale#
Bases:
TypedDictdict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s
(key, value) pairs
- dict(iterable) -> new dictionary initialized as if via:
d = {} for k, v in iterable:
d[k] = v
- dict(**kwargs) -> new dictionary initialized with the name=value pairs
in the keyword argument list. For example: dict(one=1, two=2)
- class credit.data.Sample_diag#
Bases:
TypedDictdict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s
(key, value) pairs
- dict(iterable) -> new dictionary initialized as if via:
d = {} for k, v in iterable:
d[k] = v
- dict(**kwargs) -> new dictionary initialized with the name=value pairs
in the keyword argument list. For example: dict(one=1, two=2)
- class credit.data.Sample_LES#
Bases:
TypedDictdict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s
(key, value) pairs
- dict(iterable) -> new dictionary initialized as if via:
d = {} for k, v in iterable:
d[k] = v
- dict(**kwargs) -> new dictionary initialized with the name=value pairs
in the keyword argument list. For example: dict(one=1, two=2)
- class credit.data.ERA5_and_Forcing_Dataset(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, history_len=2, forecast_len=0, transform=None, seed=42, skip_periods=None, one_shot=None, max_forecast_len=None, sst_forcing=None)#
Bases:
torch.utils.data.DatasetDeprecated A Pytorch Dataset class that works on the following kinds of variables.
upper-air variables (time, level, lat, lon)
surface variables (time, lat, lon)
dynamic forcing variables (time, lat, lon)
forcing variables (time, lat, lon)
diagnostic variables (time, lat, lon)
static variables (lat, lon).
- history_len = 2#
- forecast_len = 0#
- transform = None#
- skip_periods = None#
- one_shot = None#
- total_seq_len = 2#
- rng#
- max_forecast_len = None#
- sst_forcing = None#
- all_files = []#
- ERA5_indices#
- filename_forcing = None#
- filename_static = None#
- __post_init__()#
Calculate total sequence length after init.
- __len__()#
Length of Dataset.
- __getitem__(index)#
Get single item from the dataset.
- class credit.data.ERA5_Dataset_Distributed(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, history_len=2, forecast_len=0, transform=None, seed=42, skip_periods=None, one_shot=None, max_forecast_len=None, sst_forcing=None)#
Bases:
torch.utils.data.DatasetERA5 Dataset for Distributed training (legacy).
- history_len = 2#
- forecast_len = 0#
- transform = None#
- varname_upper_air#
- varname_surface#
- varname_dyn_forcing#
- varname_forcing#
- varname_static#
- skip_periods = None#
- one_shot = None#
- total_seq_len = 2#
- rng#
- max_forecast_len = None#
- sst_forcing = None#
- filenames#
- all_files = []#
- ERA5_indices#
- filename_forcing = None#
- filename_static = None#
- __post_init__()#
Calculate total sequence length.
- __len__()#
Length of dataset.
- __getitem__(index)#
Get item.
- Parameters:
index – index of timestep
- Returns:
pytorch Tensor containing a full state.
- class credit.data.Predict_Dataset(conf, varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface, filename_dyn_forcing, filename_forcing, filename_static, filename_diagnostic, fcst_datetime, history_len, rank, world_size, transform=None, rollout_p=0.0, which_forecast=None)#
Bases:
torch.utils.data.IterableDatasetSame as ERA5_and_Forcing_Dataset() but work with old rollout_to_netcdf.py.
- rank#
- world_size#
- transform = None#
- history_len#
- init_datetime#
- which_forecast = None#
- filenames#
- filename_surface#
- filename_dyn_forcing#
- filename_forcing#
- filename_static#
- filename_diagnostic#
- varname_upper_air#
- varname_surface#
- varname_dyn_forcing#
- varname_forcing#
- varname_static#
- varname_diagnostic#
- all_files = []#
- current_epoch = 0#
- rollout_p = 0.0#
- lead_time_periods#
- skip_periods#
- ds_read_and_subset(filename, time_start, time_end, varnames)#
Read and subset specified dataset.
- Parameters:
filename (str) – path to specified dataset file.
time_start (int) – start time index.
time_end (int) – end time index.
varnames (list) – List of variables to be read.
- load_zarr_as_input(i_file, i_init_start, i_init_end, mode='input')#
Load input data from zarr files.
- Parameters:
i_file – index of the file
i_init_start – start index of the data being loaded
i_init_end – end index of the data being loaded.
mode – “input” or “target”
- Returns:
xr.Dataset containing all the variables.
- find_start_stop_indices(index)#
Find start and stop indices for a given yearly data zarr file.
- Parameters:
index – indices of zarr file.
- __len__()#
Length of dataset.
- __iter__()#
Iterate through batch.