啟用檢查點
啟用檢查點是減少記憶體使用量的技術,方法是清除某些圖層的啟用,並在向後傳遞期間重新加以運算。實際上,這是以額外運算時間換取減少記憶體使用量。如果對模組進行了檢查點作業,則在向前傳遞結束時,只有模組的初始輸入和模組的最終輸出會保留在記憶體中。PyTorch 在向前傳遞期間,會釋放屬於該模組內部運算一部分的任何中級張量。在檢查點模組的向後傳遞期間,PyTorch 會重新運算這些張量。此時,超出此檢查點模組的圖層已完成其向後傳遞,因此可降低運用檢查點的最高記憶體使用量。
SMP v2 支援 PyTorch 啟用檢查點模組 apply_activation_checkpointing
Hugging Face GPT-NeoX 模型的檢查點轉換器層
from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) # check_fn receives a module as the arg, # and it needs to return whether the module is to be checkpointed def is_transformer_layer(module): from transformers.models.gpt_neox import GPTNeoXLayer return isinstance(submodule, GPTNeoXLayer) apply_activation_checkpointing(model, check_fn=is_transformer_layer)
對 Hugging Face GPT-NeoX 模型的其他每個轉換器層進行檢查點
# check_fn receives a module as arg, # and it needs to return whether the module is to be checkpointed # here we define that function based on global variable (transformer_layers) from transformers.models.gpt_neox import GPTNeoXLayer from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing ) transformer_layers = [ m for m model.modules() if isinstance(m, GPTNeoXLayer) ] def is_odd_transformer_layer(module): return transformer_layers.index(module) % 2 == 0 apply_activation_checkpointing(model, check_fn=is_odd_transformer_layer)
或者,PyTorch 也有用於檢查點的 torch.utils.checkpoint 模組,由 Hugging Face Transformer 模型的子集使用。此模組也適用於 SMP v2。不過,這需要您存取模型定義以新增檢查點包裝函式。因此,建議您使用 apply_activation_checkpointing 方法。