credit.domain_parallel
======================

.. py:module:: credit.domain_parallel

.. autoapi-nested-parse::

   Domain parallelism for CREDIT weather models.

   Shards high-resolution input data along spatial dimensions (latitude by default)
   across multiple GPUs, enabling training on data too large for a single GPU.
   Inspired by PhysicsNeMo's ShardTensor framework.

   Key components:
   - DomainParallelManager: Process group creation and coordination
   - HaloExchange: Differentiable boundary communication for convolutions
   - Domain-parallel layers: Conv2d, Conv3d, ConvTranspose2d, GroupNorm wrappers
   - convert_to_domain_parallel: Automatic model conversion
   - shard_tensor / gather_tensor: Input/output distribution

   Usage:
       from credit.domain_parallel import (
           initialize_domain_parallel,
           get_domain_parallel_manager,
           convert_to_domain_parallel,
           shard_tensor,
           gather_tensor,
           shard_batch,
       )



Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/credit/domain_parallel/convert/index
   /autoapi/credit/domain_parallel/halo_exchange/index
   /autoapi/credit/domain_parallel/layers/index
   /autoapi/credit/domain_parallel/manager/index
   /autoapi/credit/domain_parallel/sharding/index


Classes
-------

.. autoapisummary::

   credit.domain_parallel.DomainParallelManager
   credit.domain_parallel.HaloExchange
   credit.domain_parallel.DomainParallelConv2d
   credit.domain_parallel.DomainParallelConv3d
   credit.domain_parallel.DomainParallelConvTranspose2d
   credit.domain_parallel.DomainParallelConvTranspose3d
   credit.domain_parallel.DomainParallelGroupNorm
   credit.domain_parallel.DomainParallelInterpolate
   credit.domain_parallel.DomainParallelPeriodicConv2d


Functions
---------

.. autoapisummary::

   credit.domain_parallel.initialize_domain_parallel
   credit.domain_parallel.get_domain_parallel_manager
   credit.domain_parallel.convert_to_domain_parallel
   credit.domain_parallel.validate_sharding_constraints
   credit.domain_parallel.shard_tensor
   credit.domain_parallel.gather_tensor
   credit.domain_parallel.shard_batch


Package Contents
----------------

.. py:class:: DomainParallelManager(world_size, domain_parallel_size, shard_dim=-2)

   Manages process groups for domain parallelism.

   :param world_size: Total number of GPUs.
   :param domain_parallel_size: Number of GPUs per domain-parallel group.
   :param shard_dim: Which spatial dimension to shard. -2 means latitude (H)
                     in a (B, C, H, W) tensor.


   .. py:attribute:: world_size


   .. py:attribute:: domain_parallel_size


   .. py:attribute:: data_parallel_size


   .. py:attribute:: shard_dim
      :value: -2



   .. py:attribute:: _domain_group_idx


   .. py:attribute:: _domain_rank


   .. py:attribute:: _dp_rank


   .. py:attribute:: _domain_group
      :value: None



   .. py:attribute:: _dp_group
      :value: None



   .. py:property:: domain_group

      Process group for domain-parallel communication (halo exchange, reductions).


   .. py:property:: data_parallel_group

      Process group for data-parallel communication (gradient sync).


   .. py:property:: domain_rank

      Rank within the domain-parallel group (0 to domain_parallel_size-1).


   .. py:property:: domain_world_size

      Number of ranks in the domain-parallel group.


   .. py:property:: dp_rank

      Rank within the data-parallel group.


   .. py:property:: dp_world_size

      Number of ranks in the data-parallel group.


   .. py:property:: is_first_domain_rank

      True if this is the first rank in its domain group (north edge).


   .. py:property:: is_last_domain_rank

      True if this is the last rank in its domain group (south edge).


   .. py:method:: neighbor_ranks()

      Returns (prev_rank, next_rank) global ranks for halo exchange.

      Returns None for non-existent neighbors at edges.



.. py:function:: initialize_domain_parallel(world_size, domain_parallel_size, shard_dim=-2)

   Initialize the global DomainParallelManager singleton.

   :param world_size: Total number of GPUs.
   :param domain_parallel_size: Number of GPUs per domain group.
   :param shard_dim: Spatial dimension to shard (-2 for lat in BCHW).

   :returns: DomainParallelManager instance.


.. py:function:: get_domain_parallel_manager()

   Get the global DomainParallelManager singleton.

   :returns: DomainParallelManager or None if not initialized.


