credit.parallel.fsdp2
=====================

.. py:module:: credit.parallel.fsdp2

.. autoapi-nested-parse::

   FSDP2 wrapping for CREDIT v2 models.

   Uses torch.distributed._composable.fsdp.fully_shard (FSDP2) instead of
   the legacy FullyShardedDataParallel (FSDP1). FSDP2 is composable with TP
   and does not require a module wrapper — parameters are sharded in-place as
   DTensors.

   Sharding granularity for WXFormer v2:
     - Each Transformer encoder block  (one per depth layer per level)
     - Each UpBlock / UpBlockPS decoder block
     - The full model (outermost shard)

   Checkpoint I/O:
     - Use torch.distributed.checkpoint.state_dict.get_model_state_dict /
       set_model_state_dict (DCP) rather than torch.save/load.
     - A helper is provided: fsdp2_state_dict / fsdp2_load_state_dict.



Attributes
----------

.. autoapisummary::

   credit.parallel.fsdp2.logger


Functions
---------

.. autoapisummary::

   credit.parallel.fsdp2.apply_fsdp2
   credit.parallel.fsdp2._fix_spectral_norm_dtype
   credit.parallel.fsdp2._build_mp_policy
   credit.parallel.fsdp2.fsdp2_is_applied
   credit.parallel.fsdp2._has_fsdp2_shard
   credit.parallel.fsdp2._apply_activation_checkpointing
   credit.parallel.fsdp2._is_shardable
   credit.parallel.fsdp2.fsdp2_state_dict
   credit.parallel.fsdp2.fsdp2_load_state_dict
   credit.parallel.fsdp2.fsdp2_optimizer_state_dict
   credit.parallel.fsdp2.fsdp2_load_optimizer_state_dict


Module Contents
---------------

.. py:data:: logger

.. py:function:: apply_fsdp2(model: torch.nn.Module, dp_mesh, conf: dict) -> torch.nn.Module

   Apply FSDP2 to model using the data-parallel submesh.

   Shards Transformer and UpBlock/UpBlockPS submodules first, then
   wraps the whole model.

   :param model: Raw (or TP-converted) model.
   :param dp_mesh: 1-D DeviceMesh for the data-parallel dimension.
                   Pass None to shard over the default global mesh.
   :param conf: Full training config dict (reads trainer.amp for mp_policy).

   :returns: model with fully_shard applied (in-place, returns same object).


.. py:function:: _fix_spectral_norm_dtype(model: torch.nn.Module, param_dtype: torch.dtype) -> None

   Cast spectral norm u/v buffers to match the FSDP2 parameter dtype.

   SpectralNorm registers `weight_u` and `weight_v` as fp32 buffers.
   When FSDP2 casts `weight_orig` to bfloat16, the power-iteration
   `torch.mv(weight_mat, v)` fails with a dtype mismatch.
   This walks all modules and casts matching buffers in-place.


.. py:function:: _build_mp_policy(conf: dict)

   Build FSDP2 MixedPrecision policy from config.

   SpectralNorm registers weight_orig as a parameter and performs power
   iteration using fp32 u/v buffers. Any MixedPrecisionPolicy that casts
   parameters or inputs creates a dtype conflict inside torch.autocast
   (at::autocast::prioritize fails on mixed fp32/bf16 tensors in torch.mv).

   Strategy:
   - use_spectral_norm=True: return None (no policy). FSDP2 sharding still
     provides memory savings; the trainer disables manual autocast for fsdp2
     mode, so all compute runs in fp32. Override via fsdp2_mp_policy to opt in.
   - use_spectral_norm=False: use bfloat16 MixedPrecisionPolicy (full AMP).


.. py:function:: fsdp2_is_applied(model: torch.nn.Module) -> bool

   Return True if ``fully_shard`` was actually applied to this model.

   The config can request fsdp2 while the wrapper skips it (dp_size <= 1),
   so AMP/scaler decisions must check the model, not the config.


.. py:function:: _has_fsdp2_shard(module: torch.nn.Module) -> bool

   Return True if module opted into per-block FSDP2 sharding.

   Blocks opt in by setting ``self._fsdp2_shard = True`` in ``__init__``
   (exposed as an init arg on the wxformer blocks) or by declaring it as a
   class attribute; instance lookup covers both.


.. py:function:: _apply_activation_checkpointing(model: torch.nn.Module) -> None

   Apply no-reentrant AC to all modules that declared ``_fsdp2_shard = True``.

   Uses apply_activation_checkpointing so replacements happen in-place in parent
   modules. Must be called before fully_shard so the CheckpointWrapper is what
   gets sharded.


.. py:function:: _is_shardable(module: torch.nn.Module, ac_enabled: bool) -> bool

   Return True if module should receive its own FSDP2 shard.

   With AC enabled, shard_type blocks are wrapped in CheckpointWrapper.
   We shard the CheckpointWrapper, not the inner module — otherwise
   model.modules() visits both and fully_shard is called twice on the
   same parameters, corrupting the mesh state.


.. py:function:: fsdp2_state_dict(model: torch.nn.Module) -> dict

   Gather a full (unsharded) state dict from an FSDP2 model.

   Collective — call on all ranks. ``cpu_offload=True`` streams the gathered
   tensors to CPU and returns the state dict on rank 0 only (other ranks get
   an empty dict), avoiding an all-ranks GPU memory spike of full-model size
   at every checkpoint save.


.. py:function:: fsdp2_load_state_dict(model: torch.nn.Module, state_dict: dict) -> None

   Load a full state dict into an FSDP2 model.


.. py:function:: fsdp2_optimizer_state_dict(model: torch.nn.Module, optimizer) -> dict

   Gather a full (unsharded) optimizer state dict from an FSDP2 model.

   A raw ``optimizer.state_dict()`` under FSDP2 contains this rank's DTensor
   SHARDS only — saving that from rank 0 silently drops every other rank's
   optimizer state. Collective — call on all ranks; ``cpu_offload=True``
   returns the gathered state on rank 0 only (CPU tensors), so non-saving
   ranks never materialize the ~3x-model-size optimizer state on GPU.


.. py:function:: fsdp2_load_optimizer_state_dict(model: torch.nn.Module, optimizer, state_dict: dict) -> None

   Load a full optimizer state dict into an FSDP2-sharded optimizer.

   Params that never received gradients (unused modules, frozen layers) have
   no saved state: AdamW creates per-param state lazily on the first step
   with a non-None grad. ``set_optimizer_state_dict`` however requires an
   entry for every optimizer param and raises KeyError otherwise (e.g.
   WXFormer's cube_embedding with patch sizes of 1 is allocated but never
   called). Synthesize lazy-init (zero) state for those params — equivalent
   to their pre-first-gradient condition.


