credit.parallel.fsdp2#

FSDP2 wrapping for CREDIT v2 models.

Uses torch.distributed._composable.fsdp.fully_shard (FSDP2) instead of the legacy FullyShardedDataParallel (FSDP1). FSDP2 is composable with TP and does not require a module wrapper — parameters are sharded in-place as DTensors.

Sharding granularity for WXFormer v2:
  • Each Transformer encoder block (one per depth layer per level)

  • Each UpBlock / UpBlockPS decoder block

  • The full model (outermost shard)

Checkpoint I/O:
  • Use torch.distributed.checkpoint.state_dict.get_model_state_dict / set_model_state_dict (DCP) rather than torch.save/load.

  • A helper is provided: fsdp2_state_dict / fsdp2_load_state_dict.

Attributes#

Functions#

apply_fsdp2(→ torch.nn.Module)

Apply FSDP2 to model using the data-parallel submesh.

_fix_spectral_norm_dtype(→ None)

Cast spectral norm u/v buffers to match the FSDP2 parameter dtype.

_build_mp_policy(conf)

Build FSDP2 MixedPrecision policy from config.

fsdp2_is_applied(→ bool)

Return True if fully_shard was actually applied to this model.

_has_fsdp2_shard(→ bool)

Return True if module opted into per-block FSDP2 sharding.

_apply_activation_checkpointing(→ None)

Apply no-reentrant AC to all modules that declared _fsdp2_shard = True.

_is_shardable(→ bool)

Return True if module should receive its own FSDP2 shard.

fsdp2_state_dict(→ dict)

Gather a full (unsharded) state dict from an FSDP2 model.

fsdp2_load_state_dict(→ None)

Load a full state dict into an FSDP2 model.

fsdp2_optimizer_state_dict(→ dict)

Gather a full (unsharded) optimizer state dict from an FSDP2 model.

fsdp2_load_optimizer_state_dict(→ None)

Load a full optimizer state dict into an FSDP2-sharded optimizer.

Module Contents#

credit.parallel.fsdp2.logger#
credit.parallel.fsdp2.apply_fsdp2(model: torch.nn.Module, dp_mesh, conf: dict) torch.nn.Module#

Apply FSDP2 to model using the data-parallel submesh.

Shards Transformer and UpBlock/UpBlockPS submodules first, then wraps the whole model.

Parameters:
  • model – Raw (or TP-converted) model.

  • dp_mesh – 1-D DeviceMesh for the data-parallel dimension. Pass None to shard over the default global mesh.

  • conf – Full training config dict (reads trainer.amp for mp_policy).

Returns:

model with fully_shard applied (in-place, returns same object).

credit.parallel.fsdp2._fix_spectral_norm_dtype(model: torch.nn.Module, param_dtype: torch.dtype) None#

Cast spectral norm u/v buffers to match the FSDP2 parameter dtype.

SpectralNorm registers weight_u and weight_v as fp32 buffers. When FSDP2 casts weight_orig to bfloat16, the power-iteration torch.mv(weight_mat, v) fails with a dtype mismatch. This walks all modules and casts matching buffers in-place.

credit.parallel.fsdp2._build_mp_policy(conf: dict)#

Build FSDP2 MixedPrecision policy from config.

SpectralNorm registers weight_orig as a parameter and performs power iteration using fp32 u/v buffers. Any MixedPrecisionPolicy that casts parameters or inputs creates a dtype conflict inside torch.autocast (at::autocast::prioritize fails on mixed fp32/bf16 tensors in torch.mv).

Strategy: - use_spectral_norm=True: return None (no policy). FSDP2 sharding still

provides memory savings; the trainer disables manual autocast for fsdp2 mode, so all compute runs in fp32. Override via fsdp2_mp_policy to opt in.

  • use_spectral_norm=False: use bfloat16 MixedPrecisionPolicy (full AMP).

credit.parallel.fsdp2.fsdp2_is_applied(model: torch.nn.Module) bool#

Return True if fully_shard was actually applied to this model.

The config can request fsdp2 while the wrapper skips it (dp_size <= 1), so AMP/scaler decisions must check the model, not the config.

credit.parallel.fsdp2._has_fsdp2_shard(module: torch.nn.Module) bool#

Return True if module opted into per-block FSDP2 sharding.

Blocks opt in by setting self._fsdp2_shard = True in __init__ (exposed as an init arg on the wxformer blocks) or by declaring it as a class attribute; instance lookup covers both.

credit.parallel.fsdp2._apply_activation_checkpointing(model: torch.nn.Module) None#

Apply no-reentrant AC to all modules that declared _fsdp2_shard = True.

Uses apply_activation_checkpointing so replacements happen in-place in parent modules. Must be called before fully_shard so the CheckpointWrapper is what gets sharded.

credit.parallel.fsdp2._is_shardable(module: torch.nn.Module, ac_enabled: bool) bool#

Return True if module should receive its own FSDP2 shard.

With AC enabled, shard_type blocks are wrapped in CheckpointWrapper. We shard the CheckpointWrapper, not the inner module — otherwise model.modules() visits both and fully_shard is called twice on the same parameters, corrupting the mesh state.

credit.parallel.fsdp2.fsdp2_state_dict(model: torch.nn.Module) dict#

Gather a full (unsharded) state dict from an FSDP2 model.

Collective — call on all ranks. cpu_offload=True streams the gathered tensors to CPU and returns the state dict on rank 0 only (other ranks get an empty dict), avoiding an all-ranks GPU memory spike of full-model size at every checkpoint save.

credit.parallel.fsdp2.fsdp2_load_state_dict(model: torch.nn.Module, state_dict: dict) None#

Load a full state dict into an FSDP2 model.

credit.parallel.fsdp2.fsdp2_optimizer_state_dict(model: torch.nn.Module, optimizer) dict#

Gather a full (unsharded) optimizer state dict from an FSDP2 model.

A raw optimizer.state_dict() under FSDP2 contains this rank’s DTensor SHARDS only — saving that from rank 0 silently drops every other rank’s optimizer state. Collective — call on all ranks; cpu_offload=True returns the gathered state on rank 0 only (CPU tensors), so non-saving ranks never materialize the ~3x-model-size optimizer state on GPU.

credit.parallel.fsdp2.fsdp2_load_optimizer_state_dict(model: torch.nn.Module, optimizer, state_dict: dict) None#

Load a full optimizer state dict into an FSDP2-sharded optimizer.

Params that never received gradients (unused modules, frozen layers) have no saved state: AdamW creates per-param state lazily on the first step with a non-None grad. set_optimizer_state_dict however requires an entry for every optimizer param and raises KeyError otherwise (e.g. WXFormer’s cube_embedding with patch sizes of 1 is allocated but never called). Synthesize lazy-init (zero) state for those params — equivalent to their pre-first-gradient condition.