train#
- Content
load_dataset_and_sampler_zscore_only load_model_states_and_optimizer
Attributes#
Classes#
Optuna objective class for hyperparameter optimization. |
Functions#
|
Load the Z-score only dataset and sampler for training or validation. |
|
Load the model states, optimizer, scheduler, and gradient scaler. |
|
Main function to set up training and validation processes. |
Module Contents#
- train.load_dataset_and_sampler_zscore_only(conf, all_ERA_files, surface_files, dyn_forcing_files, diagnostic_files, world_size, rank, is_train)#
Load the Z-score only dataset and sampler for training or validation.
- Parameters:
conf (dict) – Configuration dictionary containing dataset and training parameters.
all_ERA_files (list) – List of ERA file paths.
surface_files (list) – List of surface file paths.
dyn_forcing_files (list) – List of dynamic forcing file paths.
diagnostic_files (list) – List of diagnostic file paths.
world_size (int) – Number of processes participating in the job.
rank (int) – Rank of the current process.
is_train (bool) – Flag indicating whether the dataset is for training or validation.
- Returns:
A tuple containing the dataset and the distributed sampler.
- Return type:
tuple
- train.load_model_states_and_optimizer(conf, model, device)#
Load the model states, optimizer, scheduler, and gradient scaler.
- Parameters:
conf (dict) – Configuration dictionary containing training parameters.
model (torch.nn.Module) – The model to be trained.
device (torch.device) – The device (CPU or GPU) where the model is located.
- Returns:
A tuple containing the updated configuration, model, optimizer, scheduler, and scaler.
- Return type:
tuple
- train.main(rank, world_size, conf, backend, trial=False)#
Main function to set up training and validation processes.
- Parameters:
rank (int) – Rank of the current process.
world_size (int) – Number of processes participating in the job.
conf (dict) – Configuration dictionary containing model, data, and training parameters.
backend (str) – Backend to be used for distributed training.
trial (bool, optional) – Flag for whether this is an Optuna trial. Defaults to False.
- Returns:
The result of the training process.
- Return type:
Any
- class train.Objective(config, metric='val_loss', device='cpu')#
Bases:
echo.src.base_objective.BaseObjectiveOptuna objective class for hyperparameter optimization.
- config#
Configuration dictionary containing training parameters.
- Type:
dict
- metric#
Metric to optimize, defaults to “val_loss”.
- Type:
str
- device#
Device for training, defaults to “cpu”.
- Type:
str
- train(trial, conf)#
Train the model using the given trial configuration.
- Parameters:
trial (optuna.trial.Trial) – Optuna trial object.
conf (dict) – Configuration dictionary for the current trial.
- Returns:
The result of the training process.
- Return type:
Any
- train.description = 'Train a segmengation model on a hologram data set'#