Checkpointing and Fine-Tuning a Model with Model Parallelism
The SageMaker model parallelism library provides checkpointing APIs to save the model state and the optimizer state split by the various model parallelism strategies, and to load checkpoints for continuous training from where you want to restart training and fine-tune. The APIs also support options to save the model and optimizer states partially or fully.
Checkpointing a distributed model
Choose one of the following topics depending on the framework between PyTorch and TensorFlow and the version of the SageMaker model parallelism library you use.
Topics
Checkpointing a distributed PyTorch model (for the SageMaker model parallelism library v1.10.0 and later)
The SageMaker model parallelism library provides checkpoint APIs to save and load full or partial checkpoints of the distributed model state and its optimizer state.
Note
This checkpointing method is recommended if you use PyTorch and the SageMaker model parallelism library v1.10.0 or later.
Partial checkpointing
To save checkpoints of a model trained with model parallelism, use the smdistributed.modelparallel.torch.save_checkpointpartial=True). This
                saves each model partition individually. In addition to the model and the optimizer
                state, you can also save any additional custom data through the
                    user_content argument. The checkpointed model, optimizer, and user
                content are saved as separate files. The save_checkpoint API call
                creates checkpoint folders in the following structure. 
- path - ${tag}_partial (folder for partial checkpoints) - model_rankinfo.pt - optimizer_rankinfo.pt - fp16_states_rankinfo.pt - user_content.pt - $tag (checkpoint file for full checkpoints) - user_content_$tag (user_content file for full checkpoints) - newest (a file that indicates the newest checkpoint)
To resume training from partial checkpoints, use the smdistributed.modelparallel.torch.resume_from_checkpointpartial=True, and specify the checkpoint directory and the tag
                used while saving the partial checkpoints. Note that the actual loading of model
                weights happens after model partitioning, during the first run of the
                    smdistributed.modelparallel.torch.step-decorated training step
                function.
When saving a partial checkpoint, the library also saves the model partition decision
                as files with .pt file extension. Conversely, when resuming from the
                partial checkpoint, the library loads the partition decision files together. Once the
                partition decision is loaded, you can't change the partition.
The following code snippet shows how to set the checkpoint APIs in a PyTorch training script.
import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "/opt/ml/checkpoint/model_parallel" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path, tag=f"total_steps{total_steps}", partial=True, model=model, optimizer=optimizer, user_content=user_contentnum_kept_partial_checkpoints=5) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path, partial=True)
Full checkpointing
To save the final model artifact for inference purposes, use the
                smdistributed.modelparallel.torch.save_checkpoint API with
                partial=False, which combines the model partitions to create a single
                model artifact. Note that this does not combine the optimizer states.
To initialize training with particular weights, given a full model checkpoint, you can
                use the smdistributed.modelparallel.torch.resume_from_checkpoint API with
                partial=False. Note that this does not load optimizer states.
Note
With tensor parallelism, in general, the state_dict must be
                    translated between the original model implementation and the
                    DistributedModel implementation. Optionally, you can provide the
                    state_dict translation function as an argument to the
                    smdistributed.modelparallel.torch.resume_from_checkpoint. However,
                    for Supported Models Out of the Box, the library takes care of this translation automatically.
The following code shows an example of how to use the checkpoint APIs for fully checkpointing a PyTorch model trained with model parallelism.
import smdistributed.modelparallel.torch as smp model = ... model = smp.DistributedModel(model) optimizer = ... optimizer = smp.DistributedOptimizer(optimizer) user_content = ... # additional custom data checkpoint_path = "/opt/ml/checkpoint/model_parallel" # Save a checkpoint. smp.save_checkpoint( path=checkpoint_path, tag=f"total_steps{total_steps}", partial=False, model=model, optimizer=optimizer, user_content=user_contentnum_kept_partial_checkpoints=5) # Load a checkpoint. # This automatically loads the most recently saved checkpoint. smp_checkpoint = smp.resume_from_checkpoint( path=checkpoint_path, partial=False)
Checkpointing a distributed PyTorch model (for the SageMaker model parallelism library between v1.6.0 and v1.9.0)
The SageMaker model parallelism library provides Python functions for saving partial or
                full checkpoints for training jobs with tensor parallelism. The following procedure
                shows how to use smp.save()smp.load()
Note
This checkpointing method is recommended if you use PyTorch, Tensor Parallelism, and the SageMaker model parallelism library between v1.6.0 and v1.9.0.
- 
                    Prepare a model object and wrap it with the library's wrapper function smp.DistributedModel().model = MyModel(...) model = smp.DistributedModel(model)
- 
                    Prepare an optimizer for the model. A set of model parameters is an iterable argument required by optimizer functions. To prepare a set of model parameters, you must process model.parameters()to assign unique IDs to individual model parameters.If there are parameters with duplicated IDs in the model parameter iterable, loading the checkpointed optimizer state fails. To create an iterable of model parameters with unique IDs for your optimizer, see the following: unique_params = [] unique_params_set = set() for p in model.parameters(): if p not in unique_params_set: unique_params.append(p) unique_params_set.add(p) del unique_params_set optimizer = MyOpt(unique_params, ...)
- 
                    Wrap the optimizer using the library's wrapper function smp.DistributedOptimizer().optimizer = smp.DistributedOptimizer(optimizer)
