credit.data#
Data.py contains modules for processing training data.
- Content:
generate_datetime(start_time, end_time, interval_hr)
hour_to_nanoseconds(input_hr)
nanoseconds_to_year(nanoseconds_value)
extract_month_day_hour(dates)
find_common_indices(list1, list2)
concat_and_reshape(x1, x2)
reshape_only(x1)
get_forward_data(filename)
drop_var_from_dataset()
ERA5_and_Forcing_Dataset(torch.utils.data.Dataset)
Predict_Dataset(torch.utils.data.IterableDataset)
Attributes#
Classes#
Simple class for structuring data for the ML model. |
|
A Pytorch Dataset class that works on the following kinds of variables. |
|
ERA5 Dataset for Distributed training (legacy). |
|
Same as ERA5_and_Forcing_Dataset() but work with old rollout_to_netcdf.py. |
Functions#
|
Converts an input value (or array) to numpy.datetime64. |
|
Generate a list of datetime.datetime based on stat, end times, and hour interval. |
|
Convert hour to nanoseconds. |
|
Given datetime info as nanoseconds, compute which year it belongs to. |
|
Given an 1-d array of np.datatime64[ns], extract their mon, day, hr into a zipped list. |
|
Find indices of common elements between two lists. |
|
Flattening the "level" coordinate of upper-air variables and concatenate it will surface variables. |
|
Flattening the "level" coordinate of upper-air variables. |
|
Check nc vs. zarr files and open file as xr.Dataset. |
|
Flatten a list of lists. |
|
Generate a list of integers on either side of a given number with a specified spacing. |
|
Find the key in the dictionary based on the given number. |
|
Preserve a given set of variables from an xarray.Dataset, and drop the rest. |
|
Return a version of an xarray dataset with only a selected subset of variables. |
Module Contents#
- credit.data.Array#
- credit.data.IMAGE_ATTR_NAMES = ('historical_ERA5_images', 'target_ERA5_images')#
- credit.data.ensure_numpy_datetime(value)#
Converts an input value (or array) to numpy.datetime64. Handles numpy arrays, pandas timestamps, cftime objects, and strings.
- credit.data.generate_datetime(start_time, end_time, interval_hr)#
Generate a list of datetime.datetime based on stat, end times, and hour interval.
- Parameters:
start_time (datetime.datetime) – start time
end_time (datetime.datetime) – end time
interval_hr (int) – hour interval
- credit.data.hour_to_nanoseconds(input_hr)#
Convert hour to nanoseconds.
- credit.data.nanoseconds_to_year(nanoseconds_value)#
Given datetime info as nanoseconds, compute which year it belongs to.
- credit.data.extract_month_day_hour(dates)#
Given an 1-d array of np.datatime64[ns], extract their mon, day, hr into a zipped list.
- credit.data.find_common_indices(list1, list2)#
Find indices of common elements between two lists.
- credit.data.concat_and_reshape(x1, x2)#
Flattening the “level” coordinate of upper-air variables and concatenate it will surface variables.
- credit.data.reshape_only(x1)#
Flattening the “level” coordinate of upper-air variables.
As in “concat_and_reshape”, but no concat.
- credit.data.get_forward_data(filename) xarray.Dataset#
Check nc vs. zarr files and open file as xr.Dataset.
- class credit.data.Sample#
Bases:
TypedDictSimple class for structuring data for the ML model.
- Using typing.TypedDict gives us several advantages:
Single ‘source of truth’ for the type and documentation of each example.
A static type checker can check the types are correct.
Instead of TypedDict, we could use typing.NamedTuple, which would provide runtime checks, but the deal-breaker with Tuples is that they’re immutable so we cannot change the values in the transforms.
- historical_ERA5_images: Array#
- target_ERA5_images: Array#
- datetime_index: Array#
- credit.data.flatten_list(list_of_lists)#
Flatten a list of lists.
- Parameters:
(list) (- list_of_lists)
- Returns:
- flattened_list (list)
- Return type:
A flattened list containing all elements from sublists.
- credit.data.generate_integer_list_around(number, spacing=10)#
Generate a list of integers on either side of a given number with a specified spacing.
- Parameters:
(int) (- spacing)
(int)
- Returns:
- integer_list (list)
- Return type:
List of integers on either side of the given number.
- credit.data.find_key_for_number(input_number, data_dict)#
Find the key in the dictionary based on the given number.
- Parameters:
(int) (- input_number)
(dict) (- data_dict)
- Returns:
- key_found (str)
- Return type:
The key in the dictionary where the input number falls within the specified range.
- credit.data.drop_var_from_dataset(xarray_dataset, varname_keep)#
Preserve a given set of variables from an xarray.Dataset, and drop the rest.
It will raise error if varname_key is missing from xarray_dataset.
- credit.data.keep_dataset_vars(xarray_dataset: xarray.Dataset, varnames_keep: List[str])#
Return a version of an xarray dataset with only a selected subset of variables.
- Parameters:
xarray_dataset (xr.Dataset) – The xarray dataset.
varnames_keep (List[str]) – a list of variable names to be kept.
- Returns:
xr.Dataset with only the variables in varnames_keep included.
- class credit.data.ERA5_and_Forcing_Dataset(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, history_len=2, forecast_len=0, transform=None, seed=42, skip_periods=None, one_shot=None, max_forecast_len=None, sst_forcing=None)#
Bases:
torch.utils.data.DatasetA Pytorch Dataset class that works on the following kinds of variables.
upper-air variables (time, level, lat, lon)
surface variables (time, lat, lon)
dynamic forcing variables (time, lat, lon)
forcing variables (time, lat, lon)
diagnostic variables (time, lat, lon)
static variables (lat, lon).
- history_len = 2#
- forecast_len = 0#
- transform = None#
- skip_periods = None#
- one_shot = None#
- total_seq_len = 2#
- rng#
- max_forecast_len = None#
- sst_forcing = None#
- all_files = []#
- ERA5_indices#
- filename_forcing = None#
- filename_static = None#
- __post_init__()#
Calculate total sequence length after init.
- __len__()#
Length of Dataset.
- __getitem__(index)#
Get single item from the dataset.
- class credit.data.ERA5_Dataset_Distributed(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, history_len=2, forecast_len=0, transform=None, seed=42, skip_periods=None, one_shot=None, max_forecast_len=None, sst_forcing=None)#
Bases:
torch.utils.data.DatasetERA5 Dataset for Distributed training (legacy).
- history_len = 2#
- forecast_len = 0#
- transform = None#
- varname_upper_air#
- varname_surface#
- varname_dyn_forcing#
- varname_forcing#
- varname_static#
- skip_periods = None#
- one_shot = None#
- total_seq_len = 2#
- rng#
- max_forecast_len = None#
- sst_forcing = None#
- filenames#
- all_files = []#
- ERA5_indices#
- filename_forcing = None#
- filename_static = None#
- __post_init__()#
Calculate total sequence length.
- __len__()#
Length of dataset.
- __getitem__(index)#
Get item.
- Parameters:
index – index of timestep
- Returns:
pytorch Tensor containing a full state.
- class credit.data.Predict_Dataset(conf, varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface, filename_dyn_forcing, filename_forcing, filename_static, filename_diagnostic, fcst_datetime, history_len, rank, world_size, transform=None, rollout_p=0.0, which_forecast=None)#
Bases:
torch.utils.data.IterableDatasetSame as ERA5_and_Forcing_Dataset() but work with old rollout_to_netcdf.py.
- rank#
- world_size#
- transform = None#
- history_len#
- init_datetime#
- which_forecast = None#
- filenames#
- filename_surface#
- filename_dyn_forcing#
- filename_forcing#
- filename_static#
- filename_diagnostic#
- varname_upper_air#
- varname_surface#
- varname_dyn_forcing#
- varname_forcing#
- varname_static#
- varname_diagnostic#
- all_files = []#
- current_epoch = 0#
- rollout_p = 0.0#
- lead_time_periods#
- skip_periods#
- ds_read_and_subset(filename, time_start, time_end, varnames)#
Read and subset specified dataset.
- Parameters:
filename (str) – path to specified dataset file.
time_start (int) – start time index.
time_end (int) – end time index.
varnames (list) – List of variables to be read.
- load_zarr_as_input(i_file, i_init_start, i_init_end, mode='input')#
Load input data from zarr files.
- Parameters:
i_file – index of the file
i_init_start – start index of the data being loaded
i_init_end – end index of the data being loaded.
mode – “input” or “target”
- Returns:
xr.Dataset containing all the variables.
- find_start_stop_indices(index)#
Find start and stop indices for a given yearly data zarr file.
- Parameters:
index – indices of zarr file.
- __len__()#
Length of dataset.
- __iter__()#
Iterate through batch.