

本文属于机器翻译版本。若本译文内容与英语原文存在差异，则一律以英文原文为准。

# 激活检查点
<a name="model-parallel-core-features-v2-pytorch-activation-checkpointing"></a>

*激活检查点*技术通过清除某些层的激活并在向后传递期间重新计算它们，来减少内存使用量。实际上，这是用额外的计算时间来换取内存使用量的减少。如果对模块进行了检查点检查，则在正向传递结束时，只有该模块的初始输入和该模块的最终输出会保留在内存中。 PyTorch 在向前传递期间，释放作为该模块内部计算一部分的任何中间张量。在检查点模块的向后传递过程中， PyTorch 重新计算这些张量。此时，有检查点的模块之外的层已经完成其向后传递，因此检查点操作的峰值内存使用量会变得更低。

SMP v2 支持 PyTorch 激活检查点模块。[https://pytorch.org/blog/scaling-multimodal-foundation-models-in-torchmultimodal-with-pytorch-distributed/#activation-checkpointing](https://pytorch.org/blog/scaling-multimodal-foundation-models-in-torchmultimodal-with-pytorch-distributed/#activation-checkpointing)以下是 Hugging Face GPT-NeoX 模型激活检查点的示例。

**Hugging Face GPT-NeoX 模型的检查点转换器层**

```
from transformers.models.gpt_neox import GPTNeoXLayer
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    apply_activation_checkpointing
)
    
# check_fn receives a module as the arg, 
# and it needs to return whether the module is to be checkpointed
def is_transformer_layer(module):
    from transformers.models.gpt_neox import GPTNeoXLayer
    return isinstance(submodule, GPTNeoXLayer)
    
apply_activation_checkpointing(model, check_fn=is_transformer_layer)
```

**Hugging Face GPT-NeoX 模型的每个其他转换器层都要进行检查点检查**

```
# check_fn receives a module as arg, 
# and it needs to return whether the module is to be checkpointed
# here we define that function based on global variable (transformer_layers)
from transformers.models.gpt_neox import GPTNeoXLayer
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    apply_activation_checkpointing
)

transformer_layers = [
    m for m model.modules() if isinstance(m, GPTNeoXLayer)
]

def is_odd_transformer_layer(module):
    return transformer_layers.index(module) % 2 == 0
    
apply_activation_checkpointing(model, check_fn=is_odd_transformer_layer)
```

或者， PyTorch 还有检查点`torch.utils.checkpoint`模块，Hugging Face Transformers 模型的子集使用该模块。此模块也适用于 SMP v2。但是，这需要您有权访问模型定义，才能添加检查点封装器。因此，我们建议您使用 `apply_activation_checkpointing` 方法。