credit.parallel.domain#

Domain-parallel utility functions shared across all trainers.

Attributes#

Functions#

get_domain_manager(model)

get_raw_model(model)

shard_spatial(tensor, manager)

unpad_shard_interp(y_pred, padding_opt, manager, ...)

shard_lat_weights(weights, target_h)

Match latitude weights to a (possibly domain-sharded) target height.

gather_spatial(tensor, manager)

Inverse of shard_spatial: all-gather H-shards across the domain group.

sync_domain_gradients(model, manager)

Average gradients across the domain-parallel group.

Module Contents#

credit.parallel.domain.logger#
credit.parallel.domain.get_domain_manager(model)#
credit.parallel.domain.get_raw_model(model)#
credit.parallel.domain.shard_spatial(tensor, manager)#
credit.parallel.domain.unpad_shard_interp(y_pred, padding_opt, manager, image_h, image_w)#
credit.parallel.domain.shard_lat_weights(weights, target_h)#

Match latitude weights to a (possibly domain-sharded) target height.

Single source of truth for the lat-weight sharding used by the loss and both metrics classes. Returns weights unchanged when heights match; otherwise narrows to this rank’s domain shard. Raises ValueError when the mismatch cannot be explained by domain sharding.

credit.parallel.domain.gather_spatial(tensor, manager)#

Inverse of shard_spatial: all-gather H-shards across the domain group.

Used between autoregressive rollout steps — the next step’s input is assembled at full height on every domain rank, then re-sharded by shard_spatial. Plain (non-differentiable) all_gather is correct here: Reconstruct detaches y_processed between steps, so no gradient flows through the gathered tensors.

credit.parallel.domain.sync_domain_gradients(model, manager)#

Average gradients across the domain-parallel group.

See credit.parallel.collectives.allreduce_grads_avg for the bucketing and DTensor handling.