credit.domain_parallel.layers
=============================

.. py:module:: credit.domain_parallel.layers

.. autoapi-nested-parse::

   Domain-parallel layer wrappers.

   Each wrapper replaces a standard PyTorch layer with a version that handles
   halo exchange and/or distributed reductions for domain-parallel training.

   Layers that are purely local (1x1 convolutions, channel-wise normalization,
   activations, etc.) do not need wrappers and are left unchanged.



Classes
-------

.. autoapisummary::

   credit.domain_parallel.layers.DomainParallelConv2d
   credit.domain_parallel.layers.DomainParallelConv3d
   credit.domain_parallel.layers.DomainParallelConvTranspose2d
   credit.domain_parallel.layers.DomainParallelConvTranspose3d
   credit.domain_parallel.layers.DomainParallelGroupNorm
   credit.domain_parallel.layers.DomainParallelPeriodicConv2d
   credit.domain_parallel.layers.DomainParallelInterpolate


Module Contents
---------------

.. py:class:: DomainParallelConv2d(conv, shard_dim=-2)

   Bases: :py:obj:`torch.nn.Module`


   Domain-parallel Conv2d with automatic halo exchange.

   Before the convolution, exchanges halo rows along the sharding dimension
   (latitude by default) so boundary pixels see correct neighbors. After
   the convolution, the output has the correct local shard size.

   For strided convolutions (like in CrossEmbedLayer), the halo width is
   (kernel_size - stride) // 2 to account for the stride's effect on
   output boundaries.

   :param conv: An existing nn.Conv2d module to wrap.
   :param shard_dim: Spatial dimension being sharded (-2 for H in BCHW).


   .. py:attribute:: conv


   .. py:attribute:: shard_dim
      :value: -2



   .. py:attribute:: halo_exchange


   .. py:method:: forward(x)


   .. py:property:: weight


   .. py:property:: bias


.. py:class:: DomainParallelConv3d(conv, shard_dim=3)

   Bases: :py:obj:`torch.nn.Module`


   Domain-parallel Conv3d with halo exchange along the lat dimension.

   Used for CubeEmbedding's patch-embedding Conv3d. The sharded dimension
   in the 5D tensor (B, C, T, H, W) is H (dim=-2 or dim=3).

   :param conv: An existing nn.Conv3d module to wrap.
   :param shard_dim: Spatial dimension being sharded (3 for H in 5D BCTHW).


   .. py:attribute:: conv


   .. py:attribute:: shard_dim
      :value: 3



   .. py:attribute:: halo_exchange


   .. py:method:: forward(x)


.. py:class:: DomainParallelConvTranspose2d(conv, shard_dim=-2)

   Bases: :py:obj:`torch.nn.Module`


   Domain-parallel ConvTranspose2d with halo exchange.

   For kernel=2, stride=2 (standard upsample): no halo needed, purely local.
   For kernel=4, stride=2, padding=1: halo of 1 needed.

   :param conv: An existing nn.ConvTranspose2d module to wrap.
   :param shard_dim: Spatial dimension being sharded (-2 for H in BCHW).


   .. py:attribute:: conv


   .. py:attribute:: shard_dim
      :value: -2



   .. py:attribute:: halo_width


   .. py:attribute:: halo_exchange


   .. py:method:: forward(x)


.. py:class:: DomainParallelConvTranspose3d(conv, shard_dim=3)

   Bases: :py:obj:`torch.nn.Module`


   Domain-parallel ConvTranspose3d with halo exchange along the lat dimension.

   For kernel==stride (e.g. Pangu PatchRecovery with patch_size=(2,4,4)):
   no halo needed — purely local upsampling.
   For kernel>stride in H: halo exchange required.

   :param conv: An existing nn.ConvTranspose3d module to wrap.
   :param shard_dim: Spatial dimension being sharded (3 for H in BCZHW).


   .. py:attribute:: conv


   .. py:attribute:: shard_dim
      :value: 3



   .. py:attribute:: halo_width


   .. py:attribute:: halo_exchange


   .. py:method:: forward(x)


.. py:class:: DomainParallelGroupNorm(norm)

   Bases: :py:obj:`torch.nn.Module`


   Domain-parallel GroupNorm with distributed statistics.

   GroupNorm computes mean and variance over (H, W) for each group of channels.
   With domain-sharded H, we need to all_reduce the statistics across the
   domain group before applying normalization.

   :param norm: An existing nn.GroupNorm module to wrap.


   .. py:attribute:: num_groups


   .. py:attribute:: num_channels


   .. py:attribute:: eps


   .. py:attribute:: affine


   .. py:method:: forward(x)


.. py:class:: DomainParallelPeriodicConv2d(periodic_conv, shard_dim=-2)

   Bases: :py:obj:`torch.nn.Module`


   Domain-parallel wrapper for PeriodicConv2d.

   PeriodicConv2d manually pads W with circular (longitude) and H with
   reflect (latitude) before calling an inner nn.Conv2d(padding=0).
   With domain-sharded H, reflect padding on H is wrong — boundary ranks
   would reflect against the shard edge instead of the true pole.

   This wrapper replaces the reflect padding on H with halo exchange,
   while keeping the circular padding on W unchanged.

   :param periodic_conv: An existing PeriodicConv2d module to wrap.
                         Must have a `.conv` (nn.Conv2d) and `.padding` (int) attribute.
   :param shard_dim: Spatial dimension being sharded (-2 for H in BCHW).


   .. py:attribute:: conv


   .. py:attribute:: padding


   .. py:attribute:: shard_dim
      :value: -2



   .. py:attribute:: halo_exchange


   .. py:method:: forward(x)


   .. py:property:: weight


   .. py:property:: bias


.. py:class:: DomainParallelInterpolate(size, mode='bilinear', shard_dim=-2)

   Bases: :py:obj:`torch.nn.Module`


   Domain-parallel bilinear interpolation.

   Exchanges a 1-row halo before interpolation so boundary pixels
   interpolate correctly, then trims the extra output rows.

   :param size: Target output size (H, W).
   :param mode: Interpolation mode (default: 'bilinear').
   :param shard_dim: Spatial dimension being sharded (-2 for H in BCHW).


   .. py:attribute:: size


   .. py:attribute:: mode
      :value: 'bilinear'



   .. py:attribute:: shard_dim
      :value: -2



   .. py:attribute:: halo_exchange


   .. py:method:: forward(x)


