credit.models.reset
===================

.. py:module:: credit.models.reset


Attributes
----------

.. autoapisummary::

   credit.models.reset.rank


Functions
---------

.. autoapisummary::

   credit.models.reset.reset_model


Module Contents
---------------

.. py:function:: reset_model(fsdp_model: torch.distributed.fsdp.FullyShardedDataParallel, model: Callable, return_unwrapped=False, return_weights=False)

   Resets a model wrapped with FullyShardedDataParallel (FSDP) and optionally returns the unwrapped model or weights.

   :param fsdp_model: The model wrapped with FSDP or a custom TorchFSDPModel.
   :type fsdp_model: FSDP
   :param model: A callable to instantiate the original model before FSDP wrapping.
   :type model: Callable
   :param return_unwrapped: If True, returns the unwrapped model before DDP/FSDP wrapping. Defaults to False.
   :type return_unwrapped: bool, optional
   :param return_weights: If True, returns the state dictionary of the model's weights. Defaults to False.
   :type return_weights: bool, optional

   :raises TypeError: If `fsdp_model` is not of type `FSDP` or `TorchFSDPModel`.

   :returns: The model wrapped with DistributedDataParallel (DDP) for distributed training.
             nonwrapped_model (nn.Module): The original, non-wrapped model if `return_unwrapped=True`.
             state_dict (dict): The state dictionary of the model if `return_weights=True`.
   :rtype: ddp_model (DDP)


.. py:data:: rank
   :value: -1


