train#

Content

load_dataset_and_sampler_zscore_only load_model_states_and_optimizer

Attributes#

Classes#

Objective

Optuna objective class for hyperparameter optimization.

Functions#

load_dataset_and_sampler_zscore_only(conf, ...)

Load the Z-score only dataset and sampler for training or validation.

load_model_states_and_optimizer(conf, model, device)

Load the model states, optimizer, scheduler, and gradient scaler.

main(rank, world_size, conf, backend[, trial])

Main function to set up training and validation processes.

Module Contents#

train.load_dataset_and_sampler_zscore_only(conf, all_ERA_files, surface_files, dyn_forcing_files, diagnostic_files, world_size, rank, is_train)#

Load the Z-score only dataset and sampler for training or validation.

Parameters:
  • conf (dict) – Configuration dictionary containing dataset and training parameters.

  • all_ERA_files (list) – List of ERA file paths.

  • surface_files (list) – List of surface file paths.

  • dyn_forcing_files (list) – List of dynamic forcing file paths.

  • diagnostic_files (list) – List of diagnostic file paths.

  • world_size (int) – Number of processes participating in the job.

  • rank (int) – Rank of the current process.

  • is_train (bool) – Flag indicating whether the dataset is for training or validation.

Returns:

A tuple containing the dataset and the distributed sampler.

Return type:

tuple

train.load_model_states_and_optimizer(conf, model, device)#

Load the model states, optimizer, scheduler, and gradient scaler.

Parameters:
  • conf (dict) – Configuration dictionary containing training parameters.

  • model (torch.nn.Module) – The model to be trained.

  • device (torch.device) – The device (CPU or GPU) where the model is located.

Returns:

A tuple containing the updated configuration, model, optimizer, scheduler, and scaler.

Return type:

tuple

train.main(rank, world_size, conf, backend, trial=False)#

Main function to set up training and validation processes.

Parameters:
  • rank (int) – Rank of the current process.

  • world_size (int) – Number of processes participating in the job.

  • conf (dict) – Configuration dictionary containing model, data, and training parameters.

  • backend (str) – Backend to be used for distributed training.

  • trial (bool, optional) – Flag for whether this is an Optuna trial. Defaults to False.

Returns:

The result of the training process.

Return type:

Any

class train.Objective(config, metric='val_loss', device='cpu')#

Bases: echo.src.base_objective.BaseObjective

Optuna objective class for hyperparameter optimization.

config#

Configuration dictionary containing training parameters.

Type:

dict

metric#

Metric to optimize, defaults to “val_loss”.

Type:

str

device#

Device for training, defaults to “cpu”.

Type:

str

train(trial, conf)#

Train the model using the given trial configuration.

Parameters:
  • trial (optuna.trial.Trial) – Optuna trial object.

  • conf (dict) – Configuration dictionary for the current trial.

Returns:

The result of the training process.

Return type:

Any

train.description = 'Train a segmengation model on a hologram data set'#