credit.ensemble.bred_vector#
Attributes#
Classes#
Functions#
|
Generate bred vectors and initialize initial conditions for the given batch. |
|
Generate bred vectors and initialize initial conditions for the given batch. |
|
Clones a PyTorch Dataset by creating a deep copy. |
|
Adjusts the start times by subtracting 24 hours. |
Module Contents#
- class credit.ensemble.bred_vector.BredVector(model: Callable[[torch.Tensor], torch.Tensor], noise_amplitude: float = 0.15, num_cycles: int = 5, integration_steps: int = 1, perturbation_method: Callable[[torch.Tensor, collections.OrderedDict[str, numpy.ndarray]], torch.Tensor] | None = None, hemispheric_rescale: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, terrain_file: str = None, perturb_channel_idx: int = None, ensemble_perturb: bool = False, clamp: bool = False, clamp_min: float = None, clamp_max: float = None, input_static_dim: int = 3, varnum_diag: int = 0, post_conf: dict = {})#
- model#
- noise_amplitude = 0.15#
- num_cycles = 5#
- integration_steps = 1#
- perturbation_method = None#
- hemispheric_rescale#
- ensemble_perturb = False#
- perturb_channel_idx = None#
- clamp = False#
- clamp_min = None#
- clamp_max = None#
- input_static_dim = 3#
- varnum_diag = 0#
- post_conf#
- flag_mass_conserve = False#
- flag_water_conserve = False#
- flag_energy_conserve = False#
- use_post_block = False#
- perturb(x_input: torch.Tensor, forecast_step: int = 1) torch.Tensor#
- __call__(initial_condition: torch.Tensor, dataset, return_delta_x=False) list[torch.Tensor]#
- credit.ensemble.bred_vector.generate_bred_vectors(x_batch, model, num_cycles=5, perturbation_std=0.15, epsilon=1.0, flag_clamp=False, clamp_min=None, clamp_max=None)#
Generate bred vectors and initialize initial conditions for the given batch.
- Parameters:
x_batch (torch.Tensor) – The input batch.
batch (dict) – A dictionary containing additional batch data.
model (nn.Module) – The model used for predictions.
num_cycles (int) – Number of perturbation cycles.
perturbation_std (float) – Magnitude of initial perturbations.
epsilon (float) – Scaling factor for bred vectors.
flag_clamp (bool, optional) – Whether to clamp inputs. Defaults to False.
clamp_min (float, optional) – Minimum clamp value. Required if flag_clamp is True.
clamp_max (float, optional) – Maximum clamp value. Required if flag_clamp is True.
- Returns:
List of initial conditions generated using bred vectors.
- Return type:
list[torch.Tensor]
- credit.ensemble.bred_vector.generate_bred_vectors_cycle(initial_condition, dataset, model, num_cycles=5, perturbation_std=0.15, epsilon=1.0, flag_clamp=False, clamp_min=None, clamp_max=None, device='cuda', history_len=1, varnum_diag=None, static_dim_size=None, post_conf={})#
Generate bred vectors and initialize initial conditions for the given batch.
- Parameters:
x_batch (torch.Tensor) – The input batch.
batch (dict) – A dictionary containing additional batch data.
model (nn.Module) – The model used for predictions.
num_cycles (int) – Number of perturbation cycles.
perturbation_std (float) – Magnitude of initial perturbations.
epsilon (float) – Scaling factor for bred vectors.
flag_clamp (bool, optional) – Whether to clamp inputs. Defaults to False.
clamp_min (float, optional) – Minimum clamp value. Required if flag_clamp is True.
clamp_max (float, optional) – Maximum clamp value. Required if flag_clamp is True.
- Returns:
List of initial conditions generated using bred vectors.
- Return type:
list[torch.Tensor]
- credit.ensemble.bred_vector.clone_dataset(dataset)#
Clones a PyTorch Dataset by creating a deep copy.
- Parameters:
dataset (torch.utils.data.Dataset) – The original dataset.
- Returns:
A cloned dataset.
- Return type:
torch.utils.data.Dataset
- credit.ensemble.bred_vector.adjust_start_times(time_ranges, hours=24)#
Adjusts the start times by subtracting 24 hours.
- Parameters:
time_ranges (list of lists) – Each sublist contains [start_time, end_time] as strings.
- Returns:
Adjusted time ranges [[start_time - 24hrs, start_time], …]
- Return type:
list of lists
- credit.ensemble.bred_vector.logger#