credit.domain_parallel.sharding
===============================

.. py:module:: credit.domain_parallel.sharding

.. autoapi-nested-parse::

   Tensor sharding and gathering utilities for domain parallelism.

   Provides functions to split tensors along a spatial dimension across
   domain-parallel ranks, and to gather them back for metrics/output.



Functions
---------

.. autoapisummary::

   credit.domain_parallel.sharding.shard_tensor
   credit.domain_parallel.sharding.gather_tensor
   credit.domain_parallel.sharding.shard_batch


Module Contents
---------------

.. py:function:: shard_tensor(x, dim=-2, manager=None)

   Shard a tensor along the given dimension across domain-parallel ranks.

   Splits the tensor into equal chunks and returns only this rank's chunk.

   :param x: Input tensor.
   :param dim: Dimension to shard along (default: -2 for H in BCHW or BCTHW).
   :param manager: DomainParallelManager instance (uses global singleton if None).

   :returns: Local shard of the tensor.


.. py:function:: gather_tensor(x, dim=-2, manager=None)

   Gather sharded tensor from all domain-parallel ranks.

   All-gathers the tensor along the sharding dimension and concatenates.

   :param x: Local shard tensor.
   :param dim: Dimension that was sharded (default: -2).
   :param manager: DomainParallelManager instance (uses global singleton if None).

   :returns: Full (gathered) tensor.


.. py:function:: shard_batch(batch, spatial_dims_5d=-2, spatial_dims_4d=-2, manager=None)

   Shard all spatial tensors in a training batch dictionary.

   Handles both 5D (B, C, T, H, W) and 4D (B, C, H, W) tensors.
   Non-tensor entries and 1D/2D tensors are left unchanged.

   :param batch: Dictionary of batch data (from dataloader).
   :param spatial_dims_5d: Dimension to shard in 5D tensors (default: -2 for H).
   :param spatial_dims_4d: Dimension to shard in 4D tensors (default: -2 for H).
   :param manager: DomainParallelManager instance.

   :returns: New dict with sharded tensors.


