credit.parallel
===============

.. py:module:: credit.parallel

.. autoapi-nested-parse::

   CREDIT v2 parallelism package.

   Provides FSDP2, Tensor Parallelism (TP), and integration with Negin's
   domain parallelism — all composed via PyTorch DeviceMesh.

   Config block (trainer.parallelism):
       data:   fsdp2 | ddp | none   — data-parallel mode
       tensor: int >= 1             — TP degree (1 = disabled)
       domain: int >= 1             — spatial domain shards (1 = disabled)

   Total GPUs = dp_size × tensor × domain
     where dp_size = world_size // (tensor × domain)

   Usage (called from distributed_model_wrapper_gen2):
       mesh, submeshes = build_device_mesh(conf["trainer"]["parallelism"])
       if submeshes.get("tp"):
           model = apply_tensor_parallel(model, submeshes["tp"])
       if submeshes.get("domain"):
           model = apply_domain_parallel(model, submeshes["domain"])
       if conf["trainer"]["parallelism"]["data"] == "fsdp2":
           model = apply_fsdp2(model, submeshes.get("dp"), conf)
       elif conf["trainer"]["parallelism"]["data"] == "ddp":
           model = apply_ddp(model, submeshes.get("dp"))



Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/credit/parallel/collectives/index
   /autoapi/credit/parallel/domain/index
   /autoapi/credit/parallel/fsdp2/index
   /autoapi/credit/parallel/mesh/index
   /autoapi/credit/parallel/tensor_parallel/index


Functions
---------

.. autoapisummary::

   credit.parallel.build_device_mesh
   credit.parallel.apply_fsdp2
   credit.parallel.apply_tensor_parallel
   credit.parallel.get_domain_manager
   credit.parallel.get_raw_model
   credit.parallel.shard_spatial
   credit.parallel.unpad_shard_interp
   credit.parallel.sync_domain_gradients


Package Contents
----------------

.. py:function:: build_device_mesh(parallelism_conf: dict, device: str = 'cuda')

   Build a DeviceMesh from a parallelism config block.

   :param parallelism_conf: dict with keys:
                            data   (str): "fsdp2" | "ddp" | "none"
                            tensor (int): TP degree, >= 1
                            domain (int): domain parallel degree, >= 1
   :param device: "cuda" (default) or "cpu" for tests

   :returns: DeviceMesh (or None if no parallelism)
             submeshes: dict mapping dim name -> submesh (or None if single-dim)
                 Keys present: "dp" if dp > 1, "tp" if tp > 1, "domain" if domain > 1
   :rtype: mesh

   :raises ValueError: if world_size is not divisible by tensor * domain.


.. py:function:: apply_fsdp2(model: torch.nn.Module, dp_mesh, conf: dict) -> torch.nn.Module

   Apply FSDP2 to model using the data-parallel submesh.

   Shards Transformer and UpBlock/UpBlockPS submodules first, then
   wraps the whole model.

   :param model: Raw (or TP-converted) model.
   :param dp_mesh: 1-D DeviceMesh for the data-parallel dimension.
                   Pass None to shard over the default global mesh.
   :param conf: Full training config dict (reads trainer.amp for mp_policy).

   :returns: model with fully_shard applied (in-place, returns same object).


.. py:function:: apply_tensor_parallel(model: torch.nn.Module, tp_mesh) -> torch.nn.Module

   Walk model and apply TP to all blocks that declare ``_tp_col``/``_tp_row``.

   Any ``nn.Module`` subclass can opt in by setting two class attributes::

       class MyBlock(nn.Module):
           _tp_col = "proj_up"   # dotted path to the column-parallel layer
           _tp_row = "proj_out"  # dotted path to the row-parallel layer

   Paths may address layers inside Sequentials (e.g. ``"layers.1"``).
   Supported layer types: ``nn.Conv2d`` (1×1 only) and ``nn.Linear``.

   Converts in-place. Safe to call before apply_fsdp2.

   :param model: The model to convert.
   :param tp_mesh: 1-D DeviceMesh for the tensor-parallel dimension.

   :returns: model (same object, modified in-place).


.. py:function:: get_domain_manager(model)

.. py:function:: get_raw_model(model)

.. py:function:: shard_spatial(tensor, manager)

.. py:function:: unpad_shard_interp(y_pred, padding_opt, manager, image_h, image_w)

.. py:function:: sync_domain_gradients(model, manager)

   Average gradients across the domain-parallel group.

   See credit.parallel.collectives.allreduce_grads_avg for the bucketing and
   DTensor handling.