- 
                    Save the model and the optimizer state using smp.save(). Depending on how you want to save checkpoints, choose one of the following two options: - 
                            Option 1: Save a partial model on each mp_rankfor a singleMP_GROUP.model_dict = model.local_state_dict() # save a partial model opt_dict = optimizer.local_state_dict() # save a partial optimizer state # Save the dictionaries at rdp_rank 0 as a checkpoint if smp.rdp_rank() == 0: smp.save( {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, f"/checkpoint.pt", partial=True, )With tensor parallelism, the library saves checkpointed files named in the following format: checkpoint.pt_{pp_rank}_{tp_rank}.NoteWith tensor parallelism, make sure you set the if statement as if smp.rdp_rank() == 0instead ofif smp.dp_rank() == 0. When the optimizer state is sharded with tensor parallelism, all reduced-data parallel ranks must save their own partition of the optimizer state. Using a wrong if statement for checkpointing might result in a stalling training job. For more information about usingif smp.dp_rank() == 0without tensor parallelism, see General Instruction for Saving and Loadingin the SageMaker Python SDK documentation. 
 - 
                            Option 2: Save the full model. if smp.rdp_rank() == 0: model_dict = model.state_dict(gather_to_rank0=True) # save the full model if smp.rank() == 0: smp.save( {"model_state_dict": model_dict}, "/checkpoint.pt", partial=False, )NoteConsider the following for full checkpointing: - 
                                        If you set gather_to_rank0=True, all ranks other than0return empty dictionaries.
- 
                                        For full checkpointing, you can only checkpoint the model. Full checkpointing of optimizer states is currently not supported. 
- 
                                        The full model only needs to be saved at smp.rank() == 0.
 
- 
                                        
 
- 
                            
- 
                    Load the checkpoints using smp.load(). Depending on how you checkpointed in the previous step, choose one of the following two options: - 
                            Option 1: Load the partial checkpoints. checkpoint = smp.load("/checkpoint.pt", partial=True) model.load_state_dict(checkpoint["model_state_dict"], same_partition_load=False) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])You can set same_partition_load=Trueinmodel.load_state_dict()for a faster load, if you know that the partition will not change.
- 
                            Option 2: Load the full checkpoints. if smp.rdp_rank() == 0: checkpoint = smp.load("/checkpoint.pt", partial=False) model.load_state_dict(checkpoint["model_state_dict"])The if smp.rdp_rank() == 0condition is not required, but it can help avoid redundant loading among differentMP_GROUPs. Full checkpointing optimizer state dict is currently not supported with tensor parallelism.
 
- 
                            
Checkpointing a distributed TensorFlow model
To save a TensorFlow model while training with model parallelism, use the following functions provided by the SageMaker model parallelism library.
Fine-tuning a distributed model
The fine-tuning needs to be configured in your training script. The following code
            snippet shows an example structure of a training script using the AutoModelForCausalLMsmdistributed.model.parallel.torch modules and settings
            for fine-tuning.
Note
Fine-tuning a distributed transformer (a Transformer model wrapped by
                    smp.DistributedModel()) with the smp.delayed_param_initialization
import argparse from transformers import AutoModelForCausalLM import smdistributed.modelparallel import smdistributed.modelparallel.torch as smp def parse_args(): parser = argparse.ArgumentParser() # set an arg group for model model_grp = parser.add_argument_group( title="model", description="arguments to describe model configuration" ) ... # set up numerous args to parse from the configuration dictionary to the script for training # add arg for activating fine-tuning model_grp.add_argument( "--fine_tune", type=int, default=0, help="Fine-tune model from checkpoint or pretrained model", ) def main(): """Main function to train GPT.""" args = parse_args() ... # parse numerous args if args.fine_tune > 0 and args.delayed_param > 0 and smp.rank() == 0: pretrained_model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) model_state_dict = pretrained_model.state_dict() path = os.path.join(args.model_dir, "fullmodel.pt") torch.save(model_state_dict, path) # create a Transformer model and wrap by smp.model_creation() # with options to configure model parallelism parameters offered by SageMaker AI with smp.model_creation( tensor_parallelism=smp.tp_size() > 1 or args.use_distributed_transformer > 0, zero_init=args.use_distributed_transformer == 0, dtype=dtype, distribute_embedding=args.sharded_data_parallel_degree > 1 and smp.tp_size() > 1, use_alibi=args.alibi > 0, attention_in_fp32=args.attention_in_fp32 > 0, fp32_residual_addition=args.residual_addition_in_fp32 > 0, query_key_layer_scaling=args.query_key_layer_scaling > 0 and args.bf16 < 1, fused_softmax=args.fused_softmax > 0, fused_dropout=args.fused_dropout > 0, fused_bias_gelu=args.fused_bias_gelu > 0, flash_attention=args.flash_attention > 0, ): if args.fine_tune > 0 and args.delayed_param == 0: model = AutoModelForCausalLM.from_pretrained( args.model_name or args.model_dir ) else: model = AutoModelForCausalLM.from_config(model_config) # wrap the model by smp.DistributedModel() to apply SageMaker model parallelism model = smp.DistributedModel( model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation ) # wrap the optimizer by smp.DistributedOptimizer() to apply SageMaker model parallelism optimizer= ... # define an optimizer optimizer = smp.DistributedOptimizer( optimizer, static_loss_scale=None, dynamic_loss_scale=True, dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2}, ) # for fine-tuning, use smp.resume_from_checkpoint() to load a pre-trained model if args.fine_tune > 0 and args.delayed_param > 0: smp.resume_from_checkpoint(args.model_dir, tag="fullmodel.pt", partial=False)
For a complete example of training scripts and Jupyter notebooks, see the GPT-2 examples for PyTorch