credit.distributed
==================

.. py:module:: credit.distributed


Functions
---------

.. autoapisummary::

   credit.distributed.setup
   credit.distributed.get_rank_info
   credit.distributed.should_not_checkpoint
   credit.distributed.distributed_model_wrapper


Module Contents
---------------

.. py:function:: setup(rank, world_size, mode, backend='nccl')

   Initializes the distributed process group.

   :param rank: The rank of the process within the distributed setup.
   :type rank: int
   :param world_size: The total number of processes in the distributed setup.
   :type world_size: int
   :param mode: The mode of operation (e.g., 'fsdp', 'ddp').
   :type mode: str
   :param backend: The backend to use for distributed training. Defaults to 'nccl'.
   :type backend: str, optional


.. py:function:: get_rank_info(trainer_mode)

   Gets rank and size information for distributed training.

   :param trainer_mode: The mode of training (e.g., 'fsdp', 'ddp').
   :type trainer_mode: str

   :returns: A tuple containing LOCAL_RANK (int), WORLD_RANK (int), and WORLD_SIZE (int).
   :rtype: tuple


.. py:function:: should_not_checkpoint(module)

.. py:function:: distributed_model_wrapper(conf, neural_network, device)

   Wraps the neural network model for distributed training.

   Supports modes: 'fsdp', 'ddp', 'domain_parallel', 'fsdp+domain_parallel'.

   For domain_parallel modes, the model's Conv2d/Conv3d/ConvTranspose2d/GroupNorm
   layers are replaced with domain-parallel equivalents that handle halo exchange
   and distributed normalization. For fsdp+domain_parallel, domain-parallel
   conversion is done first, then FSDP wrapping uses the data-parallel subgroup.

   :param conf: The configuration dictionary containing training settings.
   :type conf: dict
   :param neural_network: The neural network model to be wrapped.
   :type neural_network: torch.nn.Module
   :param device: The device on which the model will be trained.
   :type device: torch.device

   :returns: The wrapped model ready for distributed training.
   :rtype: torch.nn.Module


