credit.models.wxformer.sdl_inference_wrapper#

Attributes#

Classes#

SDLWrapper

Specialized wrapper for hurricane track stylization with latent vector control.

Functions#

run_multiscale_control_experiment(→ Dict[str, Dict])

Module Contents#

credit.models.wxformer.sdl_inference_wrapper.logger#
class credit.models.wxformer.sdl_inference_wrapper.SDLWrapper(pretrained_model: torch.nn.Module, channel_config: Dict = None)#

Bases: torch.nn.Module

Specialized wrapper for hurricane track stylization with latent vector control.

Capabilities: - Directional bias injection (steering flow modification) - Intensity-dependent noise scaling - Store and retrieve latent vectors Z for exact forecast reproduction - Interpolate between latent vectors for smooth ensemble exploration

model#
_noise_layers = []#
_original_factors#
_stored_latents: Dict[str, Dict]#
_current_latents: List[torch.Tensor] | None = None#
_current_timestep_map: List[int] | None = None#
_capture_enabled = False#
_collect_noise_layers() List[torch.nn.Module]#

Collect all PixelNoiseInjection layers from the model.

get_noise_factors() List[float]#

Get current noise factors from all layers.

set_noise_factors(factors: float | List[float])#

Set noise factors for all layers.

reset_to_original()#

Reset to original pretrained noise factors.

set_encoder_noise_factors(factors: float | List[float])#

Set encoder noise factors.

set_decoder_noise_factors(factors: float | List[float])#

Set decoder noise factors.

set_decoder_modulation(target_channels: List[int] = None, weight: float = 2.0)#

Set decoder modulation weights for specific channels.

set_decoder_style_vector(channel_weights: Dict[int, float])#

Modify the model’s style transformation weights.

set_manual_factors(large_scale: float, medium_scale: float, fine_scale: float)#

Manually set decoder noise factors with optional variable and level targeting.

enable_latent_capture()#

Enable capturing of latent vectors during forward pass.

disable_latent_capture()#

Disable latent vector capturing and restore original forward methods.

store_latents(name: str)#

Store the captured latent vectors with a given name.

get_stored_latents(name: str) Dict | None#

Retrieve stored latent vectors by name.

list_stored_latents() List[str]#

List all stored latent vector names.

clear_stored_latents(name: str | None = None)#

Clear stored latents (all or specific name).

reset_latent_capture()#

Reset latent capture state - useful if something went wrong.

interpolate_latents(name1: str, name2: str, t: float) Dict#

Linear interpolation between two stored latent vectors. Z_t = (1-t)*Z1 + t*Z2

Parameters:
  • name1 – First latent vector identifier

  • name2 – Second latent vector identifier

  • t – Interpolation parameter [0, 1]

calculate_mslp_and_append(y_arr, datetime_ref, latlons, surface_geopotential, conf, ensemble_size=1)#

Calculate Mean Sea Level Pressure (MSLP) and append it to the input tensor.

process_pressure_interp(y_phys, y_pred_phys, batch, latlons, surface_geopotential, conf)#

Process MSLP for entire batches of truth and prediction tensors.

forward(*args, **kwargs)#

Forward pass through the model.

_forward_with_latent_control(x, forecast_step: int, use_latents: Dict | None = None)#

Internal forward pass with optional latent control.

Parameters:
  • x – Input tensor

  • forecast_step – Current forecast step (0-indexed)

  • use_latents – Dict with ‘latents_by_timestep’ containing noise deltas

rollout_forecast(data_loader, state_transformer=None, ensemble_size: int = 1, history_len: int = 1, device: str = 'cuda', metrics_fn=None, capture_latents: bool = False, store_latents_as: str | None = None, use_stored_latents: str | None = None, use_interpolated_latents: Tuple[str, str, float] | None = None, conf: Dict | None = None) Dict#

Unified hurricane forecast rollout with integrated latent vector control.

Parameters:
  • data_loader – DataLoader containing hurricane data batches

  • state_transformer – Transformer for normalization/denormalization

  • ensemble_size – Number of ensemble members

  • history_len – Length of input history

  • device – Device to run inference on

  • metrics_fn – Function to compute metrics per step

  • conf – Configuration dictionary (required for MSLP calculation)

  • parameters (# Latent control)

  • capture_latents – If True, capture latent vectors during this rollout

  • store_latents_as – Name to store captured latents (implies capture_latents=True)

  • use_stored_latents – Name of stored latents to use for exact reproduction

  • use_interpolated_latents – Tuple of (name1, name2, t) for interpolation

Returns:

Dict with ‘predictions’, ‘truth’, ‘metrics’, and optional ‘latent_name’

generate_interpolation_sequence(dataset, name1: str, name2: str, num_steps: int = 5, state_transformer=None, device: str = 'cuda', conf: Dict | None = None, **rollout_kwargs) List[Dict]#

Generate sequence of forecasts along interpolation path.

Parameters:
  • dataset – Dataset (will be deep copied for each interpolation)

  • name1 – Start latent vector

  • name2 – End latent vector

  • num_steps – Number of interpolation points (default: 5)

  • state_transformer – Transformer for normalization

  • device – Device to run on

  • conf – Configuration dictionary

  • **rollout_kwargs – Additional arguments for rollout_forecast

Returns:

List of forecast dicts, one per interpolation point

scale_latents(name: str, beta: float) Dict#

Scale stored latent vectors by factor beta. Z_scaled = beta * Z

Parameters:
  • name – Latent vector identifier

  • beta – Scaling factor

Returns:

Scaled latent dict

generate_scaled_ensemble(dataset, base_latent_name: str, beta_values: List[float], state_transformer=None, device: str = 'cuda', conf: Dict | None = None, **rollout_kwargs) List[Dict]#

Generate ensemble with different scaling factors applied to base latent.

Parameters:
  • dataset – Dataset for forecasting

  • base_latent_name – Name of base latent vector to scale

  • beta_values – List of scaling factors to apply

  • state_transformer – Transformer for normalization

  • device – Device to run on

  • conf – Configuration dictionary

  • **rollout_kwargs – Additional arguments for rollout_forecast

Returns:

List of forecast dicts, one per beta value

scale_latents_multilevel(name: str, beta_per_layer: List[float]) Dict#

Scale stored latent vectors with different beta for each layer.

Parameters:
  • name – Latent vector identifier

  • beta_per_layer – List of scaling factors, one per noise injection layer Must match the number of layers in self._noise_layers

Returns:

Scaled latent dict with ‘latents_by_timestep’ structure

credit.models.wxformer.sdl_inference_wrapper.run_multiscale_control_experiment(experiments, wrapper, dataset, state_transformer, conf, lat_centers, lon_centers, device='cuda', ensemble_size=1) Dict[str, Dict]#
credit.models.wxformer.sdl_inference_wrapper.parser#