混合精確度訓練 - Amazon SageMaker AI

混合精確度訓練

SageMaker 模型平行化 (SMP) 程式庫 v2 透過與 PyTorch FSDP 和轉換器引擎等開放原始碼架構整合,支援立即可用的混合精確度訓練。如需詳細資訊,請參閱下列主題。

使用轉換器引擎在 P5 執行個體上使用 FP8 進行混合精確度訓練

從 SageMaker 模型平行化 (SMP) 程式庫 v2.2.0 開始,SMP 程式庫與轉換器引擎整合,並支援開箱即用的 FP8 混合精確度訓練,保持與 PyTorch FSDP MixedPrecision的相容性。這表示您可以將 PyTorch FSDP 用於混合精確度訓練,並將轉換器引擎用於 FP8 訓練。對於轉換器引擎 FP8 訓練功能不支援的模型層,這些層會回復到 PyTorch FSDP 混合精確度。

注意

SMP v2 為下列 Hugging Face Transformer 模型提供 FP8 支援:

  • GPT-NeoX (適用於 SMP v2.2.0 和更新版本)

  • Llama 2 (可在 SMP 2.2.0 版和更新版本中使用)

  • Mixtral 8x7b 和 Mixtral 8x22b (適用於 SMP v2.5.0 和更新版本)

注意

P5 功能的此 FP8 訓練可在 SageMaker 程式庫和 PyTorch 程式庫的下列組合中使用:

  • SageMaker Python SDK 2.212.0 版及更新版本

  • PyTorch v2.2.0 及較新版本

FP8 (8 位元浮點精確度) 是一種資料類型,已出現為另一個範例,以加速 LLM 模型的深度學習訓練。隨著支援 FP8 資料類型的 NVIDIA H100 GPU 發行,您可以從配備 H100 GPU 之 P5 執行個體的效能改進中受益,同時透過 FP8 混合精確度訓練加速分散式訓練。

FP8 資料類型會進一步細分為 E4M3 和 E5M2 格式。E4M3 提供更高的精確度、有限的動態範圍,並且非常適合向前傳遞模型訓練。E5M2 具有更廣泛的動態範圍,但降低了精確度,更適合向後傳遞,其中精確度較不重要,而更寬的動態範圍會受益。因此,我們建議您使用混合式 FP8 策略配方來有效利用這些特性。

對於半精度資料類型 (FP16 和 BF16),全域損失擴展技術,例如靜態損失擴展或動態損失擴展會處理由於半精度中四捨五入梯度而導致的資訊損失引起的收斂問題。不過,FP8 的動態範圍甚至更窄,而且全域損失擴展技術還不夠。此時,我們需要更精細的每個張量擴展技術。延遲擴展是一種策略,可根據在先前迭代運算的張量中觀察到的最大絕對值來選取擴展係數。此策略存在權衡;它使用 FP8 運算的完整效能優勢,但需要記憶體來保留張量的最大值歷史記錄。若要進一步了解一般延遲擴展策略,請參閱適用於深度學習的紙質 FP8 格式

實際上,使用 FP8 有助於 P5 執行個體上的所有訓練案例。我們強烈建議盡可能啟用 FP8,以提高訓練效能。

SMP v2 支援立即可用的轉換器引擎。因此,在 SageMaker AI (ml.p5.48xlarge) 的 P5 執行個體上使用 SMP v2 執行 FP8 訓練時,您只需在訓練指令碼中匯入 torch.sagemaker 並繼續使用原生轉換器引擎 Python 套件。若要進一步了解如何使用轉換器引擎進行 FP8 訓練,請參閱《NVIDIA 轉換器引擎》文件中的使用 FP8 搭配轉換器引擎。下列程式碼片段顯示匯入 SMP 程式庫和在訓練指令碼中設定 FP8 的程式碼行看起來如何。

import torch.sagemaker as tsm import transformer_engine.pytorch as te from transformer_engine.common.recipe import DelayedScaling, Format # Initialize the SMP torch.sagemaker API. tsm.init() # Define a transformer model and wrap it with the torch.sagemaker.transform API. from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_config(ModelConfig) model = tsm.transform(model) # Enable E4M3 during forward pass, E5M2 during backward pass. fp8_format = Format.HYBRID # Create an FP8 recipe. fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") # Enable FP8 autocasting. with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=tsm.state.world_process_group): out = model(inp) loss = out.sum() loss.backward()

若要尋找在 P5 執行個體上使用 SMP v2 進行 FP8 訓練的實際範例,請參閱在 P5 執行個體上使用 FP8 加速 Llama-v2 (或 GPT-NeoX) 的 SageMaker PyTorch FSDP 訓練的範例筆記本。

使用 PyTorch FSDP 搭配半精確度資料類型的混合精確度訓練

SMP v2 支援在 P4 和 P5 執行個體上訓練任務的 PyTorch FSDP MixedPrecision。PyTorch FSDP 為混合精確度提供各種組態,以改善效能和減少記憶體。

注意

以下 SageMaker 程式庫和 PyTorch 程式庫組合提供此 PyTorch FSDP 功能的混合精確度訓練。

  • SMP 2.0.0 版及更新版本

  • SageMaker Python SDK 2.200.0 版及更新版本

  • PyTorch v2.0.1 及較新版本

為混合精確度設定模型的標準方法是在 float32 中建立模型,然後允許 FSDP 透過傳遞 MixedPrecision 政策將參數轉換為 float16bfloat16,如下列程式碼片段所示。如需在 PyTorch 中變更 dtype 參數、降低或緩衝區混合精確度選項的詳細資訊,請參閱 PyTorch 文件中的 PyTorch FSDP MixedPrecision API

# Native PyTorch API from torch.distributed.fsdp import MixedPrecision dtype = torch.bfloat16 mixed_precision_policy = MixedPrecision( param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype ) model = FSDP( model, ..., mixed_precision=mixed_precision_policy )

請注意,某些模型 (例如 Hugging Face Transformer Llama 模型) 預期緩衝區為 float32。若要使用 float32,請在定義 dtype 物件的行中將 torch.bfloat16 取代為 torch.float32