credit.models.reset#
Attributes#
Functions#
|
Resets a model wrapped with FullyShardedDataParallel (FSDP) and optionally returns the unwrapped model or weights. |
Module Contents#
- credit.models.reset.reset_model(fsdp_model: torch.distributed.fsdp.FullyShardedDataParallel, model: Callable, return_unwrapped=False, return_weights=False)#
Resets a model wrapped with FullyShardedDataParallel (FSDP) and optionally returns the unwrapped model or weights.
- Parameters:
fsdp_model (FSDP) – The model wrapped with FSDP or a custom TorchFSDPModel.
model (Callable) – A callable to instantiate the original model before FSDP wrapping.
return_unwrapped (bool, optional) – If True, returns the unwrapped model before DDP/FSDP wrapping. Defaults to False.
return_weights (bool, optional) – If True, returns the state dictionary of the model’s weights. Defaults to False.
- Raises:
TypeError – If fsdp_model is not of type FSDP or TorchFSDPModel.
- Returns:
The model wrapped with DistributedDataParallel (DDP) for distributed training. nonwrapped_model (nn.Module): The original, non-wrapped model if return_unwrapped=True. state_dict (dict): The state dictionary of the model if return_weights=True.
- Return type:
ddp_model (DDP)
- credit.models.reset.rank = -1#