credit.domain_parallel.sharding#
Tensor sharding and gathering utilities for domain parallelism.
Provides functions to split tensors along a spatial dimension across domain-parallel ranks, and to gather them back for metrics/output.
Functions#
|
Shard a tensor along the given dimension across domain-parallel ranks. |
|
Gather sharded tensor from all domain-parallel ranks. |
|
Shard all spatial tensors in a training batch dictionary. |
Module Contents#
- credit.domain_parallel.sharding.shard_tensor(x, dim=-2, manager=None)#
Shard a tensor along the given dimension across domain-parallel ranks.
Splits the tensor into equal chunks and returns only this rank’s chunk.
- Parameters:
x – Input tensor.
dim – Dimension to shard along (default: -2 for H in BCHW or BCTHW).
manager – DomainParallelManager instance (uses global singleton if None).
- Returns:
Local shard of the tensor.
- credit.domain_parallel.sharding.gather_tensor(x, dim=-2, manager=None)#
Gather sharded tensor from all domain-parallel ranks.
All-gathers the tensor along the sharding dimension and concatenates.
- Parameters:
x – Local shard tensor.
dim – Dimension that was sharded (default: -2).
manager – DomainParallelManager instance (uses global singleton if None).
- Returns:
Full (gathered) tensor.
- credit.domain_parallel.sharding.shard_batch(batch, spatial_dims_5d=-2, spatial_dims_4d=-2, manager=None)#
Shard all spatial tensors in a training batch dictionary.
Handles both 5D (B, C, T, H, W) and 4D (B, C, H, W) tensors. Non-tensor entries and 1D/2D tensors are left unchanged.
- Parameters:
batch – Dictionary of batch data (from dataloader).
spatial_dims_5d – Dimension to shard in 5D tensors (default: -2 for H).
spatial_dims_4d – Dimension to shard in 4D tensors (default: -2 for H).
manager – DomainParallelManager instance.
- Returns:
New dict with sharded tensors.