

# Core features of the SageMaker model parallelism library v2
<a name="model-parallel-core-features-v2"></a>

The Amazon SageMaker AI model parallelism library v2 (SMP v2) offers distribution strategies and memory-saving techniques, such as sharded data parallelism, tensor parallelism, and checkpointing. The model parallelism strategies and techniques offered by SMP v2 help distribute large models across multiple devices while optimizing training speed and memory consumption. SMP v2 also provides a Python package `torch.sagemaker` to help adapt your training script with few lines of code change.

This guide follows the basic two-step flow introduced in [Use the SageMaker model parallelism library v2](model-parallel-use-api-v2.md). To dive deep into the core features of SMP v2 and how to use them, see the following topics.

**Note**  
These core features are available in SMP v2.0.0 and later and the SageMaker Python SDK v2.200.0 and later, and works for PyTorch v2.0.1 and later. To check the versions of the packages, see [Supported frameworks and AWS Regions](distributed-model-parallel-support-v2.md).

**Topics**
+ [Hybrid sharded data parallelism](model-parallel-core-features-v2-sharded-data-parallelism.md)
+ [Expert parallelism](model-parallel-core-features-v2-expert-parallelism.md)
+ [Context parallelism](model-parallel-core-features-v2-context-parallelism.md)
+ [Compatibility with the SMDDP library optimized for AWS infrastructure](model-parallel-core-features-v2-smddp-allgather.md)
+ [Mixed precision training](model-parallel-core-features-v2-mixed-precision.md)
+ [Delayed parameter initialization](model-parallel-core-features-v2-delayed-param-init.md)
+ [Activation checkpointing](model-parallel-core-features-v2-pytorch-activation-checkpointing.md)
+ [Activation offloading](model-parallel-core-features-v2-pytorch-activation-offloading.md)
+ [Tensor parallelism](model-parallel-core-features-v2-tensor-parallelism.md)
+ [Fine-tuning](model-parallel-core-features-v2-fine-tuning.md)
+ [FlashAttention](model-parallel-core-features-v2-flashattention.md)
+ [Checkpointing using SMP](model-parallel-core-features-v2-checkpoints.md)

# Hybrid sharded data parallelism
<a name="model-parallel-core-features-v2-sharded-data-parallelism"></a>

