credit.data

Contents

credit.data#

Data.py contains modules for processing training data.

Helper functions:
  • generate_datetime(start_time, end_time, interval_hr)

  • hour_to_nanoseconds(input_hr)

  • nanoseconds_to_year(nanoseconds_value)

  • extract_month_day_hour(dates)

  • find_common_indices(list1, list2)

  • concat_and_reshape(x1, x2)

  • reshape_only(x1)

  • get_forward_data(filename)

  • drop_var_from_dataset()

  • previous_hourly_steps()

  • next_n_hour()

  • encode_datetime64()

Sample class:
  • Sample

  • Sample_WRF

  • Sample_dscale

  • Sample_diag

  • Sample_LES

Deprecated
  • ERA5_and_Forcing_Dataset(torch.utils.data.Dataset)

  • Predict_Dataset(torch.utils.data.IterableDataset)

Attributes#

Classes#

Sample

Simple class for structuring data for the ML model.

Sample_WRF

dict() -> new empty dictionary

Sample_dscale

dict() -> new empty dictionary

Sample_diag

dict() -> new empty dictionary

Sample_LES

dict() -> new empty dictionary

ERA5_and_Forcing_Dataset

Deprecated

ERA5_Dataset_Distributed

ERA5 Dataset for Distributed training (legacy).

Predict_Dataset

Same as ERA5_and_Forcing_Dataset() but work with old rollout_to_netcdf.py.

Functions#

device_compatible_to(→ torch.Tensor)

Safely move tensor to device, with float32 casting on MPS (Metal Performance Shaders).

ensure_numpy_datetime(value)

Converts an input value (or array) to numpy.datetime64.

generate_datetime(start_time, end_time, interval_hr)

Generate a list of datetime.datetime based on stat, end times, and hour interval.

hour_to_nanoseconds(input_hr)

Convert hour to nanoseconds.

nanoseconds_to_year(nanoseconds_value)

Given datetime info as nanoseconds, compute which year it belongs to.

extract_month_day_hour(dates)

Given an 1-d array of np.datatime64[ns], extract their mon, day, hr into a zipped list.

find_common_indices(list1, list2)

Find indices of common elements between two lists.

concat_and_reshape(x1, x2)

Flattening the "level" coordinate of upper-air variables and concatenate it will surface variables.

reshape_only(x1)

Flattening the "level" coordinate of upper-air variables.

get_forward_data(→ xarray.Dataset)

Check nc vs. zarr files and open file as xr.Dataset.

flatten_list(list_of_lists)

Flatten a list of lists.

generate_integer_list_around(number[, spacing])

Generate a list of integers on either side of a given number with a specified spacing.

find_key_for_number(input_number, data_dict)

Find the key in the dictionary based on the given number.

drop_var_from_dataset(xarray_dataset, varname_keep)

Preserve a given set of variables from an xarray.Dataset, and drop the rest.

keep_dataset_vars(xarray_dataset, varnames_keep)

Return a version of an xarray dataset with only a selected subset of variables.

subset_patch(→ xarray.Dataset)

Return a spatial subset of shape (time, input_size[0], input_size[1]).

encode_datetime64(dt_array)

next_n_hour(dt, period_hours)

Round dt forward to the next N-hour boundary.

previous_hourly_steps(time_pick, hour, step)

Given a datetime64[ns] time_pick, compute time_pick - step * hours.

filter_ds(→ xarray.Dataset)

Return a new Dataset containing only the variables in varnames_keep.

Module Contents#

credit.data.Array#
credit.data.IMAGE_ATTR_NAMES = ('historical_ERA5_images', 'target_ERA5_images')#
credit.data.device_compatible_to(tensor: torch.Tensor, device: torch.device) torch.Tensor#

Safely move tensor to device, with float32 casting on MPS (Metal Performance Shaders). Addresses runtime error in OSX about MPS not supporting float64.

Parameters:
  • tensor (torch.Tensor) – Input tensor to move.

  • device (torch.device) – Target device.

Returns:

Tensor moved to device (cast to float32 if device is MPS).

Return type:

torch.Tensor

credit.data.ensure_numpy_datetime(value)#

Converts an input value (or array) to numpy.datetime64. Handles numpy arrays, pandas timestamps, cftime objects, and strings.

credit.data.generate_datetime(start_time, end_time, interval_hr)#

Generate a list of datetime.datetime based on stat, end times, and hour interval.

Parameters:
  • start_time (datetime.datetime) – start time

  • end_time (datetime.datetime) – end time

  • interval_hr (int) – hour interval

credit.data.hour_to_nanoseconds(input_hr)#

Convert hour to nanoseconds.

