credit.domain_parallel.convert#

Model conversion for domain parallelism.

Walks a model’s module tree and replaces layers that need inter-GPU communication with domain-parallel equivalents. Local operations (1x1 convolutions, channel-wise norms, activations, etc.) are unchanged.

Attributes#

Functions#

_needs_halo_conv2d(conv)

Check if a Conv2d needs halo exchange (kernel > 1 along H).

_needs_halo_conv3d(conv)

Check if a Conv3d needs halo exchange (kernel > 1 along H dim).

_needs_halo_conv_transpose3d(conv)

Check if a ConvTranspose3d needs halo exchange (kernel > stride along H).

_needs_halo_conv_transpose2d(conv)

Check if a ConvTranspose2d needs halo exchange.

_replace_module(parent, name, old_module, new_module)

Replace a child module in the parent.

convert_to_domain_parallel(model, manager[, ...])

Convert a model to use domain-parallel layers.

validate_sharding_constraints(model, local_h[, ...])

Validate that the local spatial dimension is compatible with the model.

Module Contents#

credit.domain_parallel.convert.logger#
credit.domain_parallel.convert._needs_halo_conv2d(conv)#

Check if a Conv2d needs halo exchange (kernel > 1 along H).

credit.domain_parallel.convert._needs_halo_conv3d(conv)#

Check if a Conv3d needs halo exchange (kernel > 1 along H dim).

credit.domain_parallel.convert._needs_halo_conv_transpose3d(conv)#

Check if a ConvTranspose3d needs halo exchange (kernel > stride along H).

credit.domain_parallel.convert._needs_halo_conv_transpose2d(conv)#

Check if a ConvTranspose2d needs halo exchange.

credit.domain_parallel.convert._replace_module(parent, name, old_module, new_module)#

Replace a child module in the parent.

credit.domain_parallel.convert.convert_to_domain_parallel(model, manager, shard_dim=-2, custom_converters=None)#

Convert a model to use domain-parallel layers.

Walks the module tree and replaces: - nn.Conv2d (kernel>1 in H) -> DomainParallelConv2d - nn.Conv3d (kernel>1 in H) -> DomainParallelConv3d - nn.ConvTranspose2d (kernel>stride in H) -> DomainParallelConvTranspose2d - nn.ConvTranspose3d (kernel>stride in H) -> DomainParallelConvTranspose3d - nn.GroupNorm -> DomainParallelGroupNorm

Leaves unchanged: - 1x1 Conv2d (FeedForward, projections) - Custom LayerNorm (channel-wise, no spatial reduction) - Attention modules (windowed, local within shard) - PixelShuffle, activations, Dropout

For custom module types that wrap a Conv2d internally (e.g. PeriodicConv2d), pass a custom_converters dict so the whole wrapper is replaced at once and its inner Conv2d children are not double-processed:

from credit.models.unet_diffusion import PeriodicConv2d from credit.domain_parallel.layers import DomainParallelPeriodicConv2d

convert_to_domain_parallel(

model, manager, custom_converters={

PeriodicConv2d: lambda m: DomainParallelPeriodicConv2d(m),

}

)

Parameters:
  • model – The nn.Module to convert.

  • manager – DomainParallelManager instance.

  • shard_dim – Spatial dimension being sharded (-2 for H in BCHW).

  • custom_converters – Optional dict mapping module type -> callable(module) returning the replacement module. Custom types are checked before the built-in rules, and their children are skipped.

Returns:

The model with replaced layers (modified in-place).

credit.domain_parallel.convert.validate_sharding_constraints(model, local_h, window_sizes=None)#

Validate that the local spatial dimension is compatible with the model.

Checks that local_h is divisible by all window sizes at all encoder levels, accounting for stride-based downsampling.

Parameters:
  • model – The model to validate against.

  • local_h – Local H dimension per domain-parallel rank.

  • window_sizes – List of window sizes at each encoder level. If None, attempts to extract from the model.

Returns:

List of warning messages (empty if all checks pass).