credit.models.wxformer.sdl_inference_wrapper
============================================

.. py:module:: credit.models.wxformer.sdl_inference_wrapper


Attributes
----------

.. autoapisummary::

   credit.models.wxformer.sdl_inference_wrapper.logger
   credit.models.wxformer.sdl_inference_wrapper.parser


Classes
-------

.. autoapisummary::

   credit.models.wxformer.sdl_inference_wrapper.SDLWrapper


Functions
---------

.. autoapisummary::

   credit.models.wxformer.sdl_inference_wrapper.run_multiscale_control_experiment


Module Contents
---------------

.. py:data:: logger

.. py:class:: SDLWrapper(pretrained_model: torch.nn.Module, channel_config: Dict = None)

   Bases: :py:obj:`torch.nn.Module`


   Specialized wrapper for hurricane track stylization with latent vector control.

   Capabilities:
   - Directional bias injection (steering flow modification)
   - Intensity-dependent noise scaling
   - Store and retrieve latent vectors Z for exact forecast reproduction
   - Interpolate between latent vectors for smooth ensemble exploration


   .. py:attribute:: model


   .. py:attribute:: _noise_layers
      :value: []



   .. py:attribute:: _original_factors


   .. py:attribute:: _stored_latents
      :type:  Dict[str, Dict]


   .. py:attribute:: _current_latents
      :type:  Optional[List[torch.Tensor]]
      :value: None



   .. py:attribute:: _current_timestep_map
      :type:  Optional[List[int]]
      :value: None



   .. py:attribute:: _capture_enabled
      :value: False



   .. py:method:: _collect_noise_layers() -> List[torch.nn.Module]

      Collect all PixelNoiseInjection layers from the model.



   .. py:method:: get_noise_factors() -> List[float]

      Get current noise factors from all layers.



   .. py:method:: set_noise_factors(factors: Union[float, List[float]])

      Set noise factors for all layers.



   .. py:method:: reset_to_original()

      Reset to original pretrained noise factors.



   .. py:method:: set_encoder_noise_factors(factors: Union[float, List[float]])

      Set encoder noise factors.



   .. py:method:: set_decoder_noise_factors(factors: Union[float, List[float]])

      Set decoder noise factors.



   .. py:method:: set_decoder_modulation(target_channels: List[int] = None, weight: float = 2.0)

      Set decoder modulation weights for specific channels.



   .. py:method:: set_decoder_style_vector(channel_weights: Dict[int, float])

      Modify the model's style transformation weights.



   .. py:method:: set_manual_factors(large_scale: float, medium_scale: float, fine_scale: float)

      Manually set decoder noise factors with optional variable and level targeting.



   .. py:method:: enable_latent_capture()

      Enable capturing of latent vectors during forward pass.



   .. py:method:: disable_latent_capture()

      Disable latent vector capturing and restore original forward methods.



   .. py:method:: store_latents(name: str)

      Store the captured latent vectors with a given name.



   .. py:method:: get_stored_latents(name: str) -> Optional[Dict]

      Retrieve stored latent vectors by name.



   .. py:method:: list_stored_latents() -> List[str]

      List all stored latent vector names.



   .. py:method:: clear_stored_latents(name: Optional[str] = None)

      Clear stored latents (all or specific name).



   .. py:method:: reset_latent_capture()

      Reset latent capture state - useful if something went wrong.



   .. py:method:: interpolate_latents(name1: str, name2: str, t: float) -> Dict

      Linear interpolation between two stored latent vectors.
      Z_t = (1-t)*Z1 + t*Z2

      :param name1: First latent vector identifier
      :param name2: Second latent vector identifier
      :param t: Interpolation parameter [0, 1]



   .. py:method:: calculate_mslp_and_append(y_arr, datetime_ref, latlons, surface_geopotential, conf, ensemble_size=1)

      Calculate Mean Sea Level Pressure (MSLP) and append it to the input tensor.



   .. py:method:: process_pressure_interp(y_phys, y_pred_phys, batch, latlons, surface_geopotential, conf)

      Process MSLP for entire batches of truth and prediction tensors.



   .. py:method:: forward(*args, **kwargs)

      Forward pass through the model.



   .. py:method:: _forward_with_latent_control(x, forecast_step: int, use_latents: Optional[Dict] = None)

      Internal forward pass with optional latent control.

      :param x: Input tensor
      :param forecast_step: Current forecast step (0-indexed)
      :param use_latents: Dict with 'latents_by_timestep' containing noise deltas



   .. py:method:: rollout_forecast(data_loader, state_transformer=None, ensemble_size: int = 1, history_len: int = 1, device: str = 'cuda', metrics_fn=None, capture_latents: bool = False, store_latents_as: Optional[str] = None, use_stored_latents: Optional[str] = None, use_interpolated_latents: Optional[Tuple[str, str, float]] = None, conf: Optional[Dict] = None) -> Dict

      Unified hurricane forecast rollout with integrated latent vector control.

      :param data_loader: DataLoader containing hurricane data batches
      :param state_transformer: Transformer for normalization/denormalization
      :param ensemble_size: Number of ensemble members
      :param history_len: Length of input history
      :param device: Device to run inference on
      :param metrics_fn: Function to compute metrics per step
      :param conf: Configuration dictionary (required for MSLP calculation)
      :param # Latent control parameters:
      :param capture_latents: If True, capture latent vectors during this rollout
      :param store_latents_as: Name to store captured latents (implies capture_latents=True)
      :param use_stored_latents: Name of stored latents to use for exact reproduction
      :param use_interpolated_latents: Tuple of (name1, name2, t) for interpolation

      :returns: Dict with 'predictions', 'truth', 'metrics', and optional 'latent_name'



   .. py:method:: generate_interpolation_sequence(dataset, name1: str, name2: str, num_steps: int = 5, state_transformer=None, device: str = 'cuda', conf: Optional[Dict] = None, **rollout_kwargs) -> List[Dict]

      Generate sequence of forecasts along interpolation path.

      :param dataset: Dataset (will be deep copied for each interpolation)
      :param name1: Start latent vector
      :param name2: End latent vector
      :param num_steps: Number of interpolation points (default: 5)
      :param state_transformer: Transformer for normalization
      :param device: Device to run on
      :param conf: Configuration dictionary
      :param \*\*rollout_kwargs: Additional arguments for rollout_forecast

      :returns: List of forecast dicts, one per interpolation point



   .. py:method:: scale_latents(name: str, beta: float) -> Dict

      Scale stored latent vectors by factor beta.
      Z_scaled = beta * Z

      :param name: Latent vector identifier
      :param beta: Scaling factor

      :returns: Scaled latent dict



   .. py:method:: generate_scaled_ensemble(dataset, base_latent_name: str, beta_values: List[float], state_transformer=None, device: str = 'cuda', conf: Optional[Dict] = None, **rollout_kwargs) -> List[Dict]

      Generate ensemble with different scaling factors applied to base latent.

      :param dataset: Dataset for forecasting
      :param base_latent_name: Name of base latent vector to scale
      :param beta_values: List of scaling factors to apply
      :param state_transformer: Transformer for normalization
      :param device: Device to run on
      :param conf: Configuration dictionary
      :param \*\*rollout_kwargs: Additional arguments for rollout_forecast

      :returns: List of forecast dicts, one per beta value



   .. py:method:: scale_latents_multilevel(name: str, beta_per_layer: List[float]) -> Dict

      Scale stored latent vectors with different beta for each layer.

      :param name: Latent vector identifier
      :param beta_per_layer: List of scaling factors, one per noise injection layer
                             Must match the number of layers in self._noise_layers

      :returns: Scaled latent dict with 'latents_by_timestep' structure



.. py:function:: run_multiscale_control_experiment(experiments, wrapper, dataset, state_transformer, conf, lat_centers, lon_centers, device='cuda', ensemble_size=1) -> Dict[str, Dict]

.. py:data:: parser

