credit.parallel.tensor_parallel#
Tensor Parallelism for CREDIT v2 models.
Applies column/row parallelism to transformer blocks that opt in by declaring two class attributes:
- class MyBlock(nn.Module):
_tp_col = “proj_up” # attribute path for the column-parallel layer _tp_row = “proj_out” # attribute path for the row-parallel layer
Paths may be dotted (e.g. "layers.1") to address layers nested inside a
Sequential or other container.
The column-parallel layer receives the full input and produces a sharded output (no all_reduce). The row-parallel layer receives the sharded input and issues an all_reduce SUM so the rest of the graph sees the full output.
Supported layer types: nn.Conv2d (kernel 1×1 only) and nn.Linear.
- Gradient flow:
TpColConv2d / TpColLinear: no all_reduce (output is sharded, consumed by the paired Row layer).
TpRowConv2d / TpRowLinear: all_reduce SUM over tp_group before returning.
- Usage:
from credit.parallel import apply_tensor_parallel model = apply_tensor_parallel(model, tp_mesh)
Attributes#
Classes#
Column-parallel 1×1 Conv2d: each rank owns out_channels // tp output channels. |
|
Row-parallel 1×1 Conv2d: each rank owns in_channels // tp input channels. |
|
Column-parallel Linear: each rank owns out_features // tp output neurons. |
|
Row-parallel Linear: each rank owns in_features // tp input neurons. |
Functions#
|
Require a plain 1×1 conv: the Tp wrappers rebuild the layer with default |
|
Extract the underlying ProcessGroup from a 1D DeviceMesh. |
|
Get a nested attribute/index using a dotted path. |
|
Set a nested attribute/index using a dotted path. |
|
|
|
|
|
Walk model and apply TP to all blocks that declare |
|
Average gradients of replicated (non-TP-sharded) params across the TP group. |
Module Contents#
- credit.parallel.tensor_parallel.logger#
- credit.parallel.tensor_parallel._assert_plain_1x1(conv: torch.nn.Conv2d, cls_name: str) None#
Require a plain 1×1 conv: the Tp wrappers rebuild the layer with default stride/padding/dilation/groups, so anything non-default would be silently dropped rather than replicated.
- class credit.parallel.tensor_parallel.TpColConv2d(conv: torch.nn.Conv2d, tp_group)#
Bases:
torch.nn.ModuleColumn-parallel 1×1 Conv2d: each rank owns out_channels // tp output channels.
Input: full (B, C_in, H, W). Output: sharded (B, C_out // tp, H, W) — no all_reduce needed here.
- tp_group#
- conv#
- forward(x)#
- class credit.parallel.tensor_parallel.TpRowConv2d(conv: torch.nn.Conv2d, tp_group)#
Bases:
torch.nn.ModuleRow-parallel 1×1 Conv2d: each rank owns in_channels // tp input channels.
Input: sharded (B, C_in // tp, H, W) — from paired TpColConv2d. Output: full (B, C_out, H, W) after all_reduce SUM over tp_group.
- tp_group#
- conv#
- forward(x)#
- class credit.parallel.tensor_parallel.TpColLinear(linear: torch.nn.Linear, tp_group)#
Bases:
torch.nn.ModuleColumn-parallel Linear: each rank owns out_features // tp output neurons.
Input: full (, C_in). Output: sharded (, C_out // tp) — no all_reduce needed here.
- tp_group#
- linear#
- forward(x)#
- class credit.parallel.tensor_parallel.TpRowLinear(linear: torch.nn.Linear, tp_group)#
Bases:
torch.nn.ModuleRow-parallel Linear: each rank owns in_features // tp input neurons.
Input: sharded (, C_in // tp) — from paired TpColLinear. Output: full (, C_out) after all_reduce SUM over tp_group.
- tp_group#
- linear#
- forward(x)#
- credit.parallel.tensor_parallel._tp_group_from_mesh(tp_mesh)#
Extract the underlying ProcessGroup from a 1D DeviceMesh.
- credit.parallel.tensor_parallel._rgetattr(obj, path: str)#
Get a nested attribute/index using a dotted path.
- credit.parallel.tensor_parallel._rsetattr(obj, path: str, val) None#
Set a nested attribute/index using a dotted path.
- credit.parallel.tensor_parallel._to_col_parallel(layer: torch.nn.Module, tp_group)#
- credit.parallel.tensor_parallel._to_row_parallel(layer: torch.nn.Module, tp_group)#
- credit.parallel.tensor_parallel.apply_tensor_parallel(model: torch.nn.Module, tp_mesh) torch.nn.Module#
Walk model and apply TP to all blocks that declare
_tp_col/_tp_row.Any
nn.Modulesubclass can opt in by setting two class attributes:class MyBlock(nn.Module): _tp_col = "proj_up" # dotted path to the column-parallel layer _tp_row = "proj_out" # dotted path to the row-parallel layer
Paths may address layers inside Sequentials (e.g.
"layers.1"). Supported layer types:nn.Conv2d(1×1 only) andnn.Linear.Converts in-place. Safe to call before apply_fsdp2.
- Parameters:
model – The model to convert.
tp_mesh – 1-D DeviceMesh for the tensor-parallel dimension.
- Returns:
model (same object, modified in-place).
- credit.parallel.tensor_parallel.sync_replicated_gradients(model: torch.nn.Module, tp_group) None#
Average gradients of replicated (non-TP-sharded) params across the TP group.
The Tp col/row weights are genuinely sharded per rank and must NOT be synced. Everything else (embeddings, norms, non-TP blocks, and the replicated row-parallel biases) holds an identical copy on every TP rank. Their gradients are identical in exact arithmetic given identical inputs, but with data=none nothing enforces that, and nondeterministic kernels drift the replicas apart over a long run. Averaging at the accumulation boundary pins the replicas together.
See credit.parallel.collectives.allreduce_grads_avg for the bucketing and DTensor handling. TP peers share the same dp coordinate, so their DTensor shard shapes are identical.