credit.domain_parallel#
Domain parallelism for CREDIT weather models.
Shards high-resolution input data along spatial dimensions (latitude by default) across multiple GPUs, enabling training on data too large for a single GPU. Inspired by PhysicsNeMo’s ShardTensor framework.
Key components: - DomainParallelManager: Process group creation and coordination - HaloExchange: Differentiable boundary communication for convolutions - Domain-parallel layers: Conv2d, Conv3d, ConvTranspose2d, GroupNorm wrappers - convert_to_domain_parallel: Automatic model conversion - shard_tensor / gather_tensor: Input/output distribution
- Usage:
- from credit.domain_parallel import (
initialize_domain_parallel, get_domain_parallel_manager, convert_to_domain_parallel, shard_tensor, gather_tensor, shard_batch,
)
Submodules#
Classes#
Manages process groups for domain parallelism. |
|
Halo exchange layer for domain-parallel operations. |
|
Domain-parallel Conv2d with automatic halo exchange. |
|
Domain-parallel Conv3d with halo exchange along the lat dimension. |
|
Domain-parallel ConvTranspose2d with halo exchange. |
|
Domain-parallel ConvTranspose3d with halo exchange along the lat dimension. |
|
Domain-parallel GroupNorm with distributed statistics. |
|
Domain-parallel bilinear interpolation. |
|
Domain-parallel wrapper for PeriodicConv2d. |
Functions#
|
Initialize the global DomainParallelManager singleton. |
Get the global DomainParallelManager singleton. |
|
|
Convert a model to use domain-parallel layers. |
|
Validate that the local spatial dimension is compatible with the model. |
|
Shard a tensor along the given dimension across domain-parallel ranks. |
|
Gather sharded tensor from all domain-parallel ranks. |
|
Shard all spatial tensors in a training batch dictionary. |
Package Contents#
- class credit.domain_parallel.DomainParallelManager(world_size, domain_parallel_size, shard_dim=-2)#
Manages process groups for domain parallelism.
- Parameters:
world_size – Total number of GPUs.
domain_parallel_size – Number of GPUs per domain-parallel group.
shard_dim – Which spatial dimension to shard. -2 means latitude (H) in a (B, C, H, W) tensor.
- world_size#
- domain_parallel_size#
- data_parallel_size#
- shard_dim = -2#
- _domain_group_idx#
- _domain_rank#
- _dp_rank#
- _domain_group = None#
- _dp_group = None#
- property domain_group#
Process group for domain-parallel communication (halo exchange, reductions).
- property data_parallel_group#
Process group for data-parallel communication (gradient sync).
- property domain_rank#
Rank within the domain-parallel group (0 to domain_parallel_size-1).
- property domain_world_size#
Number of ranks in the domain-parallel group.
- property dp_rank#
Rank within the data-parallel group.
- property dp_world_size#
Number of ranks in the data-parallel group.
- property is_first_domain_rank#
True if this is the first rank in its domain group (north edge).
- property is_last_domain_rank#
True if this is the last rank in its domain group (south edge).
- neighbor_ranks()#
Returns (prev_rank, next_rank) global ranks for halo exchange.
Returns None for non-existent neighbors at edges.
- credit.domain_parallel.initialize_domain_parallel(world_size, domain_parallel_size, shard_dim=-2)#
Initialize the global DomainParallelManager singleton.
- Parameters:
world_size – Total number of GPUs.
domain_parallel_size – Number of GPUs per domain group.
shard_dim – Spatial dimension to shard (-2 for lat in BCHW).
- Returns:
DomainParallelManager instance.
- credit.domain_parallel.get_domain_parallel_manager()#
Get the global DomainParallelManager singleton.
- Returns:
DomainParallelManager or None if not initialized.
- class credit.domain_parallel.HaloExchange(halo_width, dim=-2)#
Bases:
torch.nn.ModuleHalo exchange layer for domain-parallel operations.
Pads the input tensor with halo rows from neighboring ranks along the sharding dimension. The operation is differentiable.
- Parameters:
halo_width – Number of rows to exchange on each side.
dim – Tensor dimension to exchange along (default: -2 for lat in BCHW).
- halo_width#
- dim = -2#
- forward(x)#
- static trim(x, halo_before, halo_after, dim=-2)#
Trim halo rows from the output after a convolution.
- Parameters:
x – Tensor with extra halo rows.
halo_before – Number of rows to trim from the start.
halo_after – Number of rows to trim from the end.
dim – Dimension to trim along.
- class credit.domain_parallel.DomainParallelConv2d(conv, shard_dim=-2)#
Bases:
torch.nn.ModuleDomain-parallel Conv2d with automatic halo exchange.
Before the convolution, exchanges halo rows along the sharding dimension (latitude by default) so boundary pixels see correct neighbors. After the convolution, the output has the correct local shard size.
For strided convolutions (like in CrossEmbedLayer), the halo width is (kernel_size - stride) // 2 to account for the stride’s effect on output boundaries.
- Parameters:
conv – An existing nn.Conv2d module to wrap.
shard_dim – Spatial dimension being sharded (-2 for H in BCHW).
- conv#
- shard_dim = -2#
- halo_exchange#
- forward(x)#
- property weight#
- property bias#
- class credit.domain_parallel.DomainParallelConv3d(conv, shard_dim=3)#
Bases:
torch.nn.ModuleDomain-parallel Conv3d with halo exchange along the lat dimension.
Used for CubeEmbedding’s patch-embedding Conv3d. The sharded dimension in the 5D tensor (B, C, T, H, W) is H (dim=-2 or dim=3).
- Parameters:
conv – An existing nn.Conv3d module to wrap.
shard_dim – Spatial dimension being sharded (3 for H in 5D BCTHW).
- conv#
- shard_dim = 3#
- halo_exchange#
- forward(x)#
- class credit.domain_parallel.DomainParallelConvTranspose2d(conv, shard_dim=-2)#
Bases:
torch.nn.ModuleDomain-parallel ConvTranspose2d with halo exchange.
For kernel=2, stride=2 (standard upsample): no halo needed, purely local. For kernel=4, stride=2, padding=1: halo of 1 needed.
- Parameters:
conv – An existing nn.ConvTranspose2d module to wrap.
shard_dim – Spatial dimension being sharded (-2 for H in BCHW).
- conv#
- shard_dim = -2#
- halo_width#
- halo_exchange#
- forward(x)#
- class credit.domain_parallel.DomainParallelConvTranspose3d(conv, shard_dim=3)#
Bases:
torch.nn.ModuleDomain-parallel ConvTranspose3d with halo exchange along the lat dimension.
For kernel==stride (e.g. Pangu PatchRecovery with patch_size=(2,4,4)): no halo needed — purely local upsampling. For kernel>stride in H: halo exchange required.
- Parameters:
conv – An existing nn.ConvTranspose3d module to wrap.
shard_dim – Spatial dimension being sharded (3 for H in BCZHW).
- conv#
- shard_dim = 3#
- halo_width#
- halo_exchange#
- forward(x)#
- class credit.domain_parallel.DomainParallelGroupNorm(norm)#
Bases:
torch.nn.ModuleDomain-parallel GroupNorm with distributed statistics.
GroupNorm computes mean and variance over (H, W) for each group of channels. With domain-sharded H, we need to all_reduce the statistics across the domain group before applying normalization.
- Parameters:
norm – An existing nn.GroupNorm module to wrap.
- num_groups#
- num_channels#
- eps#
- affine#
- forward(x)#
- class credit.domain_parallel.DomainParallelInterpolate(size, mode='bilinear', shard_dim=-2)#
Bases:
torch.nn.ModuleDomain-parallel bilinear interpolation.
Exchanges a 1-row halo before interpolation so boundary pixels interpolate correctly, then trims the extra output rows.
- Parameters:
size – Target output size (H, W).
mode – Interpolation mode (default: ‘bilinear’).
shard_dim – Spatial dimension being sharded (-2 for H in BCHW).
- size#
- mode = 'bilinear'#
- shard_dim = -2#
- halo_exchange#
- forward(x)#
- class credit.domain_parallel.DomainParallelPeriodicConv2d(periodic_conv, shard_dim=-2)#
Bases:
torch.nn.ModuleDomain-parallel wrapper for PeriodicConv2d.
PeriodicConv2d manually pads W with circular (longitude) and H with reflect (latitude) before calling an inner nn.Conv2d(padding=0). With domain-sharded H, reflect padding on H is wrong — boundary ranks would reflect against the shard edge instead of the true pole.
This wrapper replaces the reflect padding on H with halo exchange, while keeping the circular padding on W unchanged.
- Parameters:
periodic_conv – An existing PeriodicConv2d module to wrap. Must have a .conv (nn.Conv2d) and .padding (int) attribute.
shard_dim – Spatial dimension being sharded (-2 for H in BCHW).
- conv#
- padding#
- shard_dim = -2#
- halo_exchange#
- forward(x)#
- property weight#
- property bias#
- credit.domain_parallel.convert_to_domain_parallel(model, manager, shard_dim=-2, custom_converters=None)#
Convert a model to use domain-parallel layers.
Walks the module tree and replaces: - nn.Conv2d (kernel>1 in H) -> DomainParallelConv2d - nn.Conv3d (kernel>1 in H) -> DomainParallelConv3d - nn.ConvTranspose2d (kernel>stride in H) -> DomainParallelConvTranspose2d - nn.ConvTranspose3d (kernel>stride in H) -> DomainParallelConvTranspose3d - nn.GroupNorm -> DomainParallelGroupNorm
Leaves unchanged: - 1x1 Conv2d (FeedForward, projections) - Custom LayerNorm (channel-wise, no spatial reduction) - Attention modules (windowed, local within shard) - PixelShuffle, activations, Dropout
For custom module types that wrap a Conv2d internally (e.g. PeriodicConv2d), pass a custom_converters dict so the whole wrapper is replaced at once and its inner Conv2d children are not double-processed:
from credit.models.unet_diffusion import PeriodicConv2d from credit.domain_parallel.layers import DomainParallelPeriodicConv2d
- convert_to_domain_parallel(
model, manager, custom_converters={
PeriodicConv2d: lambda m: DomainParallelPeriodicConv2d(m),
}
)
- Parameters:
model – The nn.Module to convert.
manager – DomainParallelManager instance.
shard_dim – Spatial dimension being sharded (-2 for H in BCHW).
custom_converters – Optional dict mapping module type -> callable(module) returning the replacement module. Custom types are checked before the built-in rules, and their children are skipped.
- Returns:
The model with replaced layers (modified in-place).
- credit.domain_parallel.validate_sharding_constraints(model, local_h, window_sizes=None)#
Validate that the local spatial dimension is compatible with the model.
Checks that local_h is divisible by all window sizes at all encoder levels, accounting for stride-based downsampling.
- Parameters:
model – The model to validate against.
local_h – Local H dimension per domain-parallel rank.
window_sizes – List of window sizes at each encoder level. If None, attempts to extract from the model.
- Returns:
List of warning messages (empty if all checks pass).
- credit.domain_parallel.shard_tensor(x, dim=-2, manager=None)#
Shard a tensor along the given dimension across domain-parallel ranks.
Splits the tensor into equal chunks and returns only this rank’s chunk.
- Parameters:
x – Input tensor.
dim – Dimension to shard along (default: -2 for H in BCHW or BCTHW).
manager – DomainParallelManager instance (uses global singleton if None).
- Returns:
Local shard of the tensor.
- credit.domain_parallel.gather_tensor(x, dim=-2, manager=None)#
Gather sharded tensor from all domain-parallel ranks.
All-gathers the tensor along the sharding dimension and concatenates.
- Parameters:
x – Local shard tensor.
dim – Dimension that was sharded (default: -2).
manager – DomainParallelManager instance (uses global singleton if None).
- Returns:
Full (gathered) tensor.
- credit.domain_parallel.shard_batch(batch, spatial_dims_5d=-2, spatial_dims_4d=-2, manager=None)#
Shard all spatial tensors in a training batch dictionary.
Handles both 5D (B, C, T, H, W) and 4D (B, C, H, W) tensors. Non-tensor entries and 1D/2D tensors are left unchanged.
- Parameters:
batch – Dictionary of batch data (from dataloader).
spatial_dims_5d – Dimension to shard in 5D tensors (default: -2 for H).
spatial_dims_4d – Dimension to shard in 4D tensors (default: -2 for H).
manager – DomainParallelManager instance.
- Returns:
New dict with sharded tensors.