credit.models.wxformer.sdl_inference_wrapper#
Attributes#
Classes#
Specialized wrapper for hurricane track stylization with latent vector control. |
Functions#
|
Module Contents#
- credit.models.wxformer.sdl_inference_wrapper.logger#
- class credit.models.wxformer.sdl_inference_wrapper.SDLWrapper(pretrained_model: torch.nn.Module, channel_config: Dict = None)#
Bases:
torch.nn.ModuleSpecialized wrapper for hurricane track stylization with latent vector control.
Capabilities: - Directional bias injection (steering flow modification) - Intensity-dependent noise scaling - Store and retrieve latent vectors Z for exact forecast reproduction - Interpolate between latent vectors for smooth ensemble exploration
- model#
- _noise_layers = []#
- _original_factors#
- _stored_latents: Dict[str, Dict]#
- _current_latents: List[torch.Tensor] | None = None#
- _current_timestep_map: List[int] | None = None#
- _capture_enabled = False#
- _collect_noise_layers() List[torch.nn.Module]#
Collect all PixelNoiseInjection layers from the model.
- get_noise_factors() List[float]#
Get current noise factors from all layers.
- set_noise_factors(factors: float | List[float])#
Set noise factors for all layers.
- reset_to_original()#
Reset to original pretrained noise factors.
- set_encoder_noise_factors(factors: float | List[float])#
Set encoder noise factors.
- set_decoder_noise_factors(factors: float | List[float])#
Set decoder noise factors.
- set_decoder_modulation(target_channels: List[int] = None, weight: float = 2.0)#
Set decoder modulation weights for specific channels.
- set_decoder_style_vector(channel_weights: Dict[int, float])#
Modify the model’s style transformation weights.
- set_manual_factors(large_scale: float, medium_scale: float, fine_scale: float)#
Manually set decoder noise factors with optional variable and level targeting.
- enable_latent_capture()#
Enable capturing of latent vectors during forward pass.
- disable_latent_capture()#
Disable latent vector capturing and restore original forward methods.
- store_latents(name: str)#
Store the captured latent vectors with a given name.
- get_stored_latents(name: str) Dict | None#
Retrieve stored latent vectors by name.
- list_stored_latents() List[str]#
List all stored latent vector names.
- clear_stored_latents(name: str | None = None)#
Clear stored latents (all or specific name).
- reset_latent_capture()#
Reset latent capture state - useful if something went wrong.
- interpolate_latents(name1: str, name2: str, t: float) Dict#
Linear interpolation between two stored latent vectors. Z_t = (1-t)*Z1 + t*Z2
- Parameters:
name1 – First latent vector identifier
name2 – Second latent vector identifier
t – Interpolation parameter [0, 1]
- calculate_mslp_and_append(y_arr, datetime_ref, latlons, surface_geopotential, conf, ensemble_size=1)#
Calculate Mean Sea Level Pressure (MSLP) and append it to the input tensor.
- process_pressure_interp(y_phys, y_pred_phys, batch, latlons, surface_geopotential, conf)#
Process MSLP for entire batches of truth and prediction tensors.
- forward(*args, **kwargs)#
Forward pass through the model.
- _forward_with_latent_control(x, forecast_step: int, use_latents: Dict | None = None)#
Internal forward pass with optional latent control.
- Parameters:
x – Input tensor
forecast_step – Current forecast step (0-indexed)
use_latents – Dict with ‘latents_by_timestep’ containing noise deltas
- rollout_forecast(data_loader, state_transformer=None, ensemble_size: int = 1, history_len: int = 1, device: str = 'cuda', metrics_fn=None, capture_latents: bool = False, store_latents_as: str | None = None, use_stored_latents: str | None = None, use_interpolated_latents: Tuple[str, str, float] | None = None, conf: Dict | None = None) Dict#
Unified hurricane forecast rollout with integrated latent vector control.
- Parameters:
data_loader – DataLoader containing hurricane data batches
state_transformer – Transformer for normalization/denormalization
ensemble_size – Number of ensemble members
history_len – Length of input history
device – Device to run inference on
metrics_fn – Function to compute metrics per step
conf – Configuration dictionary (required for MSLP calculation)
parameters (# Latent control)
capture_latents – If True, capture latent vectors during this rollout
store_latents_as – Name to store captured latents (implies capture_latents=True)
use_stored_latents – Name of stored latents to use for exact reproduction
use_interpolated_latents – Tuple of (name1, name2, t) for interpolation
- Returns:
Dict with ‘predictions’, ‘truth’, ‘metrics’, and optional ‘latent_name’
- generate_interpolation_sequence(dataset, name1: str, name2: str, num_steps: int = 5, state_transformer=None, device: str = 'cuda', conf: Dict | None = None, **rollout_kwargs) List[Dict]#
Generate sequence of forecasts along interpolation path.
- Parameters:
dataset – Dataset (will be deep copied for each interpolation)
name1 – Start latent vector
name2 – End latent vector
num_steps – Number of interpolation points (default: 5)
state_transformer – Transformer for normalization
device – Device to run on
conf – Configuration dictionary
**rollout_kwargs – Additional arguments for rollout_forecast
- Returns:
List of forecast dicts, one per interpolation point
- scale_latents(name: str, beta: float) Dict#
Scale stored latent vectors by factor beta. Z_scaled = beta * Z
- Parameters:
name – Latent vector identifier
beta – Scaling factor
- Returns:
Scaled latent dict
- generate_scaled_ensemble(dataset, base_latent_name: str, beta_values: List[float], state_transformer=None, device: str = 'cuda', conf: Dict | None = None, **rollout_kwargs) List[Dict]#
Generate ensemble with different scaling factors applied to base latent.
- Parameters:
dataset – Dataset for forecasting
base_latent_name – Name of base latent vector to scale
beta_values – List of scaling factors to apply
state_transformer – Transformer for normalization
device – Device to run on
conf – Configuration dictionary
**rollout_kwargs – Additional arguments for rollout_forecast
- Returns:
List of forecast dicts, one per beta value
- scale_latents_multilevel(name: str, beta_per_layer: List[float]) Dict#
Scale stored latent vectors with different beta for each layer.
- Parameters:
name – Latent vector identifier
beta_per_layer – List of scaling factors, one per noise injection layer Must match the number of layers in self._noise_layers
- Returns:
Scaled latent dict with ‘latents_by_timestep’ structure
- credit.models.wxformer.sdl_inference_wrapper.run_multiscale_control_experiment(experiments, wrapper, dataset, state_transformer, conf, lat_centers, lon_centers, device='cuda', ensemble_size=1) Dict[str, Dict]#
- credit.models.wxformer.sdl_inference_wrapper.parser#