credit.domain_parallel.convert#
Model conversion for domain parallelism.
Walks a model’s module tree and replaces layers that need inter-GPU communication with domain-parallel equivalents. Local operations (1x1 convolutions, channel-wise norms, activations, etc.) are unchanged.
Attributes#
Functions#
|
Check if a Conv2d needs halo exchange (kernel > 1 along H). |
|
Check if a Conv3d needs halo exchange (kernel > 1 along H dim). |
Check if a ConvTranspose3d needs halo exchange (kernel > stride along H). |
|
Check if a ConvTranspose2d needs halo exchange. |
|
|
Replace a child module in the parent. |
|
Convert a model to use domain-parallel layers. |
|
Validate that the local spatial dimension is compatible with the model. |
Module Contents#
- credit.domain_parallel.convert.logger#
- credit.domain_parallel.convert._needs_halo_conv2d(conv)#
Check if a Conv2d needs halo exchange (kernel > 1 along H).
- credit.domain_parallel.convert._needs_halo_conv3d(conv)#
Check if a Conv3d needs halo exchange (kernel > 1 along H dim).
- credit.domain_parallel.convert._needs_halo_conv_transpose3d(conv)#
Check if a ConvTranspose3d needs halo exchange (kernel > stride along H).
- credit.domain_parallel.convert._needs_halo_conv_transpose2d(conv)#
Check if a ConvTranspose2d needs halo exchange.
- credit.domain_parallel.convert._replace_module(parent, name, old_module, new_module)#
Replace a child module in the parent.
- credit.domain_parallel.convert.convert_to_domain_parallel(model, manager, shard_dim=-2, custom_converters=None)#
Convert a model to use domain-parallel layers.
Walks the module tree and replaces: - nn.Conv2d (kernel>1 in H) -> DomainParallelConv2d - nn.Conv3d (kernel>1 in H) -> DomainParallelConv3d - nn.ConvTranspose2d (kernel>stride in H) -> DomainParallelConvTranspose2d - nn.ConvTranspose3d (kernel>stride in H) -> DomainParallelConvTranspose3d - nn.GroupNorm -> DomainParallelGroupNorm
Leaves unchanged: - 1x1 Conv2d (FeedForward, projections) - Custom LayerNorm (channel-wise, no spatial reduction) - Attention modules (windowed, local within shard) - PixelShuffle, activations, Dropout
For custom module types that wrap a Conv2d internally (e.g. PeriodicConv2d), pass a custom_converters dict so the whole wrapper is replaced at once and its inner Conv2d children are not double-processed:
from credit.models.unet_diffusion import PeriodicConv2d from credit.domain_parallel.layers import DomainParallelPeriodicConv2d
- convert_to_domain_parallel(
model, manager, custom_converters={
PeriodicConv2d: lambda m: DomainParallelPeriodicConv2d(m),
}
)
- Parameters:
model – The nn.Module to convert.
manager – DomainParallelManager instance.
shard_dim – Spatial dimension being sharded (-2 for H in BCHW).
custom_converters – Optional dict mapping module type -> callable(module) returning the replacement module. Custom types are checked before the built-in rules, and their children are skipped.
- Returns:
The model with replaced layers (modified in-place).
- credit.domain_parallel.convert.validate_sharding_constraints(model, local_h, window_sizes=None)#
Validate that the local spatial dimension is compatible with the model.
Checks that local_h is divisible by all window sizes at all encoder levels, accounting for stride-based downsampling.
- Parameters:
model – The model to validate against.
local_h – Local H dimension per domain-parallel rank.
window_sizes – List of window sizes at each encoder level. If None, attempts to extract from the model.
- Returns:
List of warning messages (empty if all checks pass).