credit.domain_parallel.layers

Contents

credit.domain_parallel.layers#

Domain-parallel layer wrappers.

Each wrapper replaces a standard PyTorch layer with a version that handles halo exchange and/or distributed reductions for domain-parallel training.

Layers that are purely local (1x1 convolutions, channel-wise normalization, activations, etc.) do not need wrappers and are left unchanged.

Classes#

DomainParallelConv2d

Domain-parallel Conv2d with automatic halo exchange.

DomainParallelConv3d

Domain-parallel Conv3d with halo exchange along the lat dimension.

DomainParallelConvTranspose2d

Domain-parallel ConvTranspose2d with halo exchange.

DomainParallelConvTranspose3d

Domain-parallel ConvTranspose3d with halo exchange along the lat dimension.

DomainParallelGroupNorm

Domain-parallel GroupNorm with distributed statistics.

DomainParallelPeriodicConv2d

Domain-parallel wrapper for PeriodicConv2d.

DomainParallelInterpolate

Domain-parallel bilinear interpolation.

Module Contents#

class credit.domain_parallel.layers.DomainParallelConv2d(conv, shard_dim=-2)#

Bases: torch.nn.Module

Domain-parallel Conv2d with automatic halo exchange.

Before the convolution, exchanges halo rows along the sharding dimension (latitude by default) so boundary pixels see correct neighbors. After the convolution, the output has the correct local shard size.

For strided convolutions (like in CrossEmbedLayer), the halo width is (kernel_size - stride) // 2 to account for the stride’s effect on output boundaries.

Parameters:
  • conv – An existing nn.Conv2d module to wrap.

  • shard_dim – Spatial dimension being sharded (-2 for H in BCHW).

conv#
shard_dim = -2#
halo_exchange#
forward(x)#
property weight#
property bias#
class credit.domain_parallel.layers.DomainParallelConv3d(conv, shard_dim=3)#

Bases: torch.nn.Module

Domain-parallel Conv3d with halo exchange along the lat dimension.

Used for CubeEmbedding’s patch-embedding Conv3d. The sharded dimension in the 5D tensor (B, C, T, H, W) is H (dim=-2 or dim=3).

Parameters:
  • conv – An existing nn.Conv3d module to wrap.

  • shard_dim – Spatial dimension being sharded (3 for H in 5D BCTHW).

conv#
shard_dim = 3#
halo_exchange#
forward(x)#
class credit.domain_parallel.layers.DomainParallelConvTranspose2d(conv, shard_dim=-2)#

Bases: torch.nn.Module

Domain-parallel ConvTranspose2d with halo exchange.

For kernel=2, stride=2 (standard upsample): no halo needed, purely local. For kernel=4, stride=2, padding=1: halo of 1 needed.

Parameters:
  • conv – An existing nn.ConvTranspose2d module to wrap.

  • shard_dim – Spatial dimension being sharded (-2 for H in BCHW).

conv#
shard_dim = -2#
halo_width#
halo_exchange#
forward(x)#
class credit.domain_parallel.layers.DomainParallelConvTranspose3d(conv, shard_dim=3)#

Bases: torch.nn.Module

Domain-parallel ConvTranspose3d with halo exchange along the lat dimension.

For kernel==stride (e.g. Pangu PatchRecovery with patch_size=(2,4,4)): no halo needed — purely local upsampling. For kernel>stride in H: halo exchange required.

Parameters:
  • conv – An existing nn.ConvTranspose3d module to wrap.

  • shard_dim – Spatial dimension being sharded (3 for H in BCZHW).

conv#
shard_dim = 3#
halo_width#
halo_exchange#
forward(x)#
class credit.domain_parallel.layers.DomainParallelGroupNorm(norm)#

Bases: torch.nn.Module

Domain-parallel GroupNorm with distributed statistics.

GroupNorm computes mean and variance over (H, W) for each group of channels. With domain-sharded H, we need to all_reduce the statistics across the domain group before applying normalization.

Parameters:

norm – An existing nn.GroupNorm module to wrap.

num_groups#
num_channels#
eps#
affine#
forward(x)#
class credit.domain_parallel.layers.DomainParallelPeriodicConv2d(periodic_conv, shard_dim=-2)#

Bases: torch.nn.Module

Domain-parallel wrapper for PeriodicConv2d.

PeriodicConv2d manually pads W with circular (longitude) and H with reflect (latitude) before calling an inner nn.Conv2d(padding=0). With domain-sharded H, reflect padding on H is wrong — boundary ranks would reflect against the shard edge instead of the true pole.

This wrapper replaces the reflect padding on H with halo exchange, while keeping the circular padding on W unchanged.

Parameters:
  • periodic_conv – An existing PeriodicConv2d module to wrap. Must have a .conv (nn.Conv2d) and .padding (int) attribute.

  • shard_dim – Spatial dimension being sharded (-2 for H in BCHW).

conv#
padding#
shard_dim = -2#
halo_exchange#
forward(x)#
property weight#
property bias#
class credit.domain_parallel.layers.DomainParallelInterpolate(size, mode='bilinear', shard_dim=-2)#

Bases: torch.nn.Module

Domain-parallel bilinear interpolation.

Exchanges a 1-row halo before interpolation so boundary pixels interpolate correctly, then trims the extra output rows.

Parameters:
  • size – Target output size (H, W).

  • mode – Interpolation mode (default: ‘bilinear’).

  • shard_dim – Spatial dimension being sharded (-2 for H in BCHW).

size#
mode = 'bilinear'#
shard_dim = -2#
halo_exchange#
forward(x)#