credit.trainers.base_trainer#
- Content:
EMATracker
- BaseTrainer (abstract)
train_one_epoch (abstract)
validate (abstract)
fit
_save_checkpoint (internal)
Attributes#
Classes#
Exponential moving average of model weights. |
|
Helper class that provides a standard way to create an ABC using |
Module Contents#
- credit.trainers.base_trainer._SummaryWriter = None#
- credit.trainers.base_trainer.logger#
- class credit.trainers.base_trainer.EMATracker(model: torch.nn.Module, decay: float = 0.9999)#
Exponential moving average of model weights.
- Maintains a shadow copy of the model parameters:
shadow = decay * shadow + (1 - decay) * param
- Uses adaptive decay so short runs are not dominated by initial random weights:
effective_decay = min(max_decay, (1 + step) / (10 + step))
This ramps from ~0.09 at step 0 to max_decay asymptotically, so validation always reflects recent training weights regardless of run length.
- Usage:
ema = EMATracker(model, decay=0.9999) # after each optimizer.step(): ema.update(model) # before validation: ema.swap(model) # model now holds EMA weights …validate… ema.swap(model) # restore training weights
Typical max_decay: 0.9999 for long runs, 0.999 for short runs.
- decay = 0.9999#
- step = 0#
- shadow: collections.OrderedDict#
- static _is_spectral_norm_buffer(key: str, state: dict) bool#
- update(model: torch.nn.Module)#
- swap(model: torch.nn.Module)#
Swap model weights with EMA shadow weights (and vice-versa).
- state_dict()#
- load_state_dict(d)#
- class credit.trainers.base_trainer.BaseTrainer(model: torch.nn.Module, rank: int, conf: Dict[str, Any])#
Bases:
abc.ABCHelper class that provides a standard way to create an ABC using inheritance.
- model#
- rank#
- device#
- conf#
- save_loc#
- mode#
- distributed#
- start_epoch#
- epochs#
- skip_validation#
- load_weights#
- use_scheduler#
- scheduler_type#
- amp#
- grad_max_norm#
- batches_per_epoch#
- valid_batches_per_epoch#
- ensemble_size#
- save_best_weights#
- save_backup_weights#
- stopping_patience#
- save_every_epoch#
- stop_after_epoch#
- num_epoch#
- save_metric_vars#
- train_one_epoch_mode#
- training_metric#
- direction#
- static _model_state_dict(model: torch.nn.Module) dict#
Return state dict, unwrapping DDP .module wrapper if present.
- static _load_model_state_dict(model: torch.nn.Module, state_dict: dict) None#
Load state dict, unwrapping DDP .module wrapper if present.
- abstractmethod train_one_epoch(epoch: int, trainloader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, scaler: torch.amp.GradScaler, scheduler: torch.optim.lr_scheduler.LRScheduler, metrics: Dict[str, Any]) Dict[str, float]#
- abstractmethod validate(epoch: int, valid_loader: torch.utils.data.DataLoader, criterion: torch.nn.Module, metrics: Dict[str, Any]) Dict[str, float]#
- _log_batch_progress(epoch: int, results_dict: dict, optimizer: torch.optim.Optimizer | None, pbar, phase: str = 'train') None#
Update a tqdm progress bar with rolling-mean batch metrics.
Reads any keys in results_dict that start with
<phase>_and formats them as a space-separated description string. Always appends the current learning rate.- Parameters:
epoch – Current epoch number.
results_dict – Dict mapping metric names to lists of per-batch values.
optimizer – Current optimizer (used to read the learning rate).
pbar – tqdm progress bar to update.
phase – Metric prefix, either
"train"or"valid".
- _save_checkpoint(epoch: int, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, scaler: torch.amp.GradScaler) None#
Save model, optimizer, scheduler, and scaler state.
- fit(conf: Dict[str, Any], train_loader: torch.utils.data.DataLoader, valid_loader: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, train_criterion: torch.nn.Module, valid_criterion: torch.nn.Module, scaler: torch.amp.GradScaler, scheduler: torch.optim.lr_scheduler.LRScheduler, metrics: Dict[str, Any], rollout_scheduler: Callable | None = None, trial: bool = False) Dict[str, Any]#
Run the full training loop.
- Parameters:
conf – Full configuration dict (passed through to train_one_epoch/validate for data-related settings; trainer settings are accessed via self).
train_loader – DataLoaders.
valid_loader – DataLoaders.
optimizer – Training objects.
train_criterion – Training objects.
valid_criterion – Training objects.
scaler – Training objects.
scheduler – Training objects.
metrics – Training objects.
rollout_scheduler – Optional callable to schedule rollout probability.
trial – Optuna trial object, or False.
- Returns:
Dict with the best epoch’s results.