credit.data.nanoseconds_to_year(nanoseconds_value)#

Given datetime info as nanoseconds, compute which year it belongs to.

credit.data.extract_month_day_hour(dates)#

Given an 1-d array of np.datatime64[ns], extract their mon, day, hr into a zipped list.

credit.data.find_common_indices(list1, list2)#

Find indices of common elements between two lists.

credit.data.concat_and_reshape(x1, x2)#

Flattening the “level” coordinate of upper-air variables and concatenate it will surface variables.

credit.data.reshape_only(x1)#

Flattening the “level” coordinate of upper-air variables.

As in “concat_and_reshape”, but no concat.

credit.data.get_forward_data(filename) xarray.Dataset#

Check nc vs. zarr files and open file as xr.Dataset.

credit.data.flatten_list(list_of_lists)#

Flatten a list of lists.

Parameters:

list_of_lists (list) – A list containing sublists.

Returns

flattened_list (list): A flattened list containing all elements from sublists.

credit.data.generate_integer_list_around(number, spacing=10)#

Generate a list of integers on either side of a given number with a specified spacing.

Parameters:
  • number (int) – The central number around which the list is generated.

  • spacing (int) – The spacing between consecutive integers in the list. Default is 10.

Returns:

List of integers on either side of the given number.

Return type:

integer_list (list)

credit.data.find_key_for_number(input_number, data_dict)#

Find the key in the dictionary based on the given number.

Parameters:
  • input_number (int) – The number to search for in the dictionary.

  • data_dict (dict) – The dictionary with keys and corresponding value lists.

Returns:

The key in the dictionary where the input number falls within the specified range.

Return type:

key_found (str)

credit.data.drop_var_from_dataset(xarray_dataset, varname_keep)#

Preserve a given set of variables from an xarray.Dataset, and drop the rest. It will raise error if varname_key is missing from xarray_dataset.

credit.data.keep_dataset_vars(xarray_dataset: xarray.Dataset, varnames_keep: List[str])#

Return a version of an xarray dataset with only a selected subset of variables.

Parameters:
  • xarray_dataset (xr.Dataset) – The xarray dataset.

  • varnames_keep (List[str]) – a list of variable names to be kept.

Returns:

xr.Dataset with only the variables in varnames_keep included.

credit.data.subset_patch(ds: xarray.Dataset, input_size, start, lat_name='yIndex', lon_name='xIndex') xarray.Dataset#

Return a spatial subset of shape (time, input_size[0], input_size[1]). Assumes ds has dims (time, lat, lon).

credit.data.encode_datetime64(dt_array)#
credit.data.next_n_hour(dt, period_hours)#

Round dt forward to the next N-hour boundary.

Parameters:
  • dt – np.datetime64[ns] or array of such values

  • period_hours – int, the interval in hours (e.g., 3, 6)

Returns:

np.datetime64[ns] rounded forward to the next period_hours boundary

credit.data.previous_hourly_steps(time_pick, hour, step)#

Given a datetime64[ns] time_pick, compute time_pick - step * hours.

credit.data.filter_ds(ds: xarray.Dataset, varnames_keep: Sequence[str]) xarray.Dataset#

Return a new Dataset containing only the variables in varnames_keep. Raises if any var in varnames_keep is missing.

class credit.data.Sample#

Bases: TypedDict

Simple class for structuring data for the ML model.

Using typing.TypedDict gives us several advantages:
  1. Single ‘source of truth’ for the type and documentation of each example.

  2. A static type checker can check the types are correct.

Instead of TypedDict, we could use typing.NamedTuple, which would provide runtime checks, but the deal-breaker with Tuples is that they’re immutable so we cannot change the values in the transforms.

historical_ERA5_images: Array#
target_ERA5_images: Array#
datetime_index: Array#
class credit.data.Sample_WRF#

Bases: TypedDict

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s

(key, value) pairs

dict(iterable) -> new dictionary initialized as if via:

d = {} for k, v in iterable:

d[k] = v

dict(**kwargs) -> new dictionary initialized with the name=value pairs

in the keyword argument list. For example: dict(one=1, two=2)

WRF_input: Array#
WRF_target: Array#
boundary_input: Array#
time_encode: Array#
datetime_index: Array#
class credit.data.Sample_dscale#

Bases: TypedDict

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s

(key, value) pairs

dict(iterable) -> new dictionary initialized as if via:

d = {} for k, v in iterable:

d[k] = v

dict(**kwargs) -> new dictionary initialized with the name=value pairs

in the keyword argument list. For example: dict(one=1, two=2)

LR_input: Array#
HR_input: Array#
HR_target: Array#
time_encode: Array#
datetime_index: Array#
class credit.data.Sample_diag#

