credit.trainers.utils#

Functions#

cleanup()

cycle(dl)

accum_log(log, new_logs)

inject_flat_var_keys(→ None)

Inject Gen1-compatible variable keys into conf["data"] for metrics/loss.

inject_postblock_info(→ None)

Inject post-block indices into conf["postblock"] for proper initialization of Gen 1 postblocks.

load_dataset(...)

Build a MultiSourceDataset for train or validation.

load_dataloader(→ torch.utils.data.DataLoader)

Build a DataLoader with DistributedMultiStepBatchSampler.

load_model_states_and_optimizer(conf, model, device)

Load model weights, optimizer, scheduler, and gradient scaler.

Module Contents#

credit.trainers.utils.cleanup()#
credit.trainers.utils.cycle(dl)#
credit.trainers.utils.accum_log(log, new_logs)#
credit.trainers.utils.inject_flat_var_keys(conf: dict) None#

Inject Gen1-compatible variable keys into conf[“data”] for metrics/loss.

LatWeightedMetrics and VariableTotalLoss2D expect flat lists at: - conf["data"]["variables"] - conf["data"]["surface_variables"] - conf["data"]["diagnostic_variables"]

These are derived from the nested Gen2 source config.

credit.trainers.utils.inject_postblock_info(conf: dict) None#

Inject post-block indices into conf[“postblock”] for proper initialization of Gen 1 postblocks.

credit.trainers.utils.load_dataset(conf: dict, is_train: bool) credit.datasets.multi_source.MultiSourceDataset#

Build a MultiSourceDataset for train or validation.

credit.trainers.utils.load_dataloader(conf: dict, dataset: credit.datasets.multi_source.MultiSourceDataset, rank: int, world_size: int, is_train: bool) torch.utils.data.DataLoader#

Build a DataLoader with DistributedMultiStepBatchSampler.

credit.trainers.utils.load_model_states_and_optimizer(conf, model, device)#

Load model weights, optimizer, scheduler, and gradient scaler.