credit.models.checkpoint#

Attributes#

Classes#

TorchFSDPCheckpointIO

ModelWrapper

A wrapper class to define the common interface used FSDP.

TorchFSDPModel

A wrapper class to define the common interface used FSDP.

OptimizerWrapper

A standard interface for optimizers wrapped by the Booster.

FSDPOptimizerWrapper

A standard interface for optimizers wrapped by the Booster.

Functions#

load_state_dict_error_handler(load_msg)

get_file_extension(file_path)

copy_checkpoint(→ None)

Copy every checkpoint afterit's saved.

load_model_state(conf, model, device)

save_state_dict(→ None)

Save state dict to checkpoint.

load_state_dict(checkpoint_file_path)

Load state dict from checkpoint.

is_dtensor_checkpoint(→ bool)

Check whether the checkpoint file is a dtensor checkpoint.

is_safetensor_checkpoint(→ bool)

Check whether the checkpoint file is a safetensor checkpoint.

is_safetensors_available(→ bool)

Check whether safetensors is available.

Module Contents#

credit.models.checkpoint.load_state_dict_error_handler(load_msg)#
credit.models.checkpoint.get_file_extension(file_path)#
credit.models.checkpoint.copy_checkpoint(checkpoint_file_path: str, number) None#

Copy every checkpoint afterit’s saved.

Parameters:
  • state_dict (dict) – state dict.

  • checkpoint_file_path (str) – path to the checkpoint file.

  • use_safetensors (bool) – whether to use safetensors to save the checkpoint.

credit.models.checkpoint.load_model_state(conf, model, device)#
credit.models.checkpoint.save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) None#

Save state dict to checkpoint.

Parameters:
  • state_dict (dict) – state dict.

  • checkpoint_file_path (str) – path to the checkpoint file.

  • use_safetensors (bool) – whether to use safetensors to save the checkpoint.

credit.models.checkpoint.load_state_dict(checkpoint_file_path: pathlib.Path)#

Load state dict from checkpoint.

Parameters:

checkpoint_file_path (Path) – path to the checkpoint file.

Returns:

state dict.

Return type:

dict

credit.models.checkpoint.is_dtensor_checkpoint(checkpoint_file_path: str) bool#

Check whether the checkpoint file is a dtensor checkpoint.

Parameters:

checkpoint_file_path (str) – path to the checkpoint file.

Returns:

whether the checkpoint file is a dtensor checkpoint.

Return type:

bool

credit.models.checkpoint.is_safetensor_checkpoint(checkpoint_file_path: str) bool#

Check whether the checkpoint file is a safetensor checkpoint.

Parameters:

checkpoint_file_path (str) – path to the checkpoint file.

Returns:

whether the checkpoint file is a safetensor checkpoint.

Return type:

bool

credit.models.checkpoint.is_safetensors_available() bool#

Check whether safetensors is available.

Returns:

whether safetensors is available.

Return type:

bool

class credit.models.checkpoint.TorchFSDPCheckpointIO#
load_unsharded_model(model, checkpoint)#
load_unsharded_optimizer(optimizer, checkpoint)#
save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, rank)#

Save model to checkpoint but only on master process.

save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, rank)#

Save optimizer to checkpoint but only on master process.

class credit.models.checkpoint.ModelWrapper(module: torch.nn.Module)#

Bases: torch.nn.Module

A wrapper class to define the common interface used FSDP.

Parameters:

module (nn.Module) – The model to be wrapped.

module#
unwrap()#

Unwrap the model to return the original model for checkpoint saving/loading.

forward(*args, **kwargs)#
class credit.models.checkpoint.TorchFSDPModel(module, *args, **kwargs)#

Bases: ModelWrapper

A wrapper class to define the common interface used FSDP.

Parameters:

module (nn.Module) – The model to be wrapped.

module#
unwrap()#

Unwrap the model to return the original model for checkpoint saving/loading.

concat_and_reshape(x1, x2)#
reshape_only(x1)#

As in “concat_and_reshape”, but for upper-air variables only.

class credit.models.checkpoint.OptimizerWrapper(optim: torch.optim.Optimizer)#

A standard interface for optimizers wrapped by the Booster.

Parameters:

optim (Optimizer) – The optimizer to be wrapped.

optim#
property parameters#
property param_groups#
property defaults#
add_param_group(*args, **kwargs)#
step(*args, **kwargs)#

Performs a single optimization step.

zero_grad(*args, **kwargs)#

Clears the gradients of all optimized torch.Tensor.

backward(loss: torch.Tensor, *args, **kwargs)#

Performs a backward pass on the loss.

backward_by_grad(tensor: torch.Tensor, grad: torch.Tensor)#
state_dict()#

Returns the optimizer state.

load_state_dict(*args, **kwargs)#

Loads the optimizer state.

clip_grad_by_value(clip_value: float, *args, **kwargs) None#

Clips gradient of an iterable of parameters at specified min and max values.

Parameters:

clip_value (float or int) – maximum allowed value of the gradients. Gradients are clipped in the range

Note

In PyTorch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_value_ to use the faster implementation. Please refer to the PyTorch documentation for more details.

clip_grad_by_norm(max_norm: float | int, norm_type: float | int = 2.0, error_if_nonfinite: bool = False, *args, **kwargs) torch.Tensor#

Clips gradient norm of an iterable of parameters.

Parameters:
  • max_norm (float or int) – max norm of the gradients

  • norm_type (float or int) – type of the used p-norm. Can be 'inf' for infinity norm.

  • error_if_nonfinite (bool) – if True, an error is raised if the total norm is non-finite. Default: False

Note

In PyTorch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_norm_ to use the faster implementation. Please refer to the PyTorch documentation for more details.

abstractmethod scale_loss(loss: torch.Tensor)#

Scales the loss for mixed precision training.

Note: Only available for optimizers with mixed precision training.

Parameters:

loss (Tensor) – The loss to be scaled.

abstractmethod unscale_grad()#

Unscale the gradients for mixed precision training.

Note: Only available for optimizers with mixed precision training.

unwrap()#

Unwrap the optimizer for checkpoint saving/loading.

class credit.models.checkpoint.FSDPOptimizerWrapper(optimizer, model)#

Bases: OptimizerWrapper

A standard interface for optimizers wrapped by the Booster.

Parameters:

optim (Optimizer) – The optimizer to be wrapped.

model#
unwrap_model() torch.nn.Module#
credit.models.checkpoint.model#