credit.domain_parallel.layers#
Domain-parallel layer wrappers.
Each wrapper replaces a standard PyTorch layer with a version that handles halo exchange and/or distributed reductions for domain-parallel training.
Layers that are purely local (1x1 convolutions, channel-wise normalization, activations, etc.) do not need wrappers and are left unchanged.
Classes#
Domain-parallel Conv2d with automatic halo exchange. |
|
Domain-parallel Conv3d with halo exchange along the lat dimension. |
|
Domain-parallel ConvTranspose2d with halo exchange. |
|
Domain-parallel ConvTranspose3d with halo exchange along the lat dimension. |
|
Domain-parallel GroupNorm with distributed statistics. |
|
Domain-parallel wrapper for PeriodicConv2d. |
|
Domain-parallel bilinear interpolation. |
Module Contents#
- class credit.domain_parallel.layers.DomainParallelConv2d(conv, shard_dim=-2)#
Bases:
torch.nn.ModuleDomain-parallel Conv2d with automatic halo exchange.
Before the convolution, exchanges halo rows along the sharding dimension (latitude by default) so boundary pixels see correct neighbors. After the convolution, the output has the correct local shard size.
For strided convolutions (like in CrossEmbedLayer), the halo width is (kernel_size - stride) // 2 to account for the stride’s effect on output boundaries.
- Parameters:
conv – An existing nn.Conv2d module to wrap.
shard_dim – Spatial dimension being sharded (-2 for H in BCHW).
- conv#
- shard_dim = -2#
- halo_exchange#
- forward(x)#
- property weight#
- property bias#
- class credit.domain_parallel.layers.DomainParallelConv3d(conv, shard_dim=3)#
Bases:
torch.nn.ModuleDomain-parallel Conv3d with halo exchange along the lat dimension.
Used for CubeEmbedding’s patch-embedding Conv3d. The sharded dimension in the 5D tensor (B, C, T, H, W) is H (dim=-2 or dim=3).
- Parameters:
conv – An existing nn.Conv3d module to wrap.
shard_dim – Spatial dimension being sharded (3 for H in 5D BCTHW).
- conv#
- shard_dim = 3#
- halo_exchange#
- forward(x)#
- class credit.domain_parallel.layers.DomainParallelConvTranspose2d(conv, shard_dim=-2)#
Bases:
torch.nn.ModuleDomain-parallel ConvTranspose2d with halo exchange.
For kernel=2, stride=2 (standard upsample): no halo needed, purely local. For kernel=4, stride=2, padding=1: halo of 1 needed.
- Parameters:
conv – An existing nn.ConvTranspose2d module to wrap.
shard_dim – Spatial dimension being sharded (-2 for H in BCHW).
- conv#
- shard_dim = -2#
- halo_width#
- halo_exchange#
- forward(x)#
- class credit.domain_parallel.layers.DomainParallelConvTranspose3d(conv, shard_dim=3)#
Bases:
torch.nn.ModuleDomain-parallel ConvTranspose3d with halo exchange along the lat dimension.
For kernel==stride (e.g. Pangu PatchRecovery with patch_size=(2,4,4)): no halo needed — purely local upsampling. For kernel>stride in H: halo exchange required.
- Parameters:
conv – An existing nn.ConvTranspose3d module to wrap.
shard_dim – Spatial dimension being sharded (3 for H in BCZHW).
- conv#
- shard_dim = 3#
- halo_width#
- halo_exchange#
- forward(x)#
- class credit.domain_parallel.layers.DomainParallelGroupNorm(norm)#
Bases:
torch.nn.ModuleDomain-parallel GroupNorm with distributed statistics.
GroupNorm computes mean and variance over (H, W) for each group of channels. With domain-sharded H, we need to all_reduce the statistics across the domain group before applying normalization.
- Parameters:
norm – An existing nn.GroupNorm module to wrap.
- num_groups#
- num_channels#
- eps#
- affine#
- forward(x)#
- class credit.domain_parallel.layers.DomainParallelPeriodicConv2d(periodic_conv, shard_dim=-2)#
Bases:
torch.nn.ModuleDomain-parallel wrapper for PeriodicConv2d.
PeriodicConv2d manually pads W with circular (longitude) and H with reflect (latitude) before calling an inner nn.Conv2d(padding=0). With domain-sharded H, reflect padding on H is wrong — boundary ranks would reflect against the shard edge instead of the true pole.
This wrapper replaces the reflect padding on H with halo exchange, while keeping the circular padding on W unchanged.
- Parameters:
periodic_conv – An existing PeriodicConv2d module to wrap. Must have a .conv (nn.Conv2d) and .padding (int) attribute.
shard_dim – Spatial dimension being sharded (-2 for H in BCHW).
- conv#
- padding#
- shard_dim = -2#
- halo_exchange#
- forward(x)#
- property weight#
- property bias#
- class credit.domain_parallel.layers.DomainParallelInterpolate(size, mode='bilinear', shard_dim=-2)#
Bases:
torch.nn.ModuleDomain-parallel bilinear interpolation.
Exchanges a 1-row halo before interpolation so boundary pixels interpolate correctly, then trims the extra output rows.
- Parameters:
size – Target output size (H, W).
mode – Interpolation mode (default: ‘bilinear’).
shard_dim – Spatial dimension being sharded (-2 for H in BCHW).
- size#
- mode = 'bilinear'#
- shard_dim = -2#
- halo_exchange#
- forward(x)#