credit.trainers.base_trainer#

Content:
  • EMATracker

  • BaseTrainer (abstract)
    • train_one_epoch (abstract)

    • validate (abstract)

    • fit

    • _save_checkpoint (internal)

Attributes#

Classes#

EMATracker

Exponential moving average of model weights.

BaseTrainer

Helper class that provides a standard way to create an ABC using

Module Contents#

credit.trainers.base_trainer._SummaryWriter = None#
credit.trainers.base_trainer.logger#
class credit.trainers.base_trainer.EMATracker(model: torch.nn.Module, decay: float = 0.9999)#

Exponential moving average of model weights.

Maintains a shadow copy of the model parameters:

shadow = decay * shadow + (1 - decay) * param

Uses adaptive decay so short runs are not dominated by initial random weights:

effective_decay = min(max_decay, (1 + step) / (10 + step))

This ramps from ~0.09 at step 0 to max_decay asymptotically, so validation always reflects recent training weights regardless of run length.

Usage:

ema = EMATracker(model, decay=0.9999) # after each optimizer.step(): ema.update(model) # before validation: ema.swap(model) # model now holds EMA weights …validate… ema.swap(model) # restore training weights

Typical max_decay: 0.9999 for long runs, 0.999 for short runs.

decay = 0.9999#
step = 0#
shadow: collections.OrderedDict#
static _is_spectral_norm_buffer(key: str, state: dict) bool#
update(model: torch.nn.Module)#
swap(model: torch.nn.Module)#

Swap model weights with EMA shadow weights (and vice-versa).

state_dict()#
load_state_dict(d)#
class credit.trainers.base_trainer.BaseTrainer(model: torch.nn.Module, rank: int, conf: Dict[str, Any])#

Bases: abc.ABC

Helper class that provides a standard way to create an ABC using inheritance.

model#
rank#
device#
conf#
save_loc#
mode#
distributed#
start_epoch#
epochs#
skip_validation#
load_weights#
use_scheduler#
scheduler_type#
amp#
grad_max_norm#
batches_per_epoch#
valid_batches_per_epoch#
ensemble_size#
save_best_weights#
save_backup_weights#
stopping_patience#
save_every_epoch#
stop_after_epoch#
num_epoch#
save_metric_vars#
train_one_epoch_mode#
training_metric#
direction#
static _model_state_dict(model: torch.nn.Module) dict#

Return state dict, unwrapping DDP .module wrapper if present.

static _load_model_state_dict(model: torch.nn.Module, state_dict: dict) None#

Load state dict, unwrapping DDP .module wrapper if present.

abstractmethod train_one_epoch(epoch: int, trainloader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, scaler: torch.amp.GradScaler, scheduler: torch.optim.lr_scheduler.LRScheduler, metrics: Dict[str, Any]) Dict[str, float]#
abstractmethod validate(epoch: int, valid_loader: torch.utils.data.DataLoader, criterion: torch.nn.Module, metrics: Dict[str, Any]) Dict[str, float]#
_log_batch_progress(epoch: int, results_dict: dict, optimizer: torch.optim.Optimizer | None, pbar, phase: str = 'train') None#

Update a tqdm progress bar with rolling-mean batch metrics.

Reads any keys in results_dict that start with <phase>_ and formats them as a space-separated description string. Always appends the current learning rate.

Parameters:
  • epoch – Current epoch number.

  • results_dict – Dict mapping metric names to lists of per-batch values.

  • optimizer – Current optimizer (used to read the learning rate).

  • pbar – tqdm progress bar to update.

  • phase – Metric prefix, either "train" or "valid".

_save_checkpoint(epoch: int, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, scaler: torch.amp.GradScaler) None#

Save model, optimizer, scheduler, and scaler state.

fit(conf: Dict[str, Any], train_loader: torch.utils.data.DataLoader, valid_loader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, train_criterion: torch.nn.Module, valid_criterion: torch.nn.Module, scaler: torch.amp.GradScaler, scheduler: torch.optim.lr_scheduler.LRScheduler, metrics: Dict[str, Any], rollout_scheduler: Callable | None = None, trial: bool = False) Dict[str, Any]#

Run the full training loop.

Parameters:
  • conf – Full configuration dict (passed through to train_one_epoch/validate for data-related settings; trainer settings are accessed via self).

  • train_loader – DataLoaders.

  • valid_loader – DataLoaders.

  • optimizer – Training objects.

  • train_criterion – Training objects.

  • valid_criterion – Training objects.

  • scaler – Training objects.

  • scheduler – Training objects.

  • metrics – Training objects.

  • rollout_scheduler – Optional callable to schedule rollout probability.

  • trial – Optuna trial object, or False.

Returns:

Dict with the best epoch’s results.