延遲參數初始化 - Amazon SageMaker AI

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

延遲參數初始化

有限 GPU 記憶體不一定能初始化大型模型以進行訓練。若要解決 GPU 記憶體不足的問題,您可以在 CPU 記憶體上初始化模型。不過,對於參數超過 200 或 400 億個的大型模型,即使 CPU 記憶體也可能不夠。在這種情況下,我們建議您在 PyTorch 呼叫中繼裝置時初始化模型,以允許建立張量,而不會連接任何資料。中繼裝置上的張量只需要形狀資訊,這可在中繼裝置上建立具有其參數的大型模型。Hugging Face Accelerate 提供內容管理員 init_empty_weights,可協助在中繼裝置上建立此類模型,同時在一般裝置上初始化緩衝區。在訓練開始之前,PyTorch FSDP 會初始化模型參數。SMP v2 的延遲參數初始化功能會延遲在 PyTorch FSDP 執行參數碎片之後建立模型參數。PyTorch FSDP 會在分割模組時接受參數初始化函數 (param_init_fn),並呼叫每個模組的 param_init_fnparam_init_fn API 採用模組做為引數,並初始化其中的所有參數,不包括任何子模組的參數。請注意,此行為與原生 PyTorch 2.0.1 版「不同」,其具有導致參數初始化多次的錯誤。

SMP v2 提供用於套用延遲參數初始化的 torch.sagemaker.delayed_param.DelayedParamIniter API。

下列程式碼片段示範如何將 torch.sagemaker.delayed_param.DelayedParamIniter API 套用至訓練指令碼。

假設您有 PyTorch FSDP 訓練指令碼,如下所示。

# Creation of model on meta device from accelerate import init_empty_weights with init_empty_weights(): model = create_model() # Define a param init fn, below is an example for Hugging Face GPTNeoX. def init_weights(module): d = torch.cuda.current_device() # Note that below doesn't work if you have buffers in the model # buffers will need to reinitialized after this call module.to_empty(device=d, recurse=False) if isinstance(module, (nn.Linear, Conv1D)): module.weight.data.normal_(mean=0.0, std=args.initializer_range) if module.bias: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=args.initializer_range) if module.padding_idx: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) # Changes to FSDP wrapper. model = FSDP( model, ..., param_init_fn=init_weights ) # At this point model is initialized and sharded for sharded data parallelism.

請注意,延遲參數初始化方法與模型無關。若要解決此問題,您需要撰寫 init_weights 函數,如上述範例所示,以符合原始模型定義中的初始化,且應涵蓋模型的所有參數。為了簡化準備此類 init_weights 函數的程序,SMP v2 會針對下列模型實作此初始化函數:GPT-2、GPT-J、GPT-NeoX 和來自 Hugging Face Transformer 的 Llama。torch.sagemaker.delayed_param.DelayedParamIniter API 也適用於可在 torch.sagemaker.transform API 呼叫後呼叫的 SMP 張量平行實作torch.sagemaker.tensor_parallel.transformer.TransformerLMHead模型。

您可以使用 torch.sagemaker.delayed_param.DelayedParamIniter API 調整 PyTorch FSDP 指令碼,如下所示。建立具有空權重的模型之後,請將 torch.sagemaker.delayed_param.DelayedParamIniter API 註冊到模型,並定義其物件。將物件傳遞至 PyTorch FSDP 類別的 param_init_fn

from torch.sagemaker.delayed_param import DelayedParamIniter from accelerate import init_empty_weights with init_empty_weights(): model = create_model() delayed_initer = DelayedParamIniter(model) with delayed_initer.validate_params_and_buffers_inited(): model = FSDP( model, ..., param_init_fn=delayed_initer.get_param_init_fn() )

綁定權重的注意事項

使用綁定權重訓練模型時,我們需要特別注意在初始化權重和延遲參數初始化之後綁定權重。PyTorch FSDP 沒有機制可在使用 初始化權重後繫結權重param_init_fn,如上所述。為了解決這類情況,我們新增了 API 以允許 post_init_hook_fn,可用於繫結權重。您可以在其中傳遞接受模組做為引數的任何函數,但我們也在 DelayedParamIniter 中定義了預先定義的 post_param_init_fn,在其中呼叫模組 tie_weights 方法 (若存在)。請注意,即使沒有模組的 tie_weights 方法,一律傳入 post_param_init_fn 也是安全的。

with delayed_initer.validate_params_and_buffers_inited(): model = FSDP( model, ..., param_init_fn=delayed_initer.get_param_init_fn(), post_param_init_fn=delayed_initer.get_post_param_init_fn() )