credit.domain_parallel.sharding

credit.domain_parallel.sharding#

Tensor sharding and gathering utilities for domain parallelism.

Provides functions to split tensors along a spatial dimension across domain-parallel ranks, and to gather them back for metrics/output.

Functions#

shard_tensor(x[, dim, manager])

Shard a tensor along the given dimension across domain-parallel ranks.

gather_tensor(x[, dim, manager])

Gather sharded tensor from all domain-parallel ranks.

shard_batch(batch[, spatial_dims_5d, spatial_dims_4d, ...])

Shard all spatial tensors in a training batch dictionary.

Module Contents#

credit.domain_parallel.sharding.shard_tensor(x, dim=-2, manager=None)#

Shard a tensor along the given dimension across domain-parallel ranks.

Splits the tensor into equal chunks and returns only this rank’s chunk.

Parameters:
  • x – Input tensor.

  • dim – Dimension to shard along (default: -2 for H in BCHW or BCTHW).

  • manager – DomainParallelManager instance (uses global singleton if None).

Returns:

Local shard of the tensor.

credit.domain_parallel.sharding.gather_tensor(x, dim=-2, manager=None)#

Gather sharded tensor from all domain-parallel ranks.

All-gathers the tensor along the sharding dimension and concatenates.

Parameters:
  • x – Local shard tensor.

  • dim – Dimension that was sharded (default: -2).

  • manager – DomainParallelManager instance (uses global singleton if None).

Returns:

Full (gathered) tensor.

credit.domain_parallel.sharding.shard_batch(batch, spatial_dims_5d=-2, spatial_dims_4d=-2, manager=None)#

Shard all spatial tensors in a training batch dictionary.

Handles both 5D (B, C, T, H, W) and 4D (B, C, H, W) tensors. Non-tensor entries and 1D/2D tensors are left unchanged.

Parameters:
  • batch – Dictionary of batch data (from dataloader).

  • spatial_dims_5d – Dimension to shard in 5D tensors (default: -2 for H).

  • spatial_dims_4d – Dimension to shard in 4D tensors (default: -2 for H).

  • manager – DomainParallelManager instance.

Returns:

New dict with sharded tensors.