credit.parallel.tensor_parallel
===============================

.. py:module:: credit.parallel.tensor_parallel

.. autoapi-nested-parse::

   Tensor Parallelism for CREDIT v2 models.

   Applies column/row parallelism to transformer blocks that opt in by declaring
   two class attributes:

       class MyBlock(nn.Module):
           _tp_col = "proj_up"   # attribute path for the column-parallel layer
           _tp_row = "proj_out"  # attribute path for the row-parallel layer

   Paths may be dotted (e.g. ``"layers.1"``) to address layers nested inside a
   Sequential or other container.

   The column-parallel layer receives the full input and produces a sharded output
   (no all_reduce). The row-parallel layer receives the sharded input and issues an
   all_reduce SUM so the rest of the graph sees the full output.

   Supported layer types: ``nn.Conv2d`` (kernel 1×1 only) and ``nn.Linear``.

   Gradient flow:
     - TpColConv2d / TpColLinear: no all_reduce (output is sharded,
       consumed by the paired Row layer).
     - TpRowConv2d / TpRowLinear: all_reduce SUM over tp_group before returning.

   Usage:
       from credit.parallel import apply_tensor_parallel
       model = apply_tensor_parallel(model, tp_mesh)



Attributes
----------

.. autoapisummary::

   credit.parallel.tensor_parallel.logger


Classes
-------

.. autoapisummary::

   credit.parallel.tensor_parallel.TpColConv2d
   credit.parallel.tensor_parallel.TpRowConv2d
   credit.parallel.tensor_parallel.TpColLinear
   credit.parallel.tensor_parallel.TpRowLinear


Functions
---------

.. autoapisummary::

   credit.parallel.tensor_parallel._assert_plain_1x1
   credit.parallel.tensor_parallel._tp_group_from_mesh
   credit.parallel.tensor_parallel._rgetattr
   credit.parallel.tensor_parallel._rsetattr
   credit.parallel.tensor_parallel._to_col_parallel
   credit.parallel.tensor_parallel._to_row_parallel
   credit.parallel.tensor_parallel.apply_tensor_parallel
   credit.parallel.tensor_parallel.sync_replicated_gradients


Module Contents
---------------

.. py:data:: logger

.. py:function:: _assert_plain_1x1(conv: torch.nn.Conv2d, cls_name: str) -> None

   Require a plain 1×1 conv: the Tp wrappers rebuild the layer with default
   stride/padding/dilation/groups, so anything non-default would be silently
   dropped rather than replicated.


.. py:class:: TpColConv2d(conv: torch.nn.Conv2d, tp_group)

   Bases: :py:obj:`torch.nn.Module`


   Column-parallel 1×1 Conv2d: each rank owns out_channels // tp output channels.

   Input: full (B, C_in, H, W).
   Output: sharded (B, C_out // tp, H, W) — no all_reduce needed here.


   .. py:attribute:: tp_group


   .. py:attribute:: conv


   .. py:method:: forward(x)


.. py:class:: TpRowConv2d(conv: torch.nn.Conv2d, tp_group)

   Bases: :py:obj:`torch.nn.Module`


   Row-parallel 1×1 Conv2d: each rank owns in_channels // tp input channels.

   Input: sharded (B, C_in // tp, H, W) — from paired TpColConv2d.
   Output: full (B, C_out, H, W) after all_reduce SUM over tp_group.


   .. py:attribute:: tp_group


   .. py:attribute:: conv


   .. py:method:: forward(x)


.. py:class:: TpColLinear(linear: torch.nn.Linear, tp_group)

   Bases: :py:obj:`torch.nn.Module`


   Column-parallel Linear: each rank owns out_features // tp output neurons.

   Input: full (*, C_in).
   Output: sharded (*, C_out // tp) — no all_reduce needed here.


   .. py:attribute:: tp_group


   .. py:attribute:: linear


   .. py:method:: forward(x)


.. py:class:: TpRowLinear(linear: torch.nn.Linear, tp_group)

   Bases: :py:obj:`torch.nn.Module`


   Row-parallel Linear: each rank owns in_features // tp input neurons.

   Input: sharded (*, C_in // tp) — from paired TpColLinear.
   Output: full (*, C_out) after all_reduce SUM over tp_group.


   .. py:attribute:: tp_group


   .. py:attribute:: linear


   .. py:method:: forward(x)


.. py:function:: _tp_group_from_mesh(tp_mesh)

   Extract the underlying ProcessGroup from a 1D DeviceMesh.


.. py:function:: _rgetattr(obj, path: str)

   Get a nested attribute/index using a dotted path.


.. py:function:: _rsetattr(obj, path: str, val) -> None

   Set a nested attribute/index using a dotted path.


.. py:function:: _to_col_parallel(layer: torch.nn.Module, tp_group)

.. py:function:: _to_row_parallel(layer: torch.nn.Module, tp_group)

.. py:function:: apply_tensor_parallel(model: torch.nn.Module, tp_mesh) -> torch.nn.Module

   Walk model and apply TP to all blocks that declare ``_tp_col``/``_tp_row``.

   Any ``nn.Module`` subclass can opt in by setting two class attributes::

       class MyBlock(nn.Module):
           _tp_col = "proj_up"   # dotted path to the column-parallel layer
           _tp_row = "proj_out"  # dotted path to the row-parallel layer

   Paths may address layers inside Sequentials (e.g. ``"layers.1"``).
   Supported layer types: ``nn.Conv2d`` (1×1 only) and ``nn.Linear``.

   Converts in-place. Safe to call before apply_fsdp2.

   :param model: The model to convert.
   :param tp_mesh: 1-D DeviceMesh for the tensor-parallel dimension.

   :returns: model (same object, modified in-place).


.. py:function:: sync_replicated_gradients(model: torch.nn.Module, tp_group) -> None

   Average gradients of replicated (non-TP-sharded) params across the TP group.

   The Tp col/row weights are genuinely sharded per rank and must NOT be
   synced. Everything else (embeddings, norms, non-TP blocks, and the
   replicated row-parallel biases) holds an identical copy on every TP rank.
   Their gradients are identical in exact arithmetic given identical inputs,
   but with data=none nothing enforces that, and nondeterministic kernels
   drift the replicas apart over a long run. Averaging at the accumulation
   boundary pins the replicas together.

   See credit.parallel.collectives.allreduce_grads_avg for the bucketing and
   DTensor handling. TP peers share the same dp coordinate, so their DTensor
   shard shapes are identical.


