credit.parallel.tensor_parallel#

Tensor Parallelism for CREDIT v2 models.

Applies column/row parallelism to transformer blocks that opt in by declaring two class attributes:

class MyBlock(nn.Module):

_tp_col = “proj_up” # attribute path for the column-parallel layer _tp_row = “proj_out” # attribute path for the row-parallel layer

Paths may be dotted (e.g. "layers.1") to address layers nested inside a Sequential or other container.

The column-parallel layer receives the full input and produces a sharded output (no all_reduce). The row-parallel layer receives the sharded input and issues an all_reduce SUM so the rest of the graph sees the full output.

Supported layer types: nn.Conv2d (kernel 1×1 only) and nn.Linear.

Gradient flow:
  • TpColConv2d / TpColLinear: no all_reduce (output is sharded, consumed by the paired Row layer).

  • TpRowConv2d / TpRowLinear: all_reduce SUM over tp_group before returning.

Usage:

from credit.parallel import apply_tensor_parallel model = apply_tensor_parallel(model, tp_mesh)

Attributes#

Classes#

TpColConv2d

Column-parallel 1×1 Conv2d: each rank owns out_channels // tp output channels.

TpRowConv2d

Row-parallel 1×1 Conv2d: each rank owns in_channels // tp input channels.

TpColLinear

Column-parallel Linear: each rank owns out_features // tp output neurons.

TpRowLinear

Row-parallel Linear: each rank owns in_features // tp input neurons.

Functions#

_assert_plain_1x1(→ None)

Require a plain 1×1 conv: the Tp wrappers rebuild the layer with default

_tp_group_from_mesh(tp_mesh)

Extract the underlying ProcessGroup from a 1D DeviceMesh.

_rgetattr(obj, path)

Get a nested attribute/index using a dotted path.

_rsetattr(→ None)

Set a nested attribute/index using a dotted path.

_to_col_parallel(layer, tp_group)

_to_row_parallel(layer, tp_group)

apply_tensor_parallel(→ torch.nn.Module)

Walk model and apply TP to all blocks that declare _tp_col/_tp_row.

sync_replicated_gradients(→ None)

Average gradients of replicated (non-TP-sharded) params across the TP group.

Module Contents#

credit.parallel.tensor_parallel.logger#
credit.parallel.tensor_parallel._assert_plain_1x1(conv: torch.nn.Conv2d, cls_name: str) None#

Require a plain 1×1 conv: the Tp wrappers rebuild the layer with default stride/padding/dilation/groups, so anything non-default would be silently dropped rather than replicated.

class credit.parallel.tensor_parallel.TpColConv2d(conv: torch.nn.Conv2d, tp_group)#

Bases: torch.nn.Module

Column-parallel 1×1 Conv2d: each rank owns out_channels // tp output channels.

Input: full (B, C_in, H, W). Output: sharded (B, C_out // tp, H, W) — no all_reduce needed here.

tp_group#
conv#
forward(x)#
class credit.parallel.tensor_parallel.TpRowConv2d(conv: torch.nn.Conv2d, tp_group)#

Bases: torch.nn.Module

Row-parallel 1×1 Conv2d: each rank owns in_channels // tp input channels.

Input: sharded (B, C_in // tp, H, W) — from paired TpColConv2d. Output: full (B, C_out, H, W) after all_reduce SUM over tp_group.

tp_group#
conv#
forward(x)#
class credit.parallel.tensor_parallel.TpColLinear(linear: torch.nn.Linear, tp_group)#

Bases: torch.nn.Module

Column-parallel Linear: each rank owns out_features // tp output neurons.

Input: full (, C_in). Output: sharded (, C_out // tp) — no all_reduce needed here.

tp_group#
linear#
forward(x)#
class credit.parallel.tensor_parallel.TpRowLinear(linear: torch.nn.Linear, tp_group)#

Bases: torch.nn.Module

Row-parallel Linear: each rank owns in_features // tp input neurons.

Input: sharded (, C_in // tp) — from paired TpColLinear. Output: full (, C_out) after all_reduce SUM over tp_group.

tp_group#
linear#
forward(x)#
credit.parallel.tensor_parallel._tp_group_from_mesh(tp_mesh)#

Extract the underlying ProcessGroup from a 1D DeviceMesh.

credit.parallel.tensor_parallel._rgetattr(obj, path: str)#

Get a nested attribute/index using a dotted path.

credit.parallel.tensor_parallel._rsetattr(obj, path: str, val) None#

Set a nested attribute/index using a dotted path.

credit.parallel.tensor_parallel._to_col_parallel(layer: torch.nn.Module, tp_group)#
credit.parallel.tensor_parallel._to_row_parallel(layer: torch.nn.Module, tp_group)#
credit.parallel.tensor_parallel.apply_tensor_parallel(model: torch.nn.Module, tp_mesh) torch.nn.Module#

Walk model and apply TP to all blocks that declare _tp_col/_tp_row.

Any nn.Module subclass can opt in by setting two class attributes:

class MyBlock(nn.Module):
    _tp_col = "proj_up"   # dotted path to the column-parallel layer
    _tp_row = "proj_out"  # dotted path to the row-parallel layer

Paths may address layers inside Sequentials (e.g. "layers.1"). Supported layer types: nn.Conv2d (1×1 only) and nn.Linear.

Converts in-place. Safe to call before apply_fsdp2.

Parameters:
  • model – The model to convert.

  • tp_mesh – 1-D DeviceMesh for the tensor-parallel dimension.

Returns:

model (same object, modified in-place).

credit.parallel.tensor_parallel.sync_replicated_gradients(model: torch.nn.Module, tp_group) None#

Average gradients of replicated (non-TP-sharded) params across the TP group.

The Tp col/row weights are genuinely sharded per rank and must NOT be synced. Everything else (embeddings, norms, non-TP blocks, and the replicated row-parallel biases) holds an identical copy on every TP rank. Their gradients are identical in exact arithmetic given identical inputs, but with data=none nothing enforces that, and nondeterministic kernels drift the replicas apart over a long run. Averaging at the accumulation boundary pins the replicas together.

See credit.parallel.collectives.allreduce_grads_avg for the bucketing and DTensor handling. TP peers share the same dp coordinate, so their DTensor shard shapes are identical.