credit.trainers.utils#
Functions#
|
|
|
|
|
|
|
Inject Gen1-compatible variable keys into conf["data"] for metrics/loss. |
|
Inject post-block indices into conf["postblock"] for proper initialization of Gen 1 postblocks. |
|
Build a MultiSourceDataset for train or validation. |
|
Build a DataLoader with DistributedMultiStepBatchSampler. |
|
Load model weights, optimizer, scheduler, and gradient scaler. |
Module Contents#
- credit.trainers.utils.cleanup()#
- credit.trainers.utils.cycle(dl)#
- credit.trainers.utils.accum_log(log, new_logs)#
- credit.trainers.utils.inject_flat_var_keys(conf: dict) None#
Inject Gen1-compatible variable keys into conf[“data”] for metrics/loss.
LatWeightedMetricsandVariableTotalLoss2Dexpect flat lists at: -conf["data"]["variables"]-conf["data"]["surface_variables"]-conf["data"]["diagnostic_variables"]These are derived from the nested Gen2 source config.
- credit.trainers.utils.inject_postblock_info(conf: dict) None#
Inject post-block indices into conf[“postblock”] for proper initialization of Gen 1 postblocks.
- credit.trainers.utils.load_dataset(conf: dict, is_train: bool) credit.datasets.multi_source.MultiSourceDataset#
Build a MultiSourceDataset for train or validation.
- credit.trainers.utils.load_dataloader(conf: dict, dataset: credit.datasets.multi_source.MultiSourceDataset, rank: int, world_size: int, is_train: bool) torch.utils.data.DataLoader#
Build a DataLoader with DistributedMultiStepBatchSampler.
- credit.trainers.utils.load_model_states_and_optimizer(conf, model, device)#
Load model weights, optimizer, scheduler, and gradient scaler.