train_les
=========

.. py:module:: train_les


Classes
-------

.. autoapisummary::

   train_les.Objective


Functions
---------

.. autoapisummary::

   train_les.load_dataset_and_sampler
   train_les.load_model_states_and_optimizer
   train_les.main
   train_les.primary_main


Module Contents
---------------

.. py:function:: load_dataset_and_sampler(conf, param_interior, world_size, rank, is_train)

   Load the Z-score only dataset and sampler for training or validation.


.. py:function:: load_model_states_and_optimizer(conf, model, device)

   Load the model states, optimizer, scheduler, and gradient scaler.

   :param conf: Configuration dictionary containing training parameters.
   :type conf: dict
   :param model: The model to be trained.
   :type model: torch.nn.Module
   :param device: The device (CPU or GPU) where the model is located.
   :type device: torch.device

   :returns: A tuple containing the updated configuration, model, optimizer, scheduler, and scaler.
   :rtype: tuple


.. py:function:: main(rank, world_size, conf, backend, trial=False)

   Main function to set up training and validation processes.

   :param rank: Rank of the current process.
   :type rank: int
   :param world_size: Number of processes participating in the job.
   :type world_size: int
   :param conf: Configuration dictionary containing model, data, and training parameters.
   :type conf: dict
   :param backend: Backend to be used for distributed training.
   :type backend: str
   :param trial: Flag for whether this is an Optuna trial. Defaults to False.
   :type trial: bool, optional

   :returns: The result of the training process.
   :rtype: Any


.. py:class:: Objective(config, metric='val_loss', device='cpu')

   Bases: :py:obj:`echo.src.base_objective.BaseObjective`


   Optuna objective class for hyperparameter optimization.

   .. attribute:: config

      Configuration dictionary containing training parameters.

      :type: dict

   .. attribute:: metric

      Metric to optimize, defaults to "val_loss".

      :type: str

   .. attribute:: device

      Device for training, defaults to "cpu".

      :type: str


   .. py:method:: train(trial, conf)

      Train the model using the given trial configuration.

      :param trial: Optuna trial object.
      :type trial: optuna.trial.Trial
      :param conf: Configuration dictionary for the current trial.
      :type conf: dict

      :returns: The result of the training process.
      :rtype: Any



.. py:function:: primary_main()

