credit.models.checkpoint
========================

.. py:module:: credit.models.checkpoint


Attributes
----------

.. autoapisummary::

   credit.models.checkpoint.model


Classes
-------

.. autoapisummary::

   credit.models.checkpoint.TorchFSDPCheckpointIO
   credit.models.checkpoint.ModelWrapper
   credit.models.checkpoint.TorchFSDPModel
   credit.models.checkpoint.OptimizerWrapper
   credit.models.checkpoint.FSDPOptimizerWrapper


Functions
---------

.. autoapisummary::

   credit.models.checkpoint.load_state_dict_error_handler
   credit.models.checkpoint.get_file_extension
   credit.models.checkpoint.copy_checkpoint
   credit.models.checkpoint.load_model_state
   credit.models.checkpoint.save_state_dict
   credit.models.checkpoint.load_state_dict
   credit.models.checkpoint.is_dtensor_checkpoint
   credit.models.checkpoint.is_safetensor_checkpoint
   credit.models.checkpoint.is_safetensors_available


Module Contents
---------------

.. py:function:: load_state_dict_error_handler(load_msg)

.. py:function:: get_file_extension(file_path)

.. py:function:: copy_checkpoint(checkpoint_file_path: str, number) -> None

   Copy every checkpoint afterit's saved.

   :param state_dict: state dict.
   :type state_dict: dict
   :param checkpoint_file_path: path to the checkpoint file.
   :type checkpoint_file_path: str
   :param use_safetensors: whether to use safetensors to save the checkpoint.
   :type use_safetensors: bool


.. py:function:: load_model_state(conf, model, device)

.. py:function:: save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None

   Save state dict to checkpoint.

   :param state_dict: state dict.
   :type state_dict: dict
   :param checkpoint_file_path: path to the checkpoint file.
   :type checkpoint_file_path: str
   :param use_safetensors: whether to use safetensors to save the checkpoint.
   :type use_safetensors: bool


.. py:function:: load_state_dict(checkpoint_file_path: pathlib.Path)

   Load state dict from checkpoint.

   :param checkpoint_file_path: path to the checkpoint file.
   :type checkpoint_file_path: Path

   :returns: state dict.
   :rtype: dict


.. py:function:: is_dtensor_checkpoint(checkpoint_file_path: str) -> bool

   Check whether the checkpoint file is a dtensor checkpoint.

   :param checkpoint_file_path: path to the checkpoint file.
   :type checkpoint_file_path: str

   :returns: whether the checkpoint file is a dtensor checkpoint.
   :rtype: bool


.. py:function:: is_safetensor_checkpoint(checkpoint_file_path: str) -> bool

   Check whether the checkpoint file is a safetensor checkpoint.

   :param checkpoint_file_path: path to the checkpoint file.
   :type checkpoint_file_path: str

   :returns: whether the checkpoint file is a safetensor checkpoint.
   :rtype: bool


.. py:function:: is_safetensors_available() -> bool

   Check whether safetensors is available.

   :returns: whether safetensors is available.
   :rtype: bool


.. py:class:: TorchFSDPCheckpointIO

   .. py:method:: load_unsharded_model(model, checkpoint)


   .. py:method:: load_unsharded_optimizer(optimizer, checkpoint)


   .. py:method:: save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, rank)

      Save model to checkpoint but only on master process.



   .. py:method:: save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, rank)

      Save optimizer to checkpoint but only on master process.



.. py:class:: ModelWrapper(module: torch.nn.Module)

   Bases: :py:obj:`torch.nn.Module`


   A wrapper class to define the common interface used FSDP.

   :param module: The model to be wrapped.
   :type module: nn.Module


   .. py:attribute:: module


   .. py:method:: unwrap()

      Unwrap the model to return the original model for checkpoint saving/loading.



   .. py:method:: forward(*args, **kwargs)


.. py:class:: TorchFSDPModel(module, *args, **kwargs)

   Bases: :py:obj:`ModelWrapper`


   A wrapper class to define the common interface used FSDP.

   :param module: The model to be wrapped.
   :type module: nn.Module


   .. py:attribute:: module


   .. py:method:: unwrap()

      Unwrap the model to return the original model for checkpoint saving/loading.



   .. py:method:: concat_and_reshape(x1, x2)


   .. py:method:: reshape_only(x1)

      As in "concat_and_reshape", but for upper-air variables only.



.. py:class:: OptimizerWrapper(optim: torch.optim.Optimizer)

   A standard interface for optimizers wrapped by the Booster.

   :param optim: The optimizer to be wrapped.
   :type optim: Optimizer


   .. py:attribute:: optim


   .. py:property:: parameters


   .. py:property:: param_groups


   .. py:property:: defaults


   .. py:method:: add_param_group(*args, **kwargs)


   .. py:method:: step(*args, **kwargs)

      Performs a single optimization step.



   .. py:method:: zero_grad(*args, **kwargs)

      Clears the gradients of all optimized `torch.Tensor`.



   .. py:method:: backward(loss: torch.Tensor, *args, **kwargs)

      Performs a backward pass on the loss.



   .. py:method:: backward_by_grad(tensor: torch.Tensor, grad: torch.Tensor)


   .. py:method:: state_dict()

      Returns the optimizer state.



   .. py:method:: load_state_dict(*args, **kwargs)

      Loads the optimizer state.



   .. py:method:: clip_grad_by_value(clip_value: float, *args, **kwargs) -> None

      Clips gradient of an iterable of parameters at specified min and max values.

      :param clip_value: maximum allowed value of the gradients. Gradients are clipped in the range
      :type clip_value: float or int

      .. note::

         In PyTorch 2.0 and above, you can pass in foreach=True as kwargs to `clip_grad_value_` to use the
         faster implementation. Please refer to the PyTorch documentation for more details.



   .. py:method:: clip_grad_by_norm(max_norm: Union[float, int], norm_type: Union[float, int] = 2.0, error_if_nonfinite: bool = False, *args, **kwargs) -> torch.Tensor

      Clips gradient norm of an iterable of parameters.

      :param max_norm: max norm of the gradients
      :type max_norm: float or int
      :param norm_type: type of the used p-norm. Can be ``'inf'`` for infinity norm.
      :type norm_type: float or int
      :param error_if_nonfinite: if True, an error is raised if the total norm is non-finite. Default: False
      :type error_if_nonfinite: bool

      .. note::

         In PyTorch 2.0 and above, you can pass in foreach=True as kwargs to `clip_grad_norm_` to use the
         faster implementation. Please refer to the PyTorch documentation for more details.



   .. py:method:: scale_loss(loss: torch.Tensor)
      :abstractmethod:


      Scales the loss for mixed precision training.

      Note: Only available for optimizers with mixed precision training.

      :param loss: The loss to be scaled.
      :type loss: Tensor



   .. py:method:: unscale_grad()
      :abstractmethod:


      Unscale the gradients for mixed precision training.

      Note: Only available for optimizers with mixed precision training.



   .. py:method:: unwrap()

      Unwrap the optimizer for checkpoint saving/loading.



.. py:class:: FSDPOptimizerWrapper(optimizer, model)

   Bases: :py:obj:`OptimizerWrapper`


   A standard interface for optimizers wrapped by the Booster.

   :param optim: The optimizer to be wrapped.
   :type optim: Optimizer


   .. py:attribute:: model


   .. py:method:: unwrap_model() -> torch.nn.Module


.. py:data:: model