.. py:class:: HaloExchange(halo_width, dim=-2)

   Bases: :py:obj:`torch.nn.Module`


   Halo exchange layer for domain-parallel operations.

   Pads the input tensor with halo rows from neighboring ranks along
   the sharding dimension. The operation is differentiable.

   :param halo_width: Number of rows to exchange on each side.
   :param dim: Tensor dimension to exchange along (default: -2 for lat in BCHW).


   .. py:attribute:: halo_width


   .. py:attribute:: dim
      :value: -2



   .. py:method:: forward(x)


   .. py:method:: trim(x, halo_before, halo_after, dim=-2)
      :staticmethod:


      Trim halo rows from the output after a convolution.

      :param x: Tensor with extra halo rows.
      :param halo_before: Number of rows to trim from the start.
      :param halo_after: Number of rows to trim from the end.
      :param dim: Dimension to trim along.



.. py:class:: DomainParallelConv2d(conv, shard_dim=-2)

   Bases: :py:obj:`torch.nn.Module`


   Domain-parallel Conv2d with automatic halo exchange.

   Before the convolution, exchanges halo rows along the sharding dimension
   (latitude by default) so boundary pixels see correct neighbors. After
   the convolution, the output has the correct local shard size.

   For strided convolutions (like in CrossEmbedLayer), the halo width is
   (kernel_size - stride) // 2 to account for the stride's effect on
   output boundaries.

   :param conv: An existing nn.Conv2d module to wrap.
   :param shard_dim: Spatial dimension being sharded (-2 for H in BCHW).


   .. py:attribute:: conv


   .. py:attribute:: shard_dim
      :value: -2



   .. py:attribute:: halo_exchange


   .. py:method:: forward(x)


   .. py:property:: weight


   .. py:property:: bias


.. py:class:: DomainParallelConv3d(conv, shard_dim=3)

   Bases: :py:obj:`torch.nn.Module`


   Domain-parallel Conv3d with halo exchange along the lat dimension.

   Used for CubeEmbedding's patch-embedding Conv3d. The sharded dimension
   in the 5D tensor (B, C, T, H, W) is H (dim=-2 or dim=3).

   :param conv: An existing nn.Conv3d module to wrap.
   :param shard_dim: Spatial dimension being sharded (3 for H in 5D BCTHW).


   .. py:attribute:: conv


   .. py:attribute:: shard_dim
      :value: 3



   .. py:attribute:: halo_exchange


   .. py:method:: forward(x)


.. py:class:: DomainParallelConvTranspose2d(conv, shard_dim=-2)

   Bases: :py:obj:`torch.nn.Module`


   Domain-parallel ConvTranspose2d with halo exchange.

   For kernel=2, stride=2 (standard upsample): no halo needed, purely local.
   For kernel=4, stride=2, padding=1: halo of 1 needed.

   :param conv: An existing nn.ConvTranspose2d module to wrap.
   :param shard_dim: Spatial dimension being sharded (-2 for H in BCHW).


   .. py:attribute:: conv


   .. py:attribute:: shard_dim
      :value: -2



   .. py:attribute:: halo_width


   .. py:attribute:: halo_exchange


   .. py:method:: forward(x)


.. py:class:: DomainParallelConvTranspose3d(conv, shard_dim=3)

   Bases: :py:obj:`torch.nn.Module`


   Domain-parallel ConvTranspose3d with halo exchange along the lat dimension.

   For kernel==stride (e.g. Pangu PatchRecovery with patch_size=(2,4,4)):
   no halo needed — purely local upsampling.
   For kernel>stride in H: halo exchange required.

   :param conv: An existing nn.ConvTranspose3d module to wrap.
   :param shard_dim: Spatial dimension being sharded (3 for H in BCZHW).


   .. py:attribute:: conv


   .. py:attribute:: shard_dim
      :value: 3



   .. py:attribute:: halo_width


   .. py:attribute:: halo_exchange


   .. py:method:: forward(x)


.. py:class:: DomainParallelGroupNorm(norm)

   Bases: :py:obj:`torch.nn.Module`


   Domain-parallel GroupNorm with distributed statistics.

   GroupNorm computes mean and variance over (H, W) for each group of channels.
   With domain-sharded H, we need to all_reduce the statistics across the
   domain group before applying normalization.

   :param norm: An existing nn.GroupNorm module to wrap.


   .. py:attribute:: num_groups


   .. py:attribute:: num_channels


   .. py:attribute:: eps


   .. py:attribute:: affine


   .. py:method:: forward(x)


.. py:class:: DomainParallelInterpolate(size, mode='bilinear', shard_dim=-2)

   Bases: :py:obj:`torch.nn.Module`


   Domain-parallel bilinear interpolation.

   Exchanges a 1-row halo before interpolation so boundary pixels
   interpolate correctly, then trims the extra output rows.

   :param size: Target output size (H, W).
   :param mode: Interpolation mode (default: 'bilinear').
   :param shard_dim: Spatial dimension being sharded (-2 for H in BCHW).


   .. py:attribute:: size


   .. py:attribute:: mode
      :value: 'bilinear'



   .. py:attribute:: shard_dim
      :value: -2



   .. py:attribute:: halo_exchange


   .. py:method:: forward(x)


