credit.parallel.fsdp2#
FSDP2 wrapping for CREDIT v2 models.
Uses torch.distributed._composable.fsdp.fully_shard (FSDP2) instead of the legacy FullyShardedDataParallel (FSDP1). FSDP2 is composable with TP and does not require a module wrapper — parameters are sharded in-place as DTensors.
- Sharding granularity for WXFormer v2:
Each Transformer encoder block (one per depth layer per level)
Each UpBlock / UpBlockPS decoder block
The full model (outermost shard)
- Checkpoint I/O:
Use torch.distributed.checkpoint.state_dict.get_model_state_dict / set_model_state_dict (DCP) rather than torch.save/load.
A helper is provided: fsdp2_state_dict / fsdp2_load_state_dict.
Attributes#
Functions#
|
Apply FSDP2 to model using the data-parallel submesh. |
|
Cast spectral norm u/v buffers to match the FSDP2 parameter dtype. |
|
Build FSDP2 MixedPrecision policy from config. |
|
Return True if |
|
Return True if module opted into per-block FSDP2 sharding. |
|
Apply no-reentrant AC to all modules that declared |
|
Return True if module should receive its own FSDP2 shard. |
|
Gather a full (unsharded) state dict from an FSDP2 model. |
|
Load a full state dict into an FSDP2 model. |
|
Gather a full (unsharded) optimizer state dict from an FSDP2 model. |
|
Load a full optimizer state dict into an FSDP2-sharded optimizer. |
Module Contents#
- credit.parallel.fsdp2.logger#
- credit.parallel.fsdp2.apply_fsdp2(model: torch.nn.Module, dp_mesh, conf: dict) torch.nn.Module#
Apply FSDP2 to model using the data-parallel submesh.
Shards Transformer and UpBlock/UpBlockPS submodules first, then wraps the whole model.
- Parameters:
model – Raw (or TP-converted) model.
dp_mesh – 1-D DeviceMesh for the data-parallel dimension. Pass None to shard over the default global mesh.
conf – Full training config dict (reads trainer.amp for mp_policy).
- Returns:
model with fully_shard applied (in-place, returns same object).
- credit.parallel.fsdp2._fix_spectral_norm_dtype(model: torch.nn.Module, param_dtype: torch.dtype) None#
Cast spectral norm u/v buffers to match the FSDP2 parameter dtype.
SpectralNorm registers weight_u and weight_v as fp32 buffers. When FSDP2 casts weight_orig to bfloat16, the power-iteration torch.mv(weight_mat, v) fails with a dtype mismatch. This walks all modules and casts matching buffers in-place.
- credit.parallel.fsdp2._build_mp_policy(conf: dict)#
Build FSDP2 MixedPrecision policy from config.
SpectralNorm registers weight_orig as a parameter and performs power iteration using fp32 u/v buffers. Any MixedPrecisionPolicy that casts parameters or inputs creates a dtype conflict inside torch.autocast (at::autocast::prioritize fails on mixed fp32/bf16 tensors in torch.mv).
Strategy: - use_spectral_norm=True: return None (no policy). FSDP2 sharding still
provides memory savings; the trainer disables manual autocast for fsdp2 mode, so all compute runs in fp32. Override via fsdp2_mp_policy to opt in.
use_spectral_norm=False: use bfloat16 MixedPrecisionPolicy (full AMP).
- credit.parallel.fsdp2.fsdp2_is_applied(model: torch.nn.Module) bool#
Return True if
fully_shardwas actually applied to this model.The config can request fsdp2 while the wrapper skips it (dp_size <= 1), so AMP/scaler decisions must check the model, not the config.
- credit.parallel.fsdp2._has_fsdp2_shard(module: torch.nn.Module) bool#
Return True if module opted into per-block FSDP2 sharding.
Blocks opt in by setting
self._fsdp2_shard = Truein__init__(exposed as an init arg on the wxformer blocks) or by declaring it as a class attribute; instance lookup covers both.
- credit.parallel.fsdp2._apply_activation_checkpointing(model: torch.nn.Module) None#
Apply no-reentrant AC to all modules that declared
_fsdp2_shard = True.Uses apply_activation_checkpointing so replacements happen in-place in parent modules. Must be called before fully_shard so the CheckpointWrapper is what gets sharded.
- credit.parallel.fsdp2._is_shardable(module: torch.nn.Module, ac_enabled: bool) bool#
Return True if module should receive its own FSDP2 shard.
With AC enabled, shard_type blocks are wrapped in CheckpointWrapper. We shard the CheckpointWrapper, not the inner module — otherwise model.modules() visits both and fully_shard is called twice on the same parameters, corrupting the mesh state.
- credit.parallel.fsdp2.fsdp2_state_dict(model: torch.nn.Module) dict#
Gather a full (unsharded) state dict from an FSDP2 model.
Collective — call on all ranks.
cpu_offload=Truestreams the gathered tensors to CPU and returns the state dict on rank 0 only (other ranks get an empty dict), avoiding an all-ranks GPU memory spike of full-model size at every checkpoint save.
- credit.parallel.fsdp2.fsdp2_load_state_dict(model: torch.nn.Module, state_dict: dict) None#
Load a full state dict into an FSDP2 model.
- credit.parallel.fsdp2.fsdp2_optimizer_state_dict(model: torch.nn.Module, optimizer) dict#
Gather a full (unsharded) optimizer state dict from an FSDP2 model.
A raw
optimizer.state_dict()under FSDP2 contains this rank’s DTensor SHARDS only — saving that from rank 0 silently drops every other rank’s optimizer state. Collective — call on all ranks;cpu_offload=Truereturns the gathered state on rank 0 only (CPU tensors), so non-saving ranks never materialize the ~3x-model-size optimizer state on GPU.
- credit.parallel.fsdp2.fsdp2_load_optimizer_state_dict(model: torch.nn.Module, optimizer, state_dict: dict) None#
Load a full optimizer state dict into an FSDP2-sharded optimizer.
Params that never received gradients (unused modules, frozen layers) have no saved state: AdamW creates per-param state lazily on the first step with a non-None grad.
set_optimizer_state_dicthowever requires an entry for every optimizer param and raises KeyError otherwise (e.g. WXFormer’s cube_embedding with patch sizes of 1 is allocated but never called). Synthesize lazy-init (zero) state for those params — equivalent to their pre-first-gradient condition.