credit.models.reset

credit.models.reset#

Attributes#

Functions#

reset_model(fsdp_model, model[, return_unwrapped, ...])

Resets a model wrapped with FullyShardedDataParallel (FSDP) and optionally returns the unwrapped model or weights.

Module Contents#

credit.models.reset.reset_model(fsdp_model: torch.distributed.fsdp.FullyShardedDataParallel, model: Callable, return_unwrapped=False, return_weights=False)#

Resets a model wrapped with FullyShardedDataParallel (FSDP) and optionally returns the unwrapped model or weights.

Parameters:
  • fsdp_model (FSDP) – The model wrapped with FSDP or a custom TorchFSDPModel.

  • model (Callable) – A callable to instantiate the original model before FSDP wrapping.

  • return_unwrapped (bool, optional) – If True, returns the unwrapped model before DDP/FSDP wrapping. Defaults to False.

  • return_weights (bool, optional) – If True, returns the state dictionary of the model’s weights. Defaults to False.

Raises:

TypeError – If fsdp_model is not of type FSDP or TorchFSDPModel.

Returns:

The model wrapped with DistributedDataParallel (DDP) for distributed training. nonwrapped_model (nn.Module): The original, non-wrapped model if return_unwrapped=True. state_dict (dict): The state dictionary of the model if return_weights=True.

Return type:

ddp_model (DDP)

credit.models.reset.rank = -1#