.. py:class:: DomainParallelPeriodicConv2d(periodic_conv, shard_dim=-2)

   Bases: :py:obj:`torch.nn.Module`


   Domain-parallel wrapper for PeriodicConv2d.

   PeriodicConv2d manually pads W with circular (longitude) and H with
   reflect (latitude) before calling an inner nn.Conv2d(padding=0).
   With domain-sharded H, reflect padding on H is wrong — boundary ranks
   would reflect against the shard edge instead of the true pole.

   This wrapper replaces the reflect padding on H with halo exchange,
   while keeping the circular padding on W unchanged.

   :param periodic_conv: An existing PeriodicConv2d module to wrap.
                         Must have a `.conv` (nn.Conv2d) and `.padding` (int) attribute.
   :param shard_dim: Spatial dimension being sharded (-2 for H in BCHW).


   .. py:attribute:: conv


   .. py:attribute:: padding


   .. py:attribute:: shard_dim
      :value: -2



   .. py:attribute:: halo_exchange


   .. py:method:: forward(x)


   .. py:property:: weight


   .. py:property:: bias


.. py:function:: convert_to_domain_parallel(model, manager, shard_dim=-2, custom_converters=None)

   Convert a model to use domain-parallel layers.

   Walks the module tree and replaces:
   - nn.Conv2d (kernel>1 in H) -> DomainParallelConv2d
   - nn.Conv3d (kernel>1 in H) -> DomainParallelConv3d
   - nn.ConvTranspose2d (kernel>stride in H) -> DomainParallelConvTranspose2d
   - nn.ConvTranspose3d (kernel>stride in H) -> DomainParallelConvTranspose3d
   - nn.GroupNorm -> DomainParallelGroupNorm

   Leaves unchanged:
   - 1x1 Conv2d (FeedForward, projections)
   - Custom LayerNorm (channel-wise, no spatial reduction)
   - Attention modules (windowed, local within shard)
   - PixelShuffle, activations, Dropout

   For custom module types that wrap a Conv2d internally (e.g. PeriodicConv2d),
   pass a custom_converters dict so the whole wrapper is replaced at once and
   its inner Conv2d children are not double-processed:

       from credit.models.unet_diffusion import PeriodicConv2d
       from credit.domain_parallel.layers import DomainParallelPeriodicConv2d

       convert_to_domain_parallel(
           model, manager,
           custom_converters={
               PeriodicConv2d: lambda m: DomainParallelPeriodicConv2d(m),
           }
       )

   :param model: The nn.Module to convert.
   :param manager: DomainParallelManager instance.
   :param shard_dim: Spatial dimension being sharded (-2 for H in BCHW).
   :param custom_converters: Optional dict mapping module type -> callable(module)
                             returning the replacement module. Custom types are checked before
                             the built-in rules, and their children are skipped.

   :returns: The model with replaced layers (modified in-place).


.. py:function:: validate_sharding_constraints(model, local_h, window_sizes=None)

   Validate that the local spatial dimension is compatible with the model.

   Checks that local_h is divisible by all window sizes at all encoder levels,
   accounting for stride-based downsampling.

   :param model: The model to validate against.
   :param local_h: Local H dimension per domain-parallel rank.
   :param window_sizes: List of window sizes at each encoder level.
                        If None, attempts to extract from the model.

   :returns: List of warning messages (empty if all checks pass).


.. py:function:: shard_tensor(x, dim=-2, manager=None)

   Shard a tensor along the given dimension across domain-parallel ranks.

   Splits the tensor into equal chunks and returns only this rank's chunk.

   :param x: Input tensor.
   :param dim: Dimension to shard along (default: -2 for H in BCHW or BCTHW).
   :param manager: DomainParallelManager instance (uses global singleton if None).

   :returns: Local shard of the tensor.


.. py:function:: gather_tensor(x, dim=-2, manager=None)

   Gather sharded tensor from all domain-parallel ranks.

   All-gathers the tensor along the sharding dimension and concatenates.

   :param x: Local shard tensor.
   :param dim: Dimension that was sharded (default: -2).
   :param manager: DomainParallelManager instance (uses global singleton if None).

   :returns: Full (gathered) tensor.


.. py:function:: shard_batch(batch, spatial_dims_5d=-2, spatial_dims_4d=-2, manager=None)

   Shard all spatial tensors in a training batch dictionary.

   Handles both 5D (B, C, T, H, W) and 4D (B, C, H, W) tensors.
   Non-tensor entries and 1D/2D tensors are left unchanged.

   :param batch: Dictionary of batch data (from dataloader).
   :param spatial_dims_5d: Dimension to shard in 5D tensors (default: -2 for H).
   :param spatial_dims_4d: Dimension to shard in 4D tensors (default: -2 for H).
   :param manager: DomainParallelManager instance.

   :returns: New dict with sharded tensors.