Bases: TypedDict

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s

(key, value) pairs

dict(iterable) -> new dictionary initialized as if via:

d = {} for k, v in iterable:

d[k] = v

dict(**kwargs) -> new dictionary initialized with the name=value pairs

in the keyword argument list. For example: dict(one=1, two=2)

WRF_input: Array#
WRF_target: Array#
time_encode: Array#
datetime_index: Array#
class credit.data.Sample_LES#

Bases: TypedDict

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object’s

(key, value) pairs

dict(iterable) -> new dictionary initialized as if via:

d = {} for k, v in iterable:

d[k] = v

dict(**kwargs) -> new dictionary initialized with the name=value pairs

in the keyword argument list. For example: dict(one=1, two=2)

LES_input: Array#
LES_target: Array#
datetime_index: Array#
class credit.data.ERA5_and_Forcing_Dataset(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, history_len=2, forecast_len=0, transform=None, seed=42, skip_periods=None, one_shot=None, max_forecast_len=None, sst_forcing=None)#

Bases: torch.utils.data.Dataset

Deprecated A Pytorch Dataset class that works on the following kinds of variables.

  • upper-air variables (time, level, lat, lon)

  • surface variables (time, lat, lon)

  • dynamic forcing variables (time, lat, lon)

  • forcing variables (time, lat, lon)

  • diagnostic variables (time, lat, lon)

  • static variables (lat, lon).

history_len = 2#
forecast_len = 0#
transform = None#
skip_periods = None#
one_shot = None#
total_seq_len = 2#
rng#
max_forecast_len = None#
sst_forcing = None#
all_files = []#
ERA5_indices#
filename_forcing = None#
filename_static = None#
__post_init__()#

Calculate total sequence length after init.

__len__()#

Length of Dataset.

__getitem__(index)#

Get single item from the dataset.

class credit.data.ERA5_Dataset_Distributed(varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface=None, filename_dyn_forcing=None, filename_forcing=None, filename_static=None, filename_diagnostic=None, history_len=2, forecast_len=0, transform=None, seed=42, skip_periods=None, one_shot=None, max_forecast_len=None, sst_forcing=None)#

Bases: torch.utils.data.Dataset

ERA5 Dataset for Distributed training (legacy).

history_len = 2#
forecast_len = 0#
transform = None#
varname_upper_air#
varname_surface#
varname_dyn_forcing#
varname_forcing#
varname_static#
skip_periods = None#
one_shot = None#
total_seq_len = 2#
rng#
max_forecast_len = None#
sst_forcing = None#
filenames#
all_files = []#
ERA5_indices#
filename_forcing = None#
filename_static = None#
__post_init__()#

Calculate total sequence length.

__len__()#

Length of dataset.

__getitem__(index)#

Get item.

Parameters:

index – index of timestep

Returns:

pytorch Tensor containing a full state.

class credit.data.Predict_Dataset(conf, varname_upper_air, varname_surface, varname_dyn_forcing, varname_forcing, varname_static, varname_diagnostic, filenames, filename_surface, filename_dyn_forcing, filename_forcing, filename_static, filename_diagnostic, fcst_datetime, history_len, rank, world_size, transform=None, rollout_p=0.0, which_forecast=None)#

Bases: torch.utils.data.IterableDataset

Same as ERA5_and_Forcing_Dataset() but work with old rollout_to_netcdf.py.

rank#
world_size#
transform = None#
history_len#
init_datetime#
which_forecast = None#
filenames#
filename_surface#
filename_dyn_forcing#
filename_forcing#
filename_static#
filename_diagnostic#
varname_upper_air#
varname_surface#
varname_dyn_forcing#
varname_forcing#
varname_static#
varname_diagnostic#
all_files = []#
current_epoch = 0#
rollout_p = 0.0#
lead_time_periods#
skip_periods#
ds_read_and_subset(filename, time_start, time_end, varnames)#

Read and subset specified dataset.

Parameters:
  • filename (str) – path to specified dataset file.

  • time_start (int) – start time index.

  • time_end (int) – end time index.

  • varnames (list) – List of variables to be read.

load_zarr_as_input(i_file, i_init_start, i_init_end, mode='input')#

Load input data from zarr files.

Parameters:
  • i_file – index of the file

  • i_init_start – start index of the data being loaded

  • i_init_end – end index of the data being loaded.

  • mode – “input” or “target”

Returns:

xr.Dataset containing all the variables.

find_start_stop_indices(index)#

Find start and stop indices for a given yearly data zarr file.

Parameters:

index – indices of zarr file.

__len__()#

Length of dataset.

__iter__()#

Iterate through batch.