credit.datasets.om4_multistep_batcher#
Attributes#
Classes#
Ocean dataset that handles both single-step and multi-step autoregressive training. |
|
Ocean dataset that handles both single-step and multi-step autoregressive training. |
|
Ocean dataset that uses fixed forecast windows instead of continuous sampling. |
Functions#
|
Load data and return the StandardScaler normalization object. |
Module Contents#
- credit.datasets.om4_multistep_batcher.load_transform(conf)#
Load data and return the StandardScaler normalization object. Only essentials are kept.
- class credit.datasets.om4_multistep_batcher.StandardScaler(data_mean: xarray.Dataset, data_std: xarray.Dataset, prognostic_vars: str, boundary_vars: str, wet_mask: torch.Tensor)#
- prognostic_mean#
- prognostic_std#
- boundary_mean#
- boundary_std#
- wet_mask#
- _prognostic_mean_np#
- _prognostic_std_np#
- _wet_mask_np#
- _to_tensor(array: numpy.ndarray, device: torch.device) torch.Tensor#
Convert numpy array to tensor on specified device.
- normalize_prognostics(data: xarray.Dataset, fill_nan=True, fill_value=0.0) xarray.Dataset#
Normalize input dataset.
- normalize_boundary(data: xarray.Dataset, fill_nan=True, fill_value=0.0) xarray.Dataset#
Normalize boundary conditions.
- unnormalize_prognostics(data: xarray.Dataset) xarray.Dataset#
Unnormalize output dataset.
- normalize_tensor_prognostics(data: torch.Tensor, fill_nan=True, fill_value=0.0) torch.Tensor#
Normalize tensor.
- unnormalize_tensor_prognostics(data: torch.Tensor) torch.Tensor#
Unnormalize tensor.
- normalize_numpy_prognostics(data: numpy.ndarray, fill_nan=True, fill_value=0.0) numpy.ndarray#
Normalize numpy array.
- unnormalize_numpy_prognostics(data: numpy.ndarray) numpy.ndarray#
Unnormalize numpy array.
- transform_array(data: torch.Tensor, fill_nan=True, fill_value=0.0) torch.Tensor#
Normalize prognostic variables of a torch tensor. Expects shape (B, C, H, W) or (B, C, T, H, W).
- inverse_transform(data: torch.Tensor) torch.Tensor#
Unnormalize prognostic variables of a torch tensor. Expects shape (B, C, H, W) or (B, C, T, H, W).
- class credit.datasets.om4_multistep_batcher.Ocean_MultiStep_Batcher(conf, seed=42, rank=0, world_size=1, batch_size=1, shuffle=True)#
Bases:
torch.utils.data.DatasetOcean dataset that handles both single-step and multi-step autoregressive training. Returns tensors with shape (batch, channels, time, lat, lon).
- input_length#
- output_length#
- forecast_len#
- seed = 42#
- rank = 0#
- world_size = 1#
- shuffle = True#
- batch_size = 1#
- rng#
- _prognostic_data#
- _boundary_data#
- num_prognostic_vars#
- num_boundary_vars#
- wet#
- wet_surface#
- normalize#
- size#
- sampler#
- batch_indices = None#
- time_steps = None#
- forecast_step_counts = None#
- current_epoch = None#
- initialize_batch()#
Initialize batch indices using DistributedSampler’s indices.
- __len__()#
- set_epoch(epoch)#
Set epoch for distributed training.
- batches_per_epoch()#
- _get_batch_samples(sample_indices)#
Get multiple ocean samples efficiently by batching xarray operations. Returns batch dict with stacked tensors.
- __getitem__(_)#
Returns batch with tensors shaped (batch, channels, time, lat, lon). Optimized for speed.
- _get_ocean_sample(idx)#
Get single ocean sample with proper time dimensions. Returns tensors with shape (channels, time, lat, lon).
- class credit.datasets.om4_multistep_batcher.Ocean_Tensor_Batcher(conf, seed=42, rank=0, world_size=1, batch_size=1, shuffle=True)#
Bases:
torch.utils.data.DatasetOcean dataset that handles both single-step and multi-step autoregressive training. Loads cached .pt samples while preserving autoregressive logic.
- wet#
- wet_surface#
- seed = 42#
- rank = 0#
- world_size = 1#
- shuffle = True#
- batch_size = 1#
- input_length#
- output_length#
- forecast_len#
- samples_dir#
- size#
- sampler#
- batch_indices = None#
- time_steps = None#
- forecast_step_counts = None#
- current_epoch = None#
- num_prognostic_vars#
- initialize_batch()#
- __len__()#
- set_epoch(epoch)#
- batches_per_epoch()#
- _get_batch_samples(sample_indices)#
Load multiple cached .pt samples and concatenate along time dimension.
- __getitem__(_)#
Returns batch with tensors shaped (batch, channels, time, lat, lon) using autoregressive logic.
- _get_ocean_sample(idx)#
Load a cached sample from disk.
- class credit.datasets.om4_multistep_batcher.Predict_Ocean_Batcher(conf, forecast_windows, seed=42, rank=0, world_size=1, batch_size=1, shuffle=True)#
Bases:
Ocean_MultiStep_BatcherOcean dataset that uses fixed forecast windows instead of continuous sampling. Accepts a list of [start, end] datetime string pairs for forecasting.
- valid_indices = []#
- forecast_windows#
- size#
- forecast_len = 0#
- sampler#
- batch_indices#
- batch_size#
- current_epoch = 0#
- batch_call_count = 0#
- _convert_windows_to_indices()#
Convert datetime string windows to time indices.
- _build_valid_indices()#
Build list of valid sample indices from forecast windows. Each window gets ONE starting index (the first valid one).
- __len__()#
Return number of valid samples across all windows.
- initialize_batch()#
Initialize batch indices from valid indices.
- credit.datasets.om4_multistep_batcher.conf#