credit.domain_parallel.halo_exchange
====================================

.. py:module:: credit.domain_parallel.halo_exchange

.. autoapi-nested-parse::

   Halo exchange for domain-parallel operations.

   Implements differentiable halo exchange along the sharding dimension using
   point-to-point communication (isend/irecv). The backward pass performs the
   reverse exchange so gradients flow correctly across domain boundaries.

   For latitude sharding of weather data:
   - The "previous" neighbor is the rank holding the region just north.
   - The "next" neighbor is the rank holding the region just south.
   - Edge ranks (poles) get zero-padded halos on their outer boundary,
     since TensorPadding already handled pole reflection before sharding.



Classes
-------

.. autoapisummary::

   credit.domain_parallel.halo_exchange._HaloExchangeFunction
   credit.domain_parallel.halo_exchange.HaloExchange


Module Contents
---------------

.. py:class:: _HaloExchangeFunction(*args, **kwargs)

   Bases: :py:obj:`torch.autograd.Function`


   Differentiable halo exchange.

   Forward: pads the tensor with halo rows received from neighbors.
   Backward: sends gradient halos back to the ranks that contributed them.


   .. py:method:: forward(ctx, x, halo_width, dim, manager)
      :staticmethod:


      Define the forward of the custom autograd Function.

      This function is to be overridden by all subclasses.
      There are two ways to define forward:

      Usage 1 (Combined forward and ctx)::

          @staticmethod
          def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
              pass

      - It must accept a context ctx as the first argument, followed by any
        number of arguments (tensors or other types).
      - See :ref:`combining-forward-context` for more details

      Usage 2 (Separate forward and ctx)::

          @staticmethod
          def forward(*args: Any, **kwargs: Any) -> Any:
              pass


          @staticmethod
          def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
              pass

      - The forward no longer accepts a ctx argument.
      - Instead, you must also override the :meth:`torch.autograd.Function.setup_context`
        staticmethod to handle setting up the ``ctx`` object.
        ``output`` is the output of the forward, ``inputs`` are a Tuple of inputs
        to the forward.
      - See :ref:`extending-autograd` for more details

      The context can be used to store arbitrary data that can be then
      retrieved during the backward pass. Tensors should not be stored
      directly on `ctx` (though this is not currently enforced for
      backward compatibility). Instead, tensors should be saved either with
      :func:`ctx.save_for_backward` if they are intended to be used in
      ``backward`` (equivalently, ``vjp``) or :func:`ctx.save_for_forward`
      if they are intended to be used for in ``jvp``.



   .. py:method:: backward(ctx, grad_output)
      :staticmethod:


      Define a formula for differentiating the operation with backward mode automatic differentiation.

      This function is to be overridden by all subclasses.
      (Defining this function is equivalent to defining the ``vjp`` function.)

      It must accept a context :attr:`ctx` as the first argument, followed by
      as many outputs as the :func:`forward` returned (None will be passed in
      for non tensor outputs of the forward function),
      and it should return as many tensors, as there were inputs to
      :func:`forward`. Each argument is the gradient w.r.t the given output,
      and each returned value should be the gradient w.r.t. the
      corresponding input. If an input is not a Tensor or is a Tensor not
      requiring grads, you can just pass None as a gradient for that input.

      The context can be used to retrieve tensors saved during the forward
      pass. It also has an attribute :attr:`ctx.needs_input_grad` as a tuple
      of booleans representing whether each input needs gradient. E.g.,
      :func:`backward` will have ``ctx.needs_input_grad[0] = True`` if the
      first input to :func:`forward` needs gradient computed w.r.t. the
      output.



.. py:class:: HaloExchange(halo_width, dim=-2)

   Bases: :py:obj:`torch.nn.Module`


   Halo exchange layer for domain-parallel operations.

   Pads the input tensor with halo rows from neighboring ranks along
   the sharding dimension. The operation is differentiable.

   :param halo_width: Number of rows to exchange on each side.
   :param dim: Tensor dimension to exchange along (default: -2 for lat in BCHW).


   .. py:attribute:: halo_width


   .. py:attribute:: dim
      :value: -2



   .. py:method:: forward(x)


   .. py:method:: trim(x, halo_before, halo_after, dim=-2)
      :staticmethod:


      Trim halo rows from the output after a convolution.

      :param x: Tensor with extra halo rows.
      :param halo_before: Number of rows to trim from the start.
      :param halo_after: Number of rows to trim from the end.
      :param dim: Dimension to trim along.



