credit.trainers#
Submodules#
Attributes#
Classes#
Helper class that provides a standard way to create an ABC using |
|
Helper class that provides a standard way to create an ABC using |
|
Helper class that provides a standard way to create an ABC using |
Functions#
|
Package Contents#
- class credit.trainers.TrainerERA5(model: torch.nn.Module, rank: int)#
Bases:
credit.trainers.base_trainer.BaseTrainerHelper class that provides a standard way to create an ABC using inheritance.
- train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler, scheduler, metrics)#
Trains the model for one epoch.
- Parameters:
epoch (int) – Current epoch number.
conf (dict) – Configuration dictionary containing training settings.
trainloader (DataLoader) – DataLoader for the training dataset.
optimizer (torch.optim.Optimizer) – Optimizer used for training.
criterion (callable) – Loss function used for training.
scaler (torch.cuda.amp.GradScaler) – Gradient scaler for mixed precision training.
scheduler (torch.optim.lr_scheduler._LRScheduler) – Learning rate scheduler.
metrics (callable) – Function to compute metrics for evaluation.
- Returns:
Dictionary containing training metrics and loss for the epoch.
- Return type:
dict
- validate(epoch, conf, valid_loader, criterion, metrics)#
Validates the model on the validation dataset.
- Parameters:
epoch (int) – Current epoch number.
conf (dict) – Configuration dictionary containing validation settings.
valid_loader (DataLoader) – DataLoader for the validation dataset.
criterion (callable) – Loss function used for validation.
metrics (callable) – Function to compute metrics for evaluation.
- Returns:
Dictionary containing validation metrics and loss for the epoch.
- Return type:
dict
- class credit.trainers.TrainerEnsemble(model: torch.nn.Module, rank: int)#
Bases:
credit.trainers.base_trainer.BaseTrainerHelper class that provides a standard way to create an ABC using inheritance.
- train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler, scheduler, metrics)#
Trains the model for one epoch.
- Parameters:
epoch (int) – Current epoch number.
conf (dict) – Configuration dictionary containing training settings.
trainloader (DataLoader) – DataLoader for the training dataset.
optimizer (torch.optim.Optimizer) – Optimizer used for training.
criterion (callable) – Loss function used for training.
scaler (torch.cuda.amp.GradScaler) – Gradient scaler for mixed precision training.
scheduler (torch.optim.lr_scheduler._LRScheduler) – Learning rate scheduler.
metrics (callable) – Function to compute metrics for evaluation.
- Returns:
Dictionary containing training metrics and loss for the epoch.
- Return type:
dict
- validate(epoch, conf, valid_loader, criterion, metrics)#
Validates the model on the validation dataset.
- Parameters:
epoch (int) – Current epoch number.
conf (dict) – Configuration dictionary containing validation settings.
valid_loader (DataLoader) – DataLoader for the validation dataset.
criterion (callable) – Loss function used for validation.
metrics (callable) – Function to compute metrics for evaluation.
- Returns:
Dictionary containing validation metrics and loss for the epoch.
- Return type:
dict
- class credit.trainers.Trainer404(model: torch.nn.Module, rank: int)#
Bases:
credit.trainers.base_trainer.BaseTrainerHelper class that provides a standard way to create an ABC using inheritance.
- train_one_epoch(epoch, conf, trainloader, optimizer, criterion, scaler, scheduler, metrics)#
Train the model for one epoch.
- Parameters:
epoch (int) – The current epoch number.
conf (Dict[str, Any]) – The configuration dictionary.
trainloader (torch.utils.data.DataLoader) – The training data loader.
optimizer (torch.optim.Optimizer) – The optimizer.
criterion (torch.nn.Module) – The loss function.
scaler (torch.cuda.amp.GradScaler) – The gradient scaler for mixed precision training.
scheduler (torch.optim.lr_scheduler.LRScheduler) – The learning rate scheduler.
metrics (Dict[str, Any]) – The metrics to track during training.
- Returns:
A dictionary containing the training results.
- Return type:
Dict[str, float]
- validate(epoch, conf, valid_loader, criterion, metrics)#
Validate the model on the validation set.
- Parameters:
epoch (int) – The current epoch number.
conf (Dict[str, Any]) – The configuration dictionary.
valid_loader (torch.utils.data.DataLoader) – The validation data loader.
criterion (torch.nn.Module) – The loss function.
metrics (Dict[str, Any]) – The metrics to track during validation.
- Returns:
A dictionary containing the validation results.
- Return type:
Dict[str, float]
- fit(conf, train_loader, valid_loader, optimizer, train_criterion, valid_criterion, scaler, scheduler, metrics, rollout_scheduler=None, trial=False)#
Fit the model to the data.
- Parameters:
conf (Dict[str, Any]) – Configuration dictionary.
train_loader (DataLoader) – DataLoader for training data.
valid_loader (DataLoader) – DataLoader for validation data.
optimizer (Optimizer) – The optimizer to use for training.
train_criterion (torch.nn.Module) – Loss function for training.
valid_criterion (torch.nn.Module) – Loss function for validation.
scaler (GradScaler) – Gradient scaler for mixed precision training.
scheduler (_LRScheduler) – Learning rate scheduler.
metrics (Dict[str, Any]) – Dictionary of metrics to track during training.
rollout_scheduler (Optional[callable]) – Function to schedule rollout probability, if applicable.
trial (bool) – Whether this is a trial run (e.g., for hyperparameter tuning).
- Returns:
Dictionary containing the best results from training.
- Return type:
Dict[str, Any]
- credit.trainers.logger#
- credit.trainers.trainer_types#
- credit.trainers.load_trainer(conf, load_weights=False)#