credit.parallel.domain
======================

.. py:module:: credit.parallel.domain

.. autoapi-nested-parse::

   Domain-parallel utility functions shared across all trainers.



Attributes
----------

.. autoapisummary::

   credit.parallel.domain.logger


Functions
---------

.. autoapisummary::

   credit.parallel.domain.get_domain_manager
   credit.parallel.domain.get_raw_model
   credit.parallel.domain.shard_spatial
   credit.parallel.domain.unpad_shard_interp
   credit.parallel.domain.shard_lat_weights
   credit.parallel.domain.gather_spatial
   credit.parallel.domain.sync_domain_gradients


Module Contents
---------------

.. py:data:: logger

.. py:function:: get_domain_manager(model)

.. py:function:: get_raw_model(model)

.. py:function:: shard_spatial(tensor, manager)

.. py:function:: unpad_shard_interp(y_pred, padding_opt, manager, image_h, image_w)

.. py:function:: shard_lat_weights(weights, target_h)

   Match latitude weights to a (possibly domain-sharded) target height.

   Single source of truth for the lat-weight sharding used by the loss and
   both metrics classes. Returns weights unchanged when heights match;
   otherwise narrows to this rank's domain shard. Raises ValueError when the
   mismatch cannot be explained by domain sharding.


.. py:function:: gather_spatial(tensor, manager)

   Inverse of shard_spatial: all-gather H-shards across the domain group.

   Used between autoregressive rollout steps — the next step's input is
   assembled at full height on every domain rank, then re-sharded by
   shard_spatial. Plain (non-differentiable) all_gather is correct here:
   Reconstruct detaches y_processed between steps, so no gradient flows
   through the gathered tensors.


.. py:function:: sync_domain_gradients(model, manager)

   Average gradients across the domain-parallel group.

   See credit.parallel.collectives.allreduce_grads_avg for the bucketing and
   DTensor handling.


