credit.datasets.om4_multistep_batcher

Contents

credit.datasets.om4_multistep_batcher#

Attributes#

Classes#

StandardScaler

Ocean_MultiStep_Batcher

Ocean dataset that handles both single-step and multi-step autoregressive training.

Ocean_Tensor_Batcher

Ocean dataset that handles both single-step and multi-step autoregressive training.

Predict_Ocean_Batcher

Ocean dataset that uses fixed forecast windows instead of continuous sampling.

Functions#

load_transform(conf)

Load data and return the StandardScaler normalization object.

Module Contents#

credit.datasets.om4_multistep_batcher.load_transform(conf)#

Load data and return the StandardScaler normalization object. Only essentials are kept.

class credit.datasets.om4_multistep_batcher.StandardScaler(data_mean: xarray.Dataset, data_std: xarray.Dataset, prognostic_vars: str, boundary_vars: str, wet_mask: torch.Tensor)#
prognostic_mean#
prognostic_std#
boundary_mean#
boundary_std#
wet_mask#
_prognostic_mean_np#
_prognostic_std_np#
_wet_mask_np#
_to_tensor(array: numpy.ndarray, device: torch.device) torch.Tensor#

Convert numpy array to tensor on specified device.

normalize_prognostics(data: xarray.Dataset, fill_nan=True, fill_value=0.0) xarray.Dataset#

Normalize input dataset.

normalize_boundary(data: xarray.Dataset, fill_nan=True, fill_value=0.0) xarray.Dataset#

Normalize boundary conditions.

unnormalize_prognostics(data: xarray.Dataset) xarray.Dataset#

Unnormalize output dataset.

normalize_tensor_prognostics(data: torch.Tensor, fill_nan=True, fill_value=0.0) torch.Tensor#

Normalize tensor.

unnormalize_tensor_prognostics(data: torch.Tensor) torch.Tensor#

Unnormalize tensor.

normalize_numpy_prognostics(data: numpy.ndarray, fill_nan=True, fill_value=0.0) numpy.ndarray#

Normalize numpy array.

unnormalize_numpy_prognostics(data: numpy.ndarray) numpy.ndarray#

Unnormalize numpy array.

transform_array(data: torch.Tensor, fill_nan=True, fill_value=0.0) torch.Tensor#

Normalize prognostic variables of a torch tensor. Expects shape (B, C, H, W) or (B, C, T, H, W).

inverse_transform(data: torch.Tensor) torch.Tensor#

Unnormalize prognostic variables of a torch tensor. Expects shape (B, C, H, W) or (B, C, T, H, W).

class credit.datasets.om4_multistep_batcher.Ocean_MultiStep_Batcher(conf, seed=42, rank=0, world_size=1, batch_size=1, shuffle=True)#

Bases: torch.utils.data.Dataset

Ocean dataset that handles both single-step and multi-step autoregressive training. Returns tensors with shape (batch, channels, time, lat, lon).

input_length#
output_length#
forecast_len#
seed = 42#
rank = 0#
world_size = 1#
shuffle = True#
batch_size = 1#
rng#
_prognostic_data#
_boundary_data#
num_prognostic_vars#
num_boundary_vars#
wet#
wet_surface#
normalize#
size#
sampler#
batch_indices = None#
time_steps = None#
forecast_step_counts = None#
current_epoch = None#
initialize_batch()#

Initialize batch indices using DistributedSampler’s indices.

__len__()#
set_epoch(epoch)#

Set epoch for distributed training.

batches_per_epoch()#
_get_batch_samples(sample_indices)#

Get multiple ocean samples efficiently by batching xarray operations. Returns batch dict with stacked tensors.

__getitem__(_)#

Returns batch with tensors shaped (batch, channels, time, lat, lon). Optimized for speed.

_get_ocean_sample(idx)#

Get single ocean sample with proper time dimensions. Returns tensors with shape (channels, time, lat, lon).

class credit.datasets.om4_multistep_batcher.Ocean_Tensor_Batcher(conf, seed=42, rank=0, world_size=1, batch_size=1, shuffle=True)#

Bases: torch.utils.data.Dataset

Ocean dataset that handles both single-step and multi-step autoregressive training. Loads cached .pt samples while preserving autoregressive logic.

wet#
wet_surface#
seed = 42#
rank = 0#
world_size = 1#
shuffle = True#
batch_size = 1#
input_length#
output_length#
forecast_len#
samples_dir#
size#
sampler#
batch_indices = None#
time_steps = None#
forecast_step_counts = None#
current_epoch = None#
num_prognostic_vars#
initialize_batch()#
__len__()#
set_epoch(epoch)#
batches_per_epoch()#
_get_batch_samples(sample_indices)#

Load multiple cached .pt samples and concatenate along time dimension.

__getitem__(_)#

Returns batch with tensors shaped (batch, channels, time, lat, lon) using autoregressive logic.

_get_ocean_sample(idx)#

Load a cached sample from disk.

class credit.datasets.om4_multistep_batcher.Predict_Ocean_Batcher(conf, forecast_windows, seed=42, rank=0, world_size=1, batch_size=1, shuffle=True)#

Bases: Ocean_MultiStep_Batcher

Ocean dataset that uses fixed forecast windows instead of continuous sampling. Accepts a list of [start, end] datetime string pairs for forecasting.

valid_indices = []#
forecast_windows#
size#
forecast_len = 0#
sampler#
batch_indices#
batch_size#
current_epoch = 0#
batch_call_count = 0#
_convert_windows_to_indices()#

Convert datetime string windows to time indices.

_build_valid_indices()#

Build list of valid sample indices from forecast windows. Each window gets ONE starting index (the first valid one).

__len__()#

Return number of valid samples across all windows.

initialize_batch()#

Initialize batch indices from valid indices.

credit.datasets.om4_multistep_batcher.conf#