*Sharded data parallelism* is a memory-saving distributed training technique that splits the state of a model (model parameters, gradients, and optimizer states) across devices. This helps you fit a larger model or increase the batch size using the freed-up GPU memory. The SMP library offers a capability of running sharded data parallelism with PyTorch Fully Sharded Data Parallel (FSDP). PyTorch FSDP by default shards across the whole set of GPUs being used. In SMP v2, the library offers this sharded data parallelism on top of PyTorch FSDP by extending PyTorch hybrid sharding (`HYBRID_SHARD`), which is one of the [sharding strategies provided by PyTorch FSDP](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy): `FULL_SHARD`, `SHARD_GRAD_OP`, `HYBRID_SHARD`, `_HYBRID_SHARD_ZERO2`. Extending hybrid sharding in this manner helps implement scale-aware-sharding as described in the blog [Near-linear scaling of gigantic-model training on AWS](https://www.amazon.science/blog/near-linear-scaling-of-gigantic-model-training-on-aws) for PyTorch FSDP.

The SMP library makes it easy to use `HYBRID_SHARD` and `_HYBRID_SHARD_ZERO2` across any configurable number of GPUs, extending the native PyTorch FSDP that supports sharding across a single node (`HYBRID_SHARD`) or all GPUs (`FULL_SHARD`). PyTorch FSDP calls can stay as is, and you only need to add the `hybrid_shard_degree` argument to the SMP configuration, as shown in the following code example. You don't need to change the value of the `sharding_strategy` argument in the PyTorch FSDP wrapper around your PyTorch model. You can pass `ShardingStrategy.HYBRID_SHARD` as the value. Alternatively, the SMP library overrides the strategy in the script and sets it to `ShardingStrategy.HYBRID_SHARD` if you specify a value equal to or greater than 2 to the `hybrid_shard_degree` parameter.

The following code snippets show how to add the SMP initialization module `torch.sagemaker.init()` to your training script and set up the SMP configuration dictionary in JSON format for training job launcher while following the two-step process introduced in [Use the SageMaker model parallelism library v2](model-parallel-use-api-v2.md). You don’t need to make any changes to your PyTorch model or [PyTorch FSDP](https://pytorch.org/docs/stable/fsdp.html#module-torch.distributed.fsdp) configuration. For more information about the `hybrid_shard_degree` parameter, see [SMP v2 core feature configuration parameters](distributed-model-parallel-v2-reference.md#distributed-model-parallel-v2-reference-init-config).

**SMP configuration dictionary**

```
{ "hybrid_shard_degree": 16 }
```

**In training script**

```
import torch.sagemaker as tsm
tsm.init()

# Set up a PyTorch model
model = ...

# Wrap the PyTorch model using the PyTorch FSDP module
model = FSDP(
    model,
    ...
)

# Optimizer needs to be created after FSDP wrapper
optimizer = ...
```

# Expert parallelism
<a name="model-parallel-core-features-v2-expert-parallelism"></a>

A *Mixture of Experts* (MoE) model is a type of transformer model that employs a *sparse* approach, making it lighter for training compared to training traditional dense models. In this MoE neural network architecture, only a subset of the model's components called *experts* are utilized for each input. This approach offers several advantages, including more efficient training and faster inference, even with a larger model size. In other words, with the same compute budget for training a full dense model, you can fit a larger model or dataset when using MoE.

An MoE model consists of multiple *experts*, each consisting of a neural network, typically a feed-forward network (FFN). A gate network called *router* determines which tokens are sent to which expert. These experts specialize in processing specific aspects of the input data, enabling the model to train faster, reduce compute cost, while achieving the same performance quality as its counterpart dense model. To learn more about Mixture of Experts in general, refer to the blog [Applying Mixture of Experts in LLM Architectures](https://developer.nvidia.com/blog/applying-mixture-of-experts-in-llm-architectures/) in the *NVIDIA developer website*.

*Expert parallelism* is a type of parallelism that handles splitting experts of an MoE model across GPU devices.

SMP v2 integrates with [NVIDIA Megatron](https://github.com/NVIDIA/Megatron-LM) for implementing expert parallelism to support training MoE models, and runs on top of PyTorch FSDP APIs. You keep using your PyTorch FSDP training code as is and activate SMP expert parallelism for training MoE models.

## Hugging Face Transformer models compatible with SMP expert parallelism
<a name="model-parallel-core-features-v2-expert-parallelism-supported-models"></a>

SMP v2 currently offers expert parallelism support for the following Hugging Face transformer models.
+ [Mixtral](https://huggingface.co/docs/transformers/en/model_doc/mixtral)

## Configure expert parallelism
<a name="model-parallel-core-features-v2-expert-parallelism-configure"></a>

For `expert_parallel_degree`, you select a value for the degree of expert parallelism. The value must evenly divide the number of GPUs in your cluster. For example, to shard your model while using an instance with 8 GPUs, choose 2, 4, or 8. We recommend that you start with a small number, and gradually increase it until the model fits in the GPU memory.

The following code snippets show how to add the SMP initialization module `torch.sagemaker.init()` to your training script and set up the SMP configuration dictionary in JSON format for training job launcher while following the two-step process introduced in [Use the SageMaker model parallelism library v2](model-parallel-use-api-v2.md). You don’t need to make any changes to your PyTorch model or [PyTorch FSDP](https://pytorch.org/docs/stable/fsdp.html#module-torch.distributed.fsdp) configuration. For more information about the `expert_parallel_degree` parameter, see [SMP v2 core feature configuration parameters](distributed-model-parallel-v2-reference.md#distributed-model-parallel-v2-reference-init-config).

**Note**  
You can use expert parallelism with [Hybrid sharded data parallelism](model-parallel-core-features-v2-sharded-data-parallelism.md). Note that expert parallelism is currently not compatible with tensor parallelism.

**Note**  
This expert parallelism training feature is available in the following combination of libraries of SageMaker and the PyTorch library:  
SMP v2.3.0 and later
The SageMaker Python SDK v2.214.4 and later
PyTorch v2.2.0 and later

### In your training script
<a name="model-parallel-core-features-v2-expert-parallelism-configure-in-script"></a>

As part of [Step 1](model-parallel-use-api-v2.md#model-parallel-adapt-pytorch-script-v2), initialize your script with `torch.sagemaker.init()` to activate SMP v2 and wrap your model with the [`torch.sagemaker.transform`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-transform) API, adding the `config` parameter to the API to activate MoE. The following code snippet shows how to activate SMP MoE for the generic model class `AutoModelForCausalLM` pulling an MoE transformer model configuration using the `from_config` method for training from scratch, or the `from_pretrained` method for fine-tuning. To learn more about the SMP `MoEConfig` class, see [`torch.sagemaker.moe.moe_config.MoEConfig`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-moe).

```
# Import the torch.sagemaker.transform API and initialize.
import torch.sagemaker as tsm
tsm.init()

# Import transformers AutoModelForCausalLM class.
from transformers import AutoModelForCausalLM

# Import the SMP-implementation of MoE configuration class.
from torch.sagemaker.moe.moe_config import MoEConfig

# Define a transformer model with an MoE model configuration
model = AutoModelForCausalLM.from_config(MoEModelConfig)

# Wrap it by torch.sagemaker.transform with the SMP MoE configuration.
model = tsm.transform(
    model, 
    config=MoEConfig(
        smp_moe=True,
        random_seed=12345,
        moe_load_balancing="sinkhorn",
        global_token_shuffle=False,
        moe_all_to_all_dispatcher=True,
        moe_aux_loss_coeff=0.001,
        moe_z_loss_coeff=0.001
    )
)
```

### SMP configuration
<a name="model-parallel-core-features-v2-expert-parallelism-configure-in-estimator-config"></a>

As part of [Step 2](model-parallel-use-api-v2.md#model-parallel-launch-a-training-job-v2), add the following parameter to the SMP configuration dictionary for the SageMaker PyTorch estimator.

```
{   
    ..., # other SMP config parameters
    "expert_parallel_degree": 8
}
```

# Context parallelism
<a name="model-parallel-core-features-v2-context-parallelism"></a>

*Context parallelism* is a type of model parallelism that partitions the model activations along the sequence dimension. Unlike other [sequence parallelism](https://arxiv.org/abs/2205.05198) techniques, which only partition the `LayerNorm` and `RMSNorm`, context parallelism partitions the network inputs and all intermediate activations along the sequence dimension. 

SMP v2 integrates with [Transformer Engine](https://docs.nvidia.com/deeplearning/transformer-engine/index.html) for context parallelism and can be used in conjunction with PyTorch FSDP and SMP [Tensor parallelism](model-parallel-core-features-v2-tensor-parallelism.md). You can enable all three parallelisms simultaneously for model training. Context parallelism is beneficial for training models with large activation sizes and long sequence lengths. It accelerates the computation of attention scores and attention outputs, by allowing each device to computes only a part of the scores and outputs along the sequence dimension. While tensor parallelism also accelerates computation through partitioning along the hidden dimension, the advantage of context parallelism is more substantial since computational requirements increase quadratically with sequence dimension.

## Hugging Face Transformer models compatible with SMP context parallelism
<a name="model-parallel-core-features-v2-context-parallelism-supported-models"></a>

SMP v2 currently offers context parallelism support for the following Hugging Face transformer models.
+ GPT-NeoX
+ Llama 2 and Llama 3
+ [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.3)

## Configure context parallelism
<a name="model-parallel-core-features-v2-context-parallelism-configure"></a>

Set an integer value to the `context_parallel_degree` parameter that evenly divides the number of GPUs in your cluster. For example, if you have an 8-GPU instance, use 2, 4, or 8 for `context_parallel_degree`. We recommend starting with a small `context_parallel_degree` value and gradually increasing it until the model fits in the GPU memory with the required input sequence length.

The following code snippets show how to add the SMP initialization module `torch.sagemaker.init()` to your training script and set up the SMP configuration dictionary in JSON format for training job launcher while following the two-step process introduced in [Use the SageMaker model parallelism library v2](model-parallel-use-api-v2.md). You don’t need to make any changes to your PyTorch model or [PyTorch FSDP](https://pytorch.org/docs/stable/fsdp.html#module-torch.distributed.fsdp) configuration. For more information about the `context_parallel_degree` parameter, see [SMP v2 core feature configuration parameters](distributed-model-parallel-v2-reference.md#distributed-model-parallel-v2-reference-init-config).

### In your training script
<a name="model-parallel-core-features-v2-context-parallelism-configure-in-script"></a>

As part of [Step 1](model-parallel-use-api-v2.md#model-parallel-adapt-pytorch-script-v2), initialize your script with `torch.sagemaker.init()` to activate SMP v2 and wrap your model with the [`torch.sagemaker.transform`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-transform) API. 

Starting from SMP v2.6.0, you can use the argument `cp_comm_type` to determine which context parallelism implementation to use. The SMP library currently supports two implementations: `p2p` and `all_gather`. The `p2p` implementation uses peer-to-peer send-receive calls for key-value accumulation during the attention implementation and runs asynchronously, allowing overlaps with compute. `all_gather` implementation, instead, uses the `AllGather` collective operation and runs synchronously.

```
import torch.sagemaker as tsm
tsm.init()

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_config(..)
model = tsm.transform(model, cp_comm_type="p2p")
```

### SMP configuration
<a name="model-parallel-core-features-v2-context-parallelism-configure-in-estimator"></a>

As part of [Step 2](model-parallel-use-api-v2.md#model-parallel-launch-a-training-job-v2), add the following parameter to the SMP configuration dictionary for the SageMaker PyTorch estimator.

```
{   
    ..., # other SMP config parameters
    "context_parallel_degree": 2
}
```

# Compatibility with the SMDDP library optimized for AWS infrastructure
<a name="model-parallel-core-features-v2-smddp-allgather"></a>

You can use the SageMaker model parallelism library v2 (SMP v2) in conjunction with the [SageMaker distributed data parallelism (SMDDP) library](data-parallel.md) that offers the `AllGather` collective communication operation optimized for AWS infrastructure. In distributed training, collective communication operations are designed for synchronizing multiple GPU workers and exchange information between them. `AllGather` is one of the core collective communication operations typically used in sharded data parallelism. To learn more about the SMDDP `AllGather` operation, see [SMDDP `AllGather` collective operation](data-parallel-intro.md#data-parallel-allgather) Optimizing such collective communication operations would directly contribute to a faster end-to-end training without side effects on convergence.

**Note**  
The SMDDP library supports P4 and P4de instances (see also [Supported frameworks, AWS Regions, and instances types](distributed-data-parallel-support.md) by the SMDDP library).

The SMDDP library integrates natively with PyTorch through the [process group](https://pytorch.org/docs/stable/distributed.html) layer. To use the SMDDP library, you only need to add two lines of code to your training script. It supports any training frameworks such as SageMaker Model Parallelism Library, PyTorch FSDP, and DeepSpeed.

To activate SMDDP and use its `AllGather` operation, you need to add two lines of code to your training script as part of [Step 1: Adapt your PyTorch FSDP training script](model-parallel-use-api-v2.md#model-parallel-adapt-pytorch-script-v2). Note that you need to initialize PyTorch Distributed with the SMDDP backend first, and then run the SMP initialization.

```
import torch.distributed as dist

# Initialize with SMDDP
import smdistributed.dataparallel.torch.torch_smddp
dist.init_process_group(backend="smddp") # Replacing "nccl"

 # Initialize with SMP
import torch.sagemaker as tsm
tsm.init()
```

[SageMaker Framework Containers](https://github.com/aws/deep-learning-containers/blob/master/available_images.md#sagemaker-framework-containers-sm-support-only) for PyTorch (see also [Supported frameworks and AWS Regions](distributed-model-parallel-support-v2.md) by SMP v2 and [Supported frameworks, AWS Regions, and instances types](distributed-data-parallel-support.md) by the SMDDP library) are pre-packaged with the SMP binary and the SMDDP binary. To learn more about the SMDDP library, see [Run distributed training with the SageMaker AI distributed data parallelism library](data-parallel.md). 

# Mixed precision training
<a name="model-parallel-core-features-v2-mixed-precision"></a>

The SageMaker model parallelism (SMP) library v2 supports mixed precision training out of the box by integrating with open source frameworks such as PyTorch FSDP and Transformer Engine. To learn more, see the following topics.

**Topics**
+ [Mixed precision training with FP8 on P5 instances using Transformer Engine](#model-parallel-core-features-v2-mixed-precision-fp8-training-on-p5)
+ [Mixed precision training with half-precision data types using PyTorch FSDP](#model-parallel-core-features-v2-mixed-precision-half-precision)

## Mixed precision training with FP8 on P5 instances using Transformer Engine
<a name="model-parallel-core-features-v2-mixed-precision-fp8-training-on-p5"></a>

Starting from the SageMaker model parallelism (SMP) library v2.2.0, the SMP library integrates with [Transformer Engine](https://docs.nvidia.com/deeplearning/transformer-engine/index.html) and supports [FP8 mixed precision training](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) out of the box, keeping compatibility with [PyTorch FSDP `MixedPrecision`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision). This means that you can use both PyTorch FSDP for mixed precision training and Transformer Engine for FP8 training. For model layers not supported by Transformer Engine's FP8 training feature, those layers fall back to PyTorch FSDP mixed precision.

**Note**  
SMP v2 offers FP8 support for the following Hugging Face Transformer models:  
GPT-NeoX (available in SMP v2.2.0 and later)
Llama 2 (available in SMP v2.2.0 and later)
Mixtral 8x7b and Mixtral 8x22b (available in SMP v2.5.0 and later)

**Note**  
This FP8 training on the P5 feature is available in the following combination of libraries of SageMaker and the PyTorch library:  
The SageMaker Python SDK v2.212.0 and later
PyTorch v2.2.0 and later

*FP8* (8-bit floating point precision) is a data type that has emerged as another paradigm to accelerate deep learning training of LLM models. With the release of NVIDIA H100 GPUs supporting FP8 data types, you can benefit from the advantages from the performance improvements on P5 instances equipped with the H100 GPUs, while accelerating distributed training with FP8 mixed precision training.

The FP8 data type further branches down to E4M3 and E5M2 formats. *E4M3* offers a better precision, has a limited dynamic range, and is ideal for the forward pass in model training. *E5M2* has a broader dynamic range, but reduced precision, and is better suited for the backward pass, where precision is less critical and a wider dynamic range becomes beneficial. Hence, we recommend that you use the [hybrid FP8 strategy recipe](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-recipe) to leverage these characteristics effectively.

For half-precision data types (FP16 and BF16), global loss-scaling techniques such as static loss-scaling or dynamic loss-scaling handle convergence issues that arise from information loss due to rounding gradients in half-precision. However, the dynamic range of FP8 is even narrower, and the global loss scaling techniques are not sufficient. At this point, we need a finer-grained per-tensor scaling technique. *Delayed scaling* is a strategy that selects a scaling factor based on the maximum absolute values observed in a number of tensors form previous iterations. There's a trade-off in this strategy; it uses the full performance benefits of FP8 computation but requires memory for keeping the maximum value history of tensors. To learn more about the delayed scaling strategy in general, see the paper [https://arxiv.org/pdf/2209.05433.pdf](https://arxiv.org/pdf/2209.05433.pdf).

In practice, using FP8 is helpful in all training scenarios on P5 instances. We strongly recommend enabling FP8 whenever possible for enhancing training performance.

SMP v2 supports Transformer Engine out of the box. Therefore, when running FP8 training with SMP v2 on P5 instances of SageMaker AI (`ml.p5.48xlarge`), the only thing you need to do is to import `torch.sagemaker` in your training script and keep using the native Transformer Engine Python package. To learn more about using Transformer Engine for FP8 training in general, see [Using FP8 with Transformer Engine](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) in the *NVIDIA Transformer Engine documentation*. The following code snippet shows how the code lines for importing the SMP library and setting up FP8 in your training script should look.

```
import torch.sagemaker as tsm
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format

# Initialize the SMP torch.sagemaker API.
tsm.init()

# Define a transformer model and wrap it with the torch.sagemaker.transform API.
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_config(ModelConfig)
model = tsm.transform(model)

# Enable E4M3 during forward pass, E5M2 during backward pass.
fp8_format = Format.HYBRID

# Create an FP8 recipe.
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")

# Enable FP8 autocasting.
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=tsm.state.world_process_group):
    out = model(inp)

loss = out.sum()
loss.backward()
```

To find a practical example of FP8 training with SMP v2 on P5 instances, see the example notebook at [Accelerate SageMaker PyTorch FSDP Training of Llama-v2 (or GPT-NeoX) with FP8 on P5 instances](https://github.com/aws/amazon-sagemaker-examples/blob/main/training/distributed_training/pytorch/model_parallel_v2/llama_v2/smp-train-llama-fsdp-tp-fp8.ipynb).

## Mixed precision training with half-precision data types using PyTorch FSDP
<a name="model-parallel-core-features-v2-mixed-precision-half-precision"></a>

SMP v2 supports [PyTorch FSDP `MixedPrecision`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision) for training jobs on P4 and P5 instances. PyTorch FSDP provides various configurations for mixed precision for both performance improvement and memory reduction. 

**Note**  
This mixed precision training with the PyTorch FSDP feature is available in the following combination of libraries of SageMaker and the PyTorch library.  
SMP v2.0.0 and later
the SageMaker Python SDK v2.200.0 and later
PyTorch v2.0.1 and later

The standard way to configure a model for mixed precision is to create the model in `float32`, and then allow FSDP to cast the parameters to `float16` or `bfloat16` on the fly by passing a `MixedPrecision` policy, as shown in the following code snippet. For more information about options to change the `dtype` for parameters, reduction, or buffers for mixed precision in PyTorch, see [PyTorch FSDP `MixedPrecision` API](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision) in the *PyTorch documentation*.

```
# Native PyTorch API
from torch.distributed.fsdp import MixedPrecision

dtype = torch.bfloat16
mixed_precision_policy = MixedPrecision(
    param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype
)

model = FSDP(
    model,
    ...,
    mixed_precision=mixed_precision_policy
)
```

Note that certain models (such as the Hugging Face Transformers Llama model) expect buffers as `float32`. To use `float32`, replace `torch.bfloat16` with `torch.float32` in the line defining the `dtype` object.

# Delayed parameter initialization
<a name="model-parallel-core-features-v2-delayed-param-init"></a>

Initialization of a large model for training is not always possible with the limited GPU memory. To resolve this problem of insufficient GPU memory, you can initialize the model on CPU memory. However, for larger models with more than 20 or 40 billion parameters, even CPU memory might not be enough. For such case, we recommend that you initialize the model on what PyTorch calls a *meta device*, which allows the creation of tensors without any data attached to them. A tensor on a meta device only needs the shape information, and this allows to create a large model with its parameters on meta devices. [Hugging Face Accelerate](https://huggingface.co/docs/accelerate/index) provides the context manager `init_empty_weights` to help create such model on meta devices while initializing the buffers on a regular device. Before training starts, PyTorch FSDP initializes the model parameters. This delayed parameter initialization feature of SMP v2 delays this creation of model parameters to happen after PyTorch FSDP performs parameter sharding. PyTorch FSDP accepts a parameter initialization function (`param_init_fn`) when sharding the modules, and it calls `param_init_fn` for each module. The `param_init_fn` API takes a module as an argument and initializes all the parameters in it, not including the parameters of any child module. Note that this behavior *differs* from the native PyTorch v2.0.1 which has a bug causing the parameters to be initialized multiple times.

SMP v2 provides the [`torch.sagemaker.delayed_param.DelayedParamIniter`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-delayed-param-init) API for applying delayed parameter initialization.

The following code snippets show how to apply the `torch.sagemaker.delayed_param.DelayedParamIniter` API to your training script.

Assume that you have a PyTorch FSDP training script as follows.

```
# Creation of model on meta device
from accelerate import init_empty_weights
with init_empty_weights():
    model = create_model()

# Define a param init fn, below is an example for Hugging Face GPTNeoX.
def init_weights(module):
    d = torch.cuda.current_device()
    # Note that below doesn't work if you have buffers in the model
    # buffers will need to reinitialized after this call
    module.to_empty(device=d, recurse=False)
    if isinstance(module, (nn.Linear, Conv1D)):
        module.weight.data.normal_(mean=0.0, std=args.initializer_range)
        if module.bias:
            module.bias.data.zero_()
    elif isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=args.initializer_range)
        if module.padding_idx:
            module.weight.data[module.padding_idx].zero_()
    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)

# Changes to FSDP wrapper.
model = FSDP(
    model,
    ...,
    param_init_fn=init_weights
)

# At this point model is initialized and sharded for sharded data parallelism.
```

Note that the delayed parameter initialization approach is not model agnostic. To resolve this issue, you need to write an `init_weights` function as shown in the preceding example to match the initialization in the original model definition, and it should cover all the parameters of the model. To simplify this process of preparing such `init_weights` function, SMP v2 implements this initialization function for the following models: GPT-2, GPT-J, GPT-NeoX, and Llama from Hugging Face Transformers. The `torch.sagemaker.delayed_param.DelayedParamIniter` API also works with the SMP tensor parallel implementation, `torch.sagemaker.tensor_parallel.transformer.TransformerLMHead` model, that you can call after the [`torch.sagemaker.transform`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-transform) API call.

Using the `torch.sagemaker.delayed_param.DelayedParamIniter` API, you can adapt your PyTorch FSDP script as follows. After creating a model with empty weights, register the `torch.sagemaker.delayed_param.DelayedParamIniter` API to the model, and define an object of it. Pass the object to the `param_init_fn` of the PyTorch FSDP class.

```
from torch.sagemaker.delayed_param import DelayedParamIniter
from accelerate import init_empty_weights

with init_empty_weights():
    model = create_model()
    
delayed_initer = DelayedParamIniter(model)

with delayed_initer.validate_params_and_buffers_inited():
    model = FSDP(
        model,
        ...,
        param_init_fn=delayed_initer.get_param_init_fn()
    )
```

**Notes on tied weights**

When training models with tied weights, we need to take special care to tie the weights after initializing the weights with delayed parameter initialization. PyTorch FSDP does not have a mechanism to tie the weights after initializing them using `param_init_fn` as above. To address such cases we added API to allow a `post_init_hook_fn`, which can be used to tie the weights. You can pass any function in there which accepts the module as argument, but we also have a predefined `post_param_init_fn` defined in `DelayedParamIniter` which calls `tie_weights` method of the module if it exists. Note that it’s safe to always pass in `post_param_init_fn` even if there’s no `tie_weights` method for the module.

```
with delayed_initer.validate_params_and_buffers_inited():
    model = FSDP(
        model,
        ...,
        param_init_fn=delayed_initer.get_param_init_fn(),
        post_param_init_fn=delayed_initer.get_post_param_init_fn()
    )
```

# Activation checkpointing
<a name="model-parallel-core-features-v2-pytorch-activation-checkpointing"></a>

*Activation checkpointing* is a technique to reduce memory usage by clearing activations of certain layers and recomputing them during the backward pass. Effectively, this trades extra computation time for reducing memory usage. If a module is checkpointed, at the end of a forward pass, only the initial inputs to the module and final outputs from the module stay in memory. PyTorch releases any intermediate tensors that are part of the computation inside that module during the forward pass. During the backward pass of the checkpointed modules, PyTorch recomputes these tensors. At this point, the layers beyond this checkpointed module have finished their backward pass, so the peak memory usage with checkpointing becomes lower.

SMP v2 supports the PyTorch activation checkpointing module, [https://pytorch.org/blog/scaling-multimodal-foundation-models-in-torchmultimodal-with-pytorch-distributed/#activation-checkpointing](https://pytorch.org/blog/scaling-multimodal-foundation-models-in-torchmultimodal-with-pytorch-distributed/#activation-checkpointing). The following are examples of activation checkpointing of the Hugging Face GPT-NeoX model.

**Checkpointing Transformer layers of the Hugging Face GPT-NeoX model**

```
from transformers.models.gpt_neox import GPTNeoXLayer
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    apply_activation_checkpointing
)
    
# check_fn receives a module as the arg, 
# and it needs to return whether the module is to be checkpointed
def is_transformer_layer(module):
    from transformers.models.gpt_neox import GPTNeoXLayer
    return isinstance(submodule, GPTNeoXLayer)
    
apply_activation_checkpointing(model, check_fn=is_transformer_layer)
```

**Checkpointing every other Transformer layer of the Hugging Face GPT-NeoX model**

```
# check_fn receives a module as arg, 
# and it needs to return whether the module is to be checkpointed
# here we define that function based on global variable (transformer_layers)
from transformers.models.gpt_neox import GPTNeoXLayer
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    apply_activation_checkpointing
)

transformer_layers = [
    m for m model.modules() if isinstance(m, GPTNeoXLayer)
]

def is_odd_transformer_layer(module):
    return transformer_layers.index(module) % 2 == 0
    
apply_activation_checkpointing(model, check_fn=is_odd_transformer_layer)
```

Alternatively, PyTorch also has the `torch.utils.checkpoint` module for checkpointing, which is used by a subset of Hugging Face Transformers models. This module also works with SMP v2. However, it requires you to have access to the model definition for adding the checkpoint wrapper. Therefore, we recommend you to use the `apply_activation_checkpointing` method.

# Activation offloading
<a name="model-parallel-core-features-v2-pytorch-activation-offloading"></a>

**Important**  
In SMP v2.2.0, the activation offloading functionality of the SMP library doesn't work. Use the native PyTorch activation offloading instead.

Typically, the forward pass computes activations at each layer and keeps them in GPU memory until the backward pass for the corresponding layer finishes. Offloading these tensors to CPU memory after forward pass and fetching them back to GPU when they are needed can save substantial GPU memory usage. PyTorch supports offloading activations, but the implementation causes GPUs to be idle while activations are fetched back from CPU during backward pass. This causes a major performance degradation when using activation offloading.

SMP v2 improves this activation offloading. It pre-fetches activations ahead of time before the activations are needed for the GPU to start backward pass on those activations. The pre-fetching feature helps training progresses be run more efficiently without idle GPUs. This results in offering benefits from lower memory usage without a performance degradation.

You can keep the native PyTorch modules for offloading activations in your training script. The following is an example structure of applying the SMP activation offloading feature in your script. Note that activation offloading is applicable *only* when used together with [Activation checkpointing](model-parallel-core-features-v2-pytorch-activation-checkpointing.md). To learn more about the native PyTorch checkpoint tools for activation offloading, see:
+ [checkpoint\$1wrapper.py](https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py#L171) in the *PyTorch GitHub repository*
+ [Activation Checkpointing](https://pytorch.org/blog/scaling-multimodal-foundation-models-in-torchmultimodal-with-pytorch-distributed/#activation-checkpointing) in the PyTorch blog *Scaling Multi-modal Foundation Models in TorchMultimodal with PyTorch Distributed*.

You can apply the SMP activation offloading feature on [PyTorch activation checkpointing](https://pytorch.org/blog/scaling-multimodal-foundation-models-in-torchmultimodal-with-pytorch-distributed/#activation-checkpointing). This is done by adding the `sm_activation_offloading` and `activation_loading_horizon` parameters to the SMP configuration dictionary during [Step 2: Launch a training job](model-parallel-use-api-v2.md#model-parallel-launch-a-training-job-v2). 

The following code snippets show how to add the SMP initialization module `torch.sagemaker.init()` to your training script and set up the SMP configuration dictionary in JSON format for training job launcher while following the two-step process introduced in [Use the SageMaker model parallelism library v2](model-parallel-use-api-v2.md). You don’t need to make any changes to your PyTorch model or [PyTorch FSDP](https://pytorch.org/docs/stable/fsdp.html#module-torch.distributed.fsdp) configuration. For more information about the `sm_activation_offloading` and `activation_loading_horizon` parameters, see [SMP v2 core feature configuration parameters](distributed-model-parallel-v2-reference.md#distributed-model-parallel-v2-reference-init-config).

**SMP configuration**

```
{
    "activation_loading_horizon": 2,
    "sm_activation_offloading": True
}
```

**In training script**

**Note**  
While activating the SMP activation offloading feature, make sure that you also use the PyTorch `offload_wrapper` function and apply it to the root module. The SMP activation offloading feature uses the root module to determine when forward pass is done to start pre-fetching.

```
import torch.sagemaker as tsm
tsm.init()

# Native PyTorch module for activation offloading
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    apply_activation_checkpointing, 
    offload_wrapper,
)

model = FSDP(...)

# Activation offloading requires activation checkpointing.
apply_activation_checkpointing(
    model,
    check_fn=checkpoint_transformer_layers_policy,
)

model = offload_wrapper(model)
```

# Tensor parallelism
<a name="model-parallel-core-features-v2-tensor-parallelism"></a>

*Tensor parallelism* is a type of model parallelism in which specific model weights, gradients, and optimizer states are split across devices. In contrast to pipeline parallelism, which keeps individual weights intact but partitions the *set* of weights, gradients, or optimizer across devices, tensor parallelism shards *individual* weights. This typically involves distributed computation of specific operations, modules, or layers of the model.

Tensor parallelism is required in cases in which a single parameter consumes most of the GPU memory (such as large embedding tables with a large vocabulary size or a large softmax layer with a large number of classes). In this case, treating this large tensor or operation as an atomic unit is inefficient and impedes balance of the memory load.

SMP v2 integrates with [Transformer Engine](https://docs.nvidia.com/deeplearning/transformer-engine/index.html) for the implementation for tensor parallelism, and runs on top of PyTorch FSDP APIs. You can enable PyTorch FSDP and SMP tensor parallelism simultaneously, and determine the best model parallelism for best performance.

In practice, tensor parallelism is especially helpful in the following scenarios.
+ When training with long context lengths as that leads to high activation memory with FSDP alone.
+ When training with really large clusters on which the global batch size exceeds desired limits.

## Hugging Face Transformer models compatible with the SMP tensor parallelism
<a name="model-parallel-core-features-v2-tensor-parallelism-supported-models"></a>

SMP v2 currently offers tensor parallelism support for the following Hugging Face transformer models.
+ GPT-NeoX
+ Llama 2
+ Llama 3
+ [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.3)
+ [Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)
+ [Mixtral 8x22B](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1)

For reference configuration for applying tensor parallelism on these models, see [Configuration tips](model-parallel-best-practices-v2.md#model-parallel-best-practices-v2-config-tips).

## Configure tensor parallelism
<a name="model-parallel-core-features-v2-tensor-parallelism-configuration"></a>

For `tensor_parallel_degree`, you select a value for the degree of tensor parallelism. The value must evenly divide the number of GPUs in your cluster. For example, to shard your model while using an instance with 8 GPUs, choose 2, 4, or 8. We recommend that you start with a small number, and gradually increase it until the model fits in the GPU memory.

The following code snippets show how to add the SMP initialization module `torch.sagemaker.init()` to your training script and set up the SMP configuration dictionary in JSON format for training job launcher while following the two-step process introduced in [Use the SageMaker model parallelism library v2](model-parallel-use-api-v2.md). You don’t need to make any changes to your PyTorch model or [PyTorch FSDP](https://pytorch.org/docs/stable/fsdp.html#module-torch.distributed.fsdp) configuration. For more information about the `tensor_parallel_degree` and `random_seed` parameters, see [SMP v2 core feature configuration parameters](distributed-model-parallel-v2-reference.md#distributed-model-parallel-v2-reference-init-config).

**SMP configuration**

```
{
    "tensor_parallel_degree": 8,
    "random_seed": 0 
}
```

**In your training script**

Initialize with `torch.sagemaker.init()` to activate SMP v2 and wrap your model with the [`torch.sagemaker.transform`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-transform) API.

```
import torch.sagemaker as tsm
tsm.init()

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_config(..)
model = tsm.transform(model)
```

## Saving and loading Hugging Face Transformer checkpoints
<a name="model-parallel-core-features-v2-tensor-parallelism-checkpoints"></a>

After the SMP library transforms a model, it changes the state dictionary (`state_dict`) of the model. This means that the model becomes incompatible with the original Hugging Face Transformer checkpointing functionalities. To handle this, the SMP library provides APIs to save checkpoints from a transformed model in Hugging Face Transformer representation, and the `torch.sagemaker.transform` API to load a Hugging Face Transformer model checkpoint for fine-tuning.

For more information about saving checkpoints while using the tensor parallelism feature of SMP v2, see [Checkpointing using SMP](model-parallel-core-features-v2-checkpoints.md).

For more information about fine-tuning a model applying the tensor parallelism feature of SMP v2, see [Fine-tuning](model-parallel-core-features-v2-fine-tuning.md).

# Fine-tuning
<a name="model-parallel-core-features-v2-fine-tuning"></a>

Fine-tuning is a process of continuously training pre-trained models to improve performance for specific use cases.

Fine-tuning small models that fit fully on a single GPU, or those that fit 8 copies of model fully on CPUs is straightforward. It requires no special change to regular FSDP training. In the realm of models larger than this, you need to consider using the delayed parameter initialization functionality, which can be tricky.

To address this, the SMP library loads the full model on one of the ranks while the rest of the ranks create models with empty weights on a meta device. Then, PyTorch FSDP initializes the weights on non-zero ranks using the `init_weights` function, and synchronizes the weights on all ranks to the weights on the 0th rank with `sync_module_states` set to `True`. The following code snippet shows how you should set it up in your training script.

```
import torch.distributed as dist
from transformers import AutoModelForCasalLM
from accelerate import init_empty_weights
from torch.sagemaker.delayed_param import DelayedParamIniter

if dist.get_rank() == 0:
    model = AutoModelForCasalLM.from_pretrained(..., low_cpu_mem_usage=True)
else:
    with init_empty_weights():
        model = AutoModelForCasalLM.from_config(AutoConfig.from_pretrained(...))
    delayed_initer = DelayedParamIniter(model)

model = FSDP(
    model,
    ...,
    sync_module_states=True,
    param_init_fn=delayed_initer.get_param_init_fn() if dist.get_rank() > 0 else None
)
```

## Fine-tuning a pre-trained Hugging Face Transformer model with SMP tensor parallelism
<a name="model-parallel-core-features-v2-tensor-parallelism-fine-tuning-hf-transformer-with-tp"></a>

This section discusses loading Transformer models for two use cases: fine-tuning small Transformer models and fine-tuning large Transformer models. For smaller models without delayed parameter initialization, wrap the model with the `torch.sagemaker.transform` API before wrapping it with PyTorch FSDP.

```
import functools
from transformers import AutoModelForCausalLM
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.sagemaker import transform

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", low_cpu_mem_usage=True)

# Transform model while loading state dictionary from rank 0.
tp_model = transform(model, load_state_dict_from_rank0=True)

# Wrap with FSDP.
model = FSDP(
    tp_model, 
    ...
    sync_module_states=True,
)
```

For larger models, the preceding approach causes to run out of CPU memory. We recommend that you use delayed parameter initialization to avoid such CPU memory issues. In this case, you can apply the `torch.sagemaker.transform` API and the `torch.sagemaker.delayed_param.DelayedParamIniter` API as shown in the following code example.

```
from transformers import AutoModelForCausalLM
from torch.sagemaker import transform
from torch.sagemaker.delayed_param import DelayedParamIniter

# Create one instance of model without delayed param
# on CPU, on one rank.
if dist.get_rank() == 0:
    model = AutoModelForCasalLM.from_pretrained(...,low_cpu_mem_usage=True)
else:
    with init_empty_weights():
        model = AutoModelForCasalLM.from_config(AutoConfig.from_pretrained(...))

# Transform model while loading state dictionary from rank 0
model = transform(model, load_state_dict_from_rank0=True)

if dist.get_rank() != 0: # For fine-tuning, delayed parameter on non-zero ranks
    delayed_initer = DelayedParamIniter(model)
else:
    delayed_initer = None

with (
    delayed_initer.validate_params_and_buffers_inited() if delayed_initer else nullcontext()
):
    # Wrap the model with FSDP
    model = FSDP(
        model, 
        ..., 
        sync_module_states=True,
        param_init_fn=delayed_initer.get_param_init_fn() if delayed_initer else None
    )
```

# FlashAttention
<a name="model-parallel-core-features-v2-flashattention"></a>

SMP v2 supports [FlashAttention](https://github.com/HazyResearch/flash-attention) kernels and makes it easy to apply them to various scenarios for Hugging Face Transformer models. Note that if you use FlashAttention package v2.0 or later, SMP uses FlashAttention v2; however, the Triton flash attention defaults to the flash attention kernel in FlashAttention v1.x, making it exclusively supported in FlashAttention v1. 

The module (`nn.Module`) is a low level API that defines the attention layers of a model. It should be applied right after model creation, from the `AutoModelForCausalLM.from_config()` API for example, and before the model is being transformed or wrapped with FSDP.

## Use FlashAttention kernels for self attention
<a name="model-parallel-core-features-v2-flashattention-self"></a>

The following code snippet shows how to use the [`torch.sagemaker.nn.attn.FlashSelfAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashselfattention) API provided by SMP v2.

```
def new_attn(self, q, k, v, attention_mask=None, head_mask=None):
    return (
        self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"),
        None,
    )

for layer in model.gpt_neox.layers:
    layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention()
    layer.attention._attn = functools.partial(new_attn, layer.attention)
```

## Use FlashAttention kernels for grouped-query attention
<a name="model-parallel-core-features-v2-flashattention-grouped-query"></a>

SMP v2 also supports [FlashAttention](https://github.com/HazyResearch/flash-attention) kernels for grouped-query attention (GQA) and makes it easy to apply them to various scenarios for Hugging Face Transformer models. Different from original attention architecture, GQA equally partitions query heads into groups, and query heads in the same group share the same key and value heads. Therefore, q and kv heads are passed into forward call separately. Note: The number of q heads needs to be divisible by the number of kv heads.

**Example of using FlashGroupedQueryAttention**

The following code snippet shows how to use the [`torch.sagemaker.nn.attn.FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn) API provided by SMP v2.

```
from transformers.models.llama.modeling_llama import LlamaAttention
from torch.sagemaker.nn.attn import FlashGroupedQueryAttention

class LlamaFlashAttention(LlamaAttention):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)

        self.flash_attn = FlashGroupedQueryAttention(
            attention_dropout_prob=0.0,
        )
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        ...
    ):
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        ...
        kv = (key_states, value_states)
        attn_output = self.flash_attn(
            query_states,
            kv,
            attn_mask=attention_mask,
            causal=True,
            layout="b h s d",
        )
        ...
        attn_output = self.o_proj(attn_output)
        ...
        return attn_output
```

The SMP library also provides [`torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-llamaFlashAttn), which uses the [`torch.sagemaker.nn.attn.FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn) API at low level. Hugging Face Transformers has a similar implementation called [https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) from v4.36.0. The following code snippet shows how to use the SMP v2 `LlamaFlashAttention` API or the Transformers `LlamaFlashAttention2` API to replace the attention layers of an existing Llama model.

```
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention
from transformers.models.llama.modeling_llama import LlamaFlashAttention2

flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2

attn_name = "self_attn"
for layer in model.model.layers:
    prev_layer = getattr(layer, attn_name)
    setattr(layer, attn_name, flash_attn_class(model.config))
```

# Checkpointing using SMP
<a name="model-parallel-core-features-v2-checkpoints"></a>

The SageMaker model parallelism (SMP) library supports PyTorch APIs for checkpoints, and provides APIs that help checkpoint properly while using the SMP library. 

PyTorch FSDP (Fully Sharded Data Parallelism) supports three types of checkpoints: full, sharded, and local, each serving different purposes. Full checkpoints are used when exporting the model after training is completed, as generating a full checkpoint is a computationally expensive process. Sharded checkpoints help save and load the state of a model sharded for each individual rank. With sharded checkpoints, you can resume training with different hardware configurations, such as a different number of GPUs. However, loading sharded checkpoints can be slow due to the communication involved among multiple devices. The SMP library provides local checkpointing functionalities, which allow faster retrieval of the model's state without additional communication overhead. Note that checkpoints created by FSDP require writing to a shared network file system such as Amazon FSx.

## Async local checkpoints
<a name="w2aac25c25c19c19c33b7"></a>

When training machine learning models, there is no need for subsequent iterations to wait for the checkpoint files to be saved to disk. With the release of SMP v2.5, the library supports saving checkpoint files asynchronously. This means that the subsequent training iteration can run simultaneously with the input and output (I/O) operations for creating checkpoints, without being slowed down or held back by those I/O operations. Also, the process of retrieving sharded model and optimizer paramemeters in PyTorch can be time-consuming due to the additional collective communication required to exchange distributed tensor metadata across ranks. Even when using `StateDictType.LOCAL_STATE_DICT` to save local checkpoints for each rank, PyTorch still invokes hooks that perform collective communication. To mitigate this issue and reduce the time required for checkpoint retrieval, SMP introduces `SMStateDictType.SM_LOCAL_STATE_DICT`, which allows for faster retrieval of model and optimizer checkpoints by bypassing the collective communication overhead. 

**Note**  
Maintaining consistency in the FSDP `SHARD_DEGREE` is a requirement for utilizing the `SMStateDictType.SM_LOCAL_STATE_DICT`. Ensure that the `SHARD_DEGREE` remains unchanged. While the number of model replications can vary, the model shard degree needs to be identical to the previous training setup when resuming from a checkpoint.

```
import os
import torch.distributed as dist
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.sagemaker.distributed.checkpoint.state_dict_saver import (
    async_save,
    maybe_finalize_async_calls,
)
from torch.sagemaker.distributed.checkpoint.state_dict_utils import (
    sm_state_dict_type,
    SMStateDictType,
)

global_rank = dist.get_rank()
save_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}"

# 1. Get replication ranks and group
current_replication_group = None
current_replication_ranks = None
for replication_ranks in state.ranker.get_rep_groups():
    rep_group = dist.new_group(replication_ranks)
    if global_rank in replication_ranks:
        current_replication_group = rep_group
        current_replication_ranks = replication_ranks

coordinator_rank = min(current_replication_ranks)

# 2. Wait for the previous checkpointing done
maybe_finalize_async_calls(
    blocking=True, process_group=current_replication_group
)

# 3. Get model local checkpoint
with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT):
    state_dict = {
       "model": model.state_dict(),
       "optimizer": optimizer.state_dict(),
        # Potentially add more customized state dicts.
    }

# 4. Save a local checkpoint 
async_save(
    state_dict,
    checkpoint_id=os.path.join(save_dir, sub_dir),
    process_group=current_replication_group,
    coordinator_rank=coordinator_rank,
)
```

The following code snippet demonstrates how to load a checkpoint utilizing `SMStateDictType.SM_LOCAL_STATE_DICT`.

```
import os
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.sagemaker.distributed.checkpoint.state_dict_loader import load
from torch.sagemaker.distributed.checkpoint.state_dict_utils import (
    sm_state_dict_type,
    SMStateDictType,
    init_optim_state
)
from torch.sagemaker.distributed.checkpoint.filesystem import (
    DistributedFileSystemReader,
)

load_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}"
global_rank = dist.get_rank()
checkpoint_id = os.path.join(load_dir, sub_dir)
storage_reader = DistributedFileSystemReader(checkpoint_id)

# 1. Get replication ranks and group
current_replication_group = None
current_replication_ranks = None
for replication_ranks in state.ranker.get_rep_groups():
    rep_group = dist.new_group(replication_ranks)
    if global_rank in replication_ranks:
        current_replication_group = rep_group
        current_replication_ranks = replication_ranks

coordinator_rank = min(current_replication_ranks)

# 2. Create local state_dict
with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT):
    state_dict = {
        "model": model.state_dict(),
        # Potentially add more customized state dicts.
    }
 
    # Init optimizer state_dict states by setting zero grads and step.
    init_optim_state(optimizer, skip_empty_param=True)
    state_dict["optimizer"] = optimizer.state_dict()
 
# 3. Load a checkpoint
load(
    state_dict=state_dict,
    process_group=current_replication_group,
    coordinator_rank=coordinator_rank,
    storage_reader=storage_reader,
)
```

Storing checkpoints for large language models (LLMs) can be expensive as it often requires creating a large filesystem volume. To reduce costs, you have the option to save checkpoints directly to Amazon S3 without the need for additional filesystem services such as Amazon FSx. You can leverage the previous example with the following code snippet to save checkpoints to S3 by specifying an S3 URL as the destination. 

```
key = os.path.join(checkpoint_dir, sub_dir)
checkpoint_id= f"s3://{your_s3_bucket}/{key}"
async_save(state_dict, checkpoint_id=checkpoint_id, **kw)
load(state_dict, checkpoint_id=checkpoint_id, **kw)
```

## Async sharded checkpoints
<a name="w2aac25c25c19c19c33b9"></a>

There may be situations where you need to continue training with different hardware configurations, such as changing the number of GPUs. In these cases, your training processes must load checkpoints while resharding, which means resuming subsequent training with a different number of `SHARD_DEGREE`. In order to address the scenario where you need to resume training with a different number of `SHARD_DEGREE`, you must save your model checkpoints using the sharded state dictionary type, which is represented by `StateDictType.SHARDED_STATE_DICT`. Saving checkpoints in this format allows you to properly handle the resharding process when continuing the training with a modified hardware configuration. The provided code snippet illustrates how to use the `tsm` API to asynchronously save sharded checkpoints, enabling a more efficient and streamlined training process.

```
import os
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.sagemaker.utils.process_group_utils import get_global_ranks
from torch.sagemaker.distributed.checkpoint.state_dict_saver import (
    async_save,
    maybe_finalize_async_calls,
)

save_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}"
checkpoint_id = os.path.join(save_dir, sub_dir)

# To determine whether curreto take part in checkpointing.
global_rank = dist.get_rank()
action_rank = state.ranker.get_rep_rank(global_rank) == 0
process_group = model.process_group
coordinator_rank = min(get_global_ranks(process_group))

# 1. wait for the previous checkpointing done
maybe_finalize_async_calls(blocking=True, process_group=process_group)

# 2. retrieve model & optimizer sharded state_dict
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
    state_dict = {
        "model": model.state_dict(),
        "optimizer": FSDP.optim_state_dict(model, optimizer),
        # Potentially add more customized state dicts.
    }
 
# 3. save checkpoints asynchronously using async_save
if action_rank:
    async_save(
        state_dict,
        checkpoint_id=checkpoint_id,
        process_group=process_group,
        coordinator_rank=coordinator_rank,
    )
```

The process of loading shared checkpoints is similar to the previous section, but it involves using the `torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader` and its `load` method. The `load` method of this class allows you to load the shared checkpoint data, following a process analogous to the one described earlier.

```
import os
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.sagemaker.distributed.checkpoint.state_dict_loader import load
from torch.sagemaker.utils.process_group_utils import get_global_ranks
from torch.sagemaker.distributed.checkpoint.filesystem import (
    DistributedFileSystemReader,
)
 
 load_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}"
checkpoint_id = os.path.join(load_dir, sub_dir)
reader = DistributedFileSystemReader(checkpoint_id)

process_group = model.process_group
coordinator_rank = min(get_global_ranks(process_group))

with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
   # 1. Load model and everything else except the optimizer.
   state_dict = {
        "model": model.state_dict()
        # Potentially more customized state dicts.
   }
   load(
        state_dict,
        storage_reader=reader,
        process_group=process_group,
        coordinator_rank=coordinator_rank,
   )
   model.load_state_dict(state_dict["model"])
 
   # 2. Load optimizer.
   optim_state = load_sharded_optimizer_state_dict(
        model_state_dict=state_dict["model"],
        optimizer_key="optimizer",
        storage_reader=reader,
        process_group=process_group,
    )    
   flattened_optimizer_state = FSDP.optim_state_dict_to_load(
        optim_state["optimizer"], model, optimizer,
         group=model.process_group
   )
   optimizer.load_state_dict(flattened_optimizer_state)
```

## Full model checkpoints
<a name="model-parallel-core-features-v2-checkpoints-full"></a>

At the end of training, you can save a full checkpoint that combines all shards of a model into a single model checkpoint file. The SMP library fully supports the PyTorch full model checkpoints API, so you don't need to make any changes.

Note that if you use the SMP [Tensor parallelism](model-parallel-core-features-v2-tensor-parallelism.md), the SMP library transforms the model. When checkpointing the full model in this case, the SMP library translates the model back to the Hugging Face Transformers checkpoint format by default.

In cases where you train with the SMP tensor parallelism and turn off the SMP translation process, you can use the `translate_on_save` argument of the PyTorch `FullStateDictConfig` API to switch the SMP auto-translation on or off as needed. For example, if you are focusing on training a model, you don’t need to add the translation process which adds overhead. In that case, we recommend you to set `translate_on_save=False`. Also, if you plan to keep using the SMP translation of the model for further training in future, you can switch it off to save the SMP translation of the model for later use. Translating the model back to the Hugging Face Transformers model checkpoint format is needed when you wrap up the training of your model and use that for inference.

```
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullStateDictConfig
import torch.sagemaker as tsm

# Save checkpoints.
with FSDP.state_dict_type(
    model, 
    StateDictType.FULL_STATE_DICT, 
    FullStateDictConfig(
        rank0_only=True, offload_to_cpu=True,
        # Default value is to translate back to Hugging Face Transformers format,
        # when saving full checkpoints for models trained with SMP tensor parallelism.
        # translate_on_save=True
    ),
):
    state_dict = model.state_dict()
    if dist.get_rank() == 0:
        logger.info("Processed state dict to save. Starting write to disk now.")
        os.makedirs(save_dir, exist_ok=True)
        # This name is needed for HF from_pretrained API to work.
        torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin"))
        hf_model_config.save_pretrained(save_dir)
    dist.barrier()
```

Note that the option `FullStateDictConfig(rank0_only=True, offload_to_cpu=True)` is to gather the model on the CPU of the 0th rank device to save memory when training large models.

To load the model back for inference, you do so as shown in the following code example. Note that the class `AutoModelForCausalLM` might change to other factor builder classes in Hugging Face Transformers, such as `AutoModelForSeq2SeqLM`, depending on your model. For more information, see [Hugging Face Transformers documentation](https://huggingface.co/docs/transformers/v4.36.1/en/model_doc/auto#natural-language-processing).

```
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(save_dir)
```