credit.parallel.collectives
===========================

.. py:module:: credit.parallel.collectives

.. autoapi-nested-parse::

   Shared gradient all-reduce machinery for the parallel package.

   Used by both sync_domain_gradients (domain.py) and sync_replicated_gradients
   (tensor_parallel.py) so the subtle DTensor handling lives in exactly one place.



Functions
---------

.. autoapisummary::

   credit.parallel.collectives.all_reduce_avg
   credit.parallel.collectives.allreduce_grads_avg


Module Contents
---------------

.. py:function:: all_reduce_avg(tensor, group=None) -> None

   In-place all_reduce average that also works on gloo.

   ReduceOp.AVG is NCCL-only; gloo (CPU multi-rank runs, --backend gloo)
   needs SUM + divide.


.. py:function:: allreduce_grads_avg(grads, group) -> None

   Average gradients across a process group with minimal NCCL calls.

   Plain dense grads are flattened into one bucket per (dtype, device) so the
   sync is a handful of large all_reduces instead of one per parameter.
   DTensor grads (FSDP2) are reduced in place on their local shards — shards
   can be 0-sized or oddly strided per rank, which breaks the
   flatten/unflatten round-trip — but the per-shard all_reduces are issued
   async and waited together, so they cost one latency, not one per param.

   :param grads: iterable of gradient tensors (dense or DTensor).
   :param group: process group to average over.


