credit.domain_parallel.convert
==============================

.. py:module:: credit.domain_parallel.convert

.. autoapi-nested-parse::

   Model conversion for domain parallelism.

   Walks a model's module tree and replaces layers that need inter-GPU
   communication with domain-parallel equivalents. Local operations
   (1x1 convolutions, channel-wise norms, activations, etc.) are unchanged.



Attributes
----------

.. autoapisummary::

   credit.domain_parallel.convert.logger


Functions
---------

.. autoapisummary::

   credit.domain_parallel.convert._needs_halo_conv2d
   credit.domain_parallel.convert._needs_halo_conv3d
   credit.domain_parallel.convert._needs_halo_conv_transpose3d
   credit.domain_parallel.convert._needs_halo_conv_transpose2d
   credit.domain_parallel.convert._replace_module
   credit.domain_parallel.convert.convert_to_domain_parallel
   credit.domain_parallel.convert.validate_sharding_constraints


Module Contents
---------------

.. py:data:: logger

.. py:function:: _needs_halo_conv2d(conv)

   Check if a Conv2d needs halo exchange (kernel > 1 along H).


.. py:function:: _needs_halo_conv3d(conv)

   Check if a Conv3d needs halo exchange (kernel > 1 along H dim).


.. py:function:: _needs_halo_conv_transpose3d(conv)

   Check if a ConvTranspose3d needs halo exchange (kernel > stride along H).


.. py:function:: _needs_halo_conv_transpose2d(conv)

   Check if a ConvTranspose2d needs halo exchange.


.. py:function:: _replace_module(parent, name, old_module, new_module)

   Replace a child module in the parent.


.. py:function:: convert_to_domain_parallel(model, manager, shard_dim=-2, custom_converters=None)

   Convert a model to use domain-parallel layers.

   Walks the module tree and replaces:
   - nn.Conv2d (kernel>1 in H) -> DomainParallelConv2d
   - nn.Conv3d (kernel>1 in H) -> DomainParallelConv3d
   - nn.ConvTranspose2d (kernel>stride in H) -> DomainParallelConvTranspose2d
   - nn.ConvTranspose3d (kernel>stride in H) -> DomainParallelConvTranspose3d
   - nn.GroupNorm -> DomainParallelGroupNorm

   Leaves unchanged:
   - 1x1 Conv2d (FeedForward, projections)
   - Custom LayerNorm (channel-wise, no spatial reduction)
   - Attention modules (windowed, local within shard)
   - PixelShuffle, activations, Dropout

   For custom module types that wrap a Conv2d internally (e.g. PeriodicConv2d),
   pass a custom_converters dict so the whole wrapper is replaced at once and
   its inner Conv2d children are not double-processed:

       from credit.models.unet_diffusion import PeriodicConv2d
       from credit.domain_parallel.layers import DomainParallelPeriodicConv2d

       convert_to_domain_parallel(
           model, manager,
           custom_converters={
               PeriodicConv2d: lambda m: DomainParallelPeriodicConv2d(m),
           }
       )

   :param model: The nn.Module to convert.
   :param manager: DomainParallelManager instance.
   :param shard_dim: Spatial dimension being sharded (-2 for H in BCHW).
   :param custom_converters: Optional dict mapping module type -> callable(module)
                             returning the replacement module. Custom types are checked before
                             the built-in rules, and their children are skipped.

   :returns: The model with replaced layers (modified in-place).


.. py:function:: validate_sharding_constraints(model, local_h, window_sizes=None)

   Validate that the local spatial dimension is compatible with the model.

   Checks that local_h is divisible by all window sizes at all encoder levels,
   accounting for stride-based downsampling.

   :param model: The model to validate against.
   :param local_h: Local H dimension per domain-parallel rank.
   :param window_sizes: List of window sizes at each encoder level.
                        If None, attempts to extract from the model.

   :returns: List of warning messages (empty if all checks pass).


