credit.models.swin_wrf
======================

.. py:module:: credit.models.swin_wrf


Attributes
----------

.. autoapisummary::

   credit.models.swin_wrf.logger


Classes
-------

.. autoapisummary::

   credit.models.swin_wrf.CubeEmbedding
   credit.models.swin_wrf.DownBlock
   credit.models.swin_wrf.UpBlock
   credit.models.swin_wrf.UTransformer
   credit.models.swin_wrf.WRFTransformer


Functions
---------

.. autoapisummary::

   credit.models.swin_wrf.apply_spectral_norm
   credit.models.swin_wrf.get_pad3d
   credit.models.swin_wrf.get_pad2d


Module Contents
---------------

.. py:data:: logger

.. py:function:: apply_spectral_norm(model)

   add spectral norm to all the conv and linear layers


.. py:function:: get_pad3d(input_resolution, window_size)

   Estimate the size of padding based on the given window size and the original input size.

   :param input_resolution: (Pl, Lat, Lon)
   :type input_resolution: tuple[int]
   :param window_size: (Pl, Lat, Lon)
   :type window_size: tuple[int]

   :returns: (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)
   :rtype: padding (tuple[int])


.. py:function:: get_pad2d(input_resolution, window_size)

   :param input_resolution: Lat, Lon
   :type input_resolution: tuple[int]
   :param window_size: Lat, Lon
   :type window_size: tuple[int]

   :returns: (padding_left, padding_right, padding_top, padding_bottom)
   :rtype: padding (tuple[int])


.. py:class:: CubeEmbedding(img_size, patch_size, in_chans, embed_dim, norm_layer=nn.LayerNorm)

   Bases: :py:obj:`torch.nn.Module`


   :param img_size: T, Lat, Lon
   :param patch_size: T, Lat, Lon


   .. py:attribute:: img_size


   .. py:attribute:: patches_resolution


   .. py:attribute:: embed_dim


   .. py:attribute:: proj


   .. py:method:: forward(x: torch.Tensor)


.. py:class:: DownBlock(in_chans: int, out_chans: int, num_groups: int, num_residuals: int = 2)

   Bases: :py:obj:`torch.nn.Module`


   Base class for all neural network modules.

   Your models should also subclass this class.

   Modules can also contain other Modules, allowing them to be nested in
   a tree structure. You can assign the submodules as regular attributes::

       import torch.nn as nn
       import torch.nn.functional as F


       class Model(nn.Module):
           def __init__(self) -> None:
               super().__init__()
               self.conv1 = nn.Conv2d(1, 20, 5)
               self.conv2 = nn.Conv2d(20, 20, 5)

           def forward(self, x):
               x = F.relu(self.conv1(x))
               return F.relu(self.conv2(x))

   Submodules assigned in this way will be registered, and will also have their
   parameters converted when you call :meth:`to`, etc.

   .. note::
       As per the example above, an ``__init__()`` call to the parent class
       must be made before assignment on the child.

   :ivar training: Boolean represents whether this module is in training or
                   evaluation mode.
   :vartype training: bool


   .. py:attribute:: conv


   .. py:attribute:: b


   .. py:method:: forward(x)


.. py:class:: UpBlock(in_chans, out_chans, num_groups, num_residuals=2)

   Bases: :py:obj:`torch.nn.Module`


   Base class for all neural network modules.

   Your models should also subclass this class.

   Modules can also contain other Modules, allowing them to be nested in
   a tree structure. You can assign the submodules as regular attributes::

       import torch.nn as nn
       import torch.nn.functional as F


       class Model(nn.Module):
           def __init__(self) -> None:
               super().__init__()
               self.conv1 = nn.Conv2d(1, 20, 5)
               self.conv2 = nn.Conv2d(20, 20, 5)

           def forward(self, x):
               x = F.relu(self.conv1(x))
               return F.relu(self.conv2(x))

   Submodules assigned in this way will be registered, and will also have their
   parameters converted when you call :meth:`to`, etc.

   .. note::
       As per the example above, an ``__init__()`` call to the parent class
       must be made before assignment on the child.

   :ivar training: Boolean represents whether this module is in training or
                   evaluation mode.
   :vartype training: bool


   .. py:attribute:: conv


   .. py:attribute:: b


   .. py:method:: forward(x)


.. py:class:: UTransformer(embed_dim, num_groups, input_resolution, num_heads, window_size, depth, drop_path)

   Bases: :py:obj:`torch.nn.Module`


   U-Transformer
   :param embed_dim: Patch embedding dimension.
   :type embed_dim: int
   :param num_groups: number of groups to separate the channels into.
   :type num_groups: int | tuple[int]
   :param input_resolution: Lat, Lon.
   :type input_resolution: tuple[int]
   :param num_heads: Number of attention heads in different layers.
   :type num_heads: int
   :param window_size: Window size.
   :type window_size: int | tuple[int]
   :param depth: Number of blocks.
   :type depth: int


   .. py:attribute:: padding


   .. py:attribute:: pad


   .. py:attribute:: down


   .. py:attribute:: layer


   .. py:attribute:: up


   .. py:method:: forward(x)


.. py:class:: WRFTransformer(param_interior, param_outside, time_encode_dim=12, num_groups=32, num_heads=8, depth=48, window_size=7, use_spectral_norm=True, interp=True, drop_path=0, padding_conf=None, post_conf=None, **kwargs)

   Bases: :py:obj:`credit.models.base_model.BaseModel`


   :param img_size: T, Lat, Lon.
   :type img_size: Sequence[int], optional
   :param patch_size: T, Lat, Lon.
   :type patch_size: Sequence[int], optional
   :param in_chans: number of input channels.
   :type in_chans: int, optional
   :param out_chans: number of output channels.
   :type out_chans: int, optional
   :param dim: number of embed channels.
   :type dim: int, optional
   :param num_groups: number of groups to separate the channels into.
   :type num_groups: Sequence[int] | int, optional
   :param num_heads: Number of attention heads.
   :type num_heads: int, optional
   :param window_size: Local window size.
   :type window_size: int | tuple[int], optional


   .. py:attribute:: time_encode
      :value: 12



   .. py:attribute:: use_interp
      :value: True



   .. py:attribute:: use_spectral_norm
      :value: True



   .. py:attribute:: use_padding


   .. py:attribute:: use_post_block


   .. py:attribute:: cube_embedding_inside


   .. py:attribute:: cube_embedding_outside


   .. py:attribute:: total_dim


   .. py:attribute:: u_transformer


   .. py:attribute:: fc


   .. py:attribute:: patch_size


   .. py:attribute:: input_resolution


   .. py:attribute:: out_chans


   .. py:attribute:: img_size


   .. py:attribute:: film


   .. py:method:: _match_spatial(src: torch.Tensor, ref: torch.Tensor)


   .. py:method:: forward(x: torch.Tensor, x_outside: torch.Tensor, x_extra: torch.Tensor)


