credit.parallel.mesh
====================

.. py:module:: credit.parallel.mesh

.. autoapi-nested-parse::

   DeviceMesh construction for CREDIT v2 parallelism.

   Builds a logical process mesh from a trainer.parallelism config block.
   Supports up to 3 parallel dimensions: data (FSDP2/DDP), tensor (TP),
   and domain (spatial sharding).

   Example configs:
       # FSDP2 only, 8 GPUs
       parallelism: {data: fsdp2, tensor: 1, domain: 1}
       # mesh: (8,) named ["dp"]

       # FSDP2 + TP=2, 8 GPUs → dp=4, tp=2
       parallelism: {data: fsdp2, tensor: 2, domain: 1}
       # mesh: (4, 2) named ["dp", "tp"]

       # FSDP2 + TP=2 + domain=2, 8 GPUs → dp=2, tp=2, domain=2
       parallelism: {data: fsdp2, tensor: 2, domain: 2}
       # mesh: (2, 2, 2) named ["dp", "tp", "domain"]



Attributes
----------

.. autoapisummary::

   credit.parallel.mesh.logger


Functions
---------

.. autoapisummary::

   credit.parallel.mesh.dp_world_size
   credit.parallel.mesh.build_device_mesh
   credit.parallel.mesh.data_parallel_coords
   credit.parallel.mesh.parse_parallelism_conf


Module Contents
---------------

.. py:data:: logger

.. py:function:: dp_world_size(parallelism_conf: dict, world_size: int) -> int

   Number of data-parallel replicas for a given world size.

   Single source of truth for the dp-size arithmetic (build_device_mesh,
   data_parallel_coords, and the model wrapper must all agree).

   :raises ValueError: if world_size is not divisible by tensor * domain.


.. py:function:: build_device_mesh(parallelism_conf: dict, device: str = 'cuda')

   Build a DeviceMesh from a parallelism config block.

   :param parallelism_conf: dict with keys:
                            data   (str): "fsdp2" | "ddp" | "none"
                            tensor (int): TP degree, >= 1
                            domain (int): domain parallel degree, >= 1
   :param device: "cuda" (default) or "cpu" for tests

   :returns: DeviceMesh (or None if no parallelism)
             submeshes: dict mapping dim name -> submesh (or None if single-dim)
                 Keys present: "dp" if dp > 1, "tp" if tp > 1, "domain" if domain > 1
   :rtype: mesh

   :raises ValueError: if world_size is not divisible by tensor * domain.


.. py:function:: data_parallel_coords(conf: dict)

   Return (dp_rank, dp_size) for dataset/dataloader sharding.

   THE SAMPLER CONTRACT
   --------------------
   Dataset samples must be sharded over the **data-parallel** dimension only,
   never over the global rank. Ranks that differ only in their tensor- or
   domain-parallel coordinate MUST receive the *same* batch:

     - TP ranks compute partial outputs of the same activation; the row-parallel
       all_reduce sums them. Feeding different samples to TP peers silently sums
       partial outputs of *different* inputs — garbage activations, and the
       replicated (non-TP) parameters drift apart because nothing syncs them
       across the tp dimension.
     - Domain ranks hold different spatial shards of the same sample; the halo
       exchange passes boundary rows between them. Different samples per domain
       rank corrupt every halo.

   So a DataLoader/DistributedSampler must be built with
   ``rank=dp_rank, num_replicas=dp_size`` from this function — NOT the global
   rank/world_size — whenever tensor > 1 or domain > 1.

   RANK LAYOUT
   -----------
   ``init_device_mesh`` arranges ranks row-major over (dp, tp, domain), with
   dp outermost and domain innermost. ``DomainParallelManager`` builds the same
   layout (domain groups are consecutive ranks). Hence for global rank g:

       domain_coord = g % domain
       tp_coord     = (g // domain) % tp
       dp_rank      = g // (tp * domain)

   :returns: the data-parallel coordinate of this rank and the
             number of data-parallel replicas. Falls back to (0, 1) when torch
             distributed is not initialized.
   :rtype: (dp_rank, dp_size)


.. py:function:: parse_parallelism_conf(conf: dict) -> dict

   Extract and validate the parallelism block from a trainer config.

   Returns a normalized parallelism dict with keys: data, tensor, domain.


