credit.trainers.ic_optimization#

Attributes#

Classes#

ForecastProcessor

TimeStepper

TrainerIC

Helper class that provides a standard way to create an ABC using

Functions#

save_netcdf_increment(darray_upper_air, ...)

Save CREDIT model prediction output to netCDF file. Also performs pressure level

Module Contents#

credit.trainers.ic_optimization.logger#
credit.trainers.ic_optimization.save_netcdf_increment(darray_upper_air: xarray.DataArray, darray_single_level: xarray.DataArray, nc_filename: str, forecast_hour: int, meta_data: dict, conf: dict, name_tag: str)#

Save CREDIT model prediction output to netCDF file. Also performs pressure level interpolation on the output if you wish.

Parameters:
  • darray_upper_air (xr.DataArray) – upper air variable predictions

  • darray_single_level (xr.DataArray) – surface variable predictions

  • nc_filename (str) – file description to go into output filenames

  • forecast_hour (int) – how many hours since the initialization of the model.

  • meta_data (dict) – metadata dictionary for output variables

  • conf (dict) – configuration dictionary for training and/or rollout

class credit.trainers.ic_optimization.ForecastProcessor(conf, device)#
conf#
device#
batch_size#
ensemble_size#
lead_time_periods#
latlons#
meta_data#
process(y_pred, datetimes, save_datetimes, nametag)#
class credit.trainers.ic_optimization.TimeStepper(dataset)#
dataset#
_active = False#
__iter__()#
reset(idx=0)#

Initialize new sample starting from forecast step 0.

__next__()#

Advance forecast steps until forecast_len + 1.

class credit.trainers.ic_optimization.TrainerIC(model: torch.nn.Module, rank: int)#

Bases: credit.trainers.base_trainer.BaseTrainer

Helper class that provides a standard way to create an ABC using inheritance.

train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler, scheduler, metrics)#

Trains the model for one epoch.

Parameters:
  • epoch (int) – Current epoch number.

  • conf (dict) – Configuration dictionary containing training settings.

  • trainloader (DataLoader) – DataLoader for the training dataset.

  • optimizer (torch.optim.Optimizer) – Optimizer used for training.

  • criterion (callable) – Loss function used for training.

  • scaler (torch.cuda.amp.GradScaler) – Gradient scaler for mixed precision training.

  • scheduler (torch.optim.lr_scheduler._LRScheduler) – Learning rate scheduler.

  • metrics (callable) – Function to compute metrics for evaluation.

Returns:

Dictionary containing training metrics and loss for the epoch.

Return type:

dict

validate(epoch, conf, valid_loader, criterion, metrics)#

Validates the model on the validation dataset.

Parameters:
  • epoch (int) – Current epoch number.

  • conf (dict) – Configuration dictionary containing validation settings.

  • valid_loader (DataLoader) – DataLoader for the validation dataset.

  • criterion (callable) – Loss function used for validation.

  • metrics (callable) – Function to compute metrics for evaluation.

Returns:

Dictionary containing validation metrics and loss for the epoch.

Return type:

dict