

翻訳は機械翻訳により提供されています。提供された翻訳内容と英語版の間で齟齬、不一致または矛盾がある場合、英語版が優先します。

# パラメータの遅延初期化
<a name="model-parallel-core-features-v2-delayed-param-init"></a>

トレーニング用の大規模なモデルの初期化は、制限された GPU メモリでは常に可能とは限りません。この GPU メモリ不足という問題を解決するために、CPU メモリでモデルを初期化できます。ただし、パラメータ数が 200 億または 400 億を超えるような大規模なモデルでは、CPU メモリでさえ十分ではないことがあります。このような場合は、PyTorch が呼ぶところの*メタデバイス*でモデルを初期化することをお勧めします。これにより、データが関連付いていないテンソルを作成できます。メタデバイス上のテンソルは形状情報のみを必要とするため、パラメータがメタデバイス上にある大規模なモデルを作成できます。[Hugging Face Accelerate](https://huggingface.co/docs/accelerate/index) が提供している `init_empty_weights` コンテキストマネージャーを使用すれば、このようなモデルをメタデバイスで作成し、バッファは通常のデバイス上で初期化できます。トレーニングが始まる前に、PyTorch FSDP がモデルパラメータを初期化します。SMP v2 のパラメータ遅延初期化の機能では、このモデルパラメータの作成を、PyTorch FSDP がパラメータシャーディングを実行した後まで遅らせます。PyTorch FSDP は、モジュールをシャーディングするときにパラメータ初期化関数 (`param_init_fn`) を受け入れ、各モジュールに対して `param_init_fn` を呼び出します。`param_init_fn` API はモジュールを引数として受け取り、そのモジュール内のすべてのパラメータを、子モジュールのパラメータを除いて初期化します。この動作は、ネイティブ PyTorch v2.0.1 とは*異なります*。PyTorch v2.0.1 では、バグのせいでパラメータが複数回初期化されてしまいます。

SMP v2 は、パラメータの遅延初期化を適用するための [`torch.sagemaker.delayed_param.DelayedParamIniter`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-delayed-param-init) API を提供します。

以下のコードスニペットは、`torch.sagemaker.delayed_param.DelayedParamIniter` API をトレーニングスクリプトに適用する方法を示しています。

次のような PyTorch FSDP トレーニングスクリプトがあるとします。

```
# Creation of model on meta device
from accelerate import init_empty_weights
with init_empty_weights():
    model = create_model()

# Define a param init fn, below is an example for Hugging Face GPTNeoX.
def init_weights(module):
    d = torch.cuda.current_device()
    # Note that below doesn't work if you have buffers in the model
    # buffers will need to reinitialized after this call
    module.to_empty(device=d, recurse=False)
    if isinstance(module, (nn.Linear, Conv1D)):
        module.weight.data.normal_(mean=0.0, std=args.initializer_range)
        if module.bias:
            module.bias.data.zero_()
    elif isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=args.initializer_range)
        if module.padding_idx:
            module.weight.data[module.padding_idx].zero_()
    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)

# Changes to FSDP wrapper.
model = FSDP(
    model,
    ...,
    param_init_fn=init_weights
)

# At this point model is initialized and sharded for sharded data parallelism.
```

パラメータの遅延初期化アプローチはモデル非依存ではありません。この問題を解決するには、前の例に示したように、元のモデル定義の初期化関数と一致するように `init_weights` 関数を記述し、モデルのすべてのパラメータを網羅する必要があります。このような `init_weights` 関数を簡単に準備できるように、SMP v2 では、Hugging Face Transformers の GPT-2、GPT-J、GPT-NeoX、Llama の各モデルに対してこの初期化関数を実装しています。`torch.sagemaker.delayed_param.DelayedParamIniter` API は、SMP テンソル並列処理の実装である `torch.sagemaker.tensor_parallel.transformer.TransformerLMHead` モデルでも動作し、[`torch.sagemaker.transform`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-transform) API コールの後に呼び出すことができます。

`torch.sagemaker.delayed_param.DelayedParamIniter` API を使用して、PyTorch FSDP スクリプトを次のように適応させることができます。空の重みを持つモデルを作成してから、`torch.sagemaker.delayed_param.DelayedParamIniter` API をそのモデルに登録し、そのオブジェクトを定義します。このオブジェクトを PyTorch FSDP クラスの `param_init_fn` に渡します。

```
from torch.sagemaker.delayed_param import DelayedParamIniter
from accelerate import init_empty_weights

with init_empty_weights():
    model = create_model()
    
delayed_initer = DelayedParamIniter(model)

with delayed_initer.validate_params_and_buffers_inited():
    model = FSDP(
        model,
        ...,
        param_init_fn=delayed_initer.get_param_init_fn()
    )
```

**重み共有に関する注意事項**

重みを共有する (tied weight) モデルをトレーニングする場合は、特別な注意が必要です。パラメータの遅延初期化で重みを初期化した後で、重みを共有する必要があります。PyTorch FSDP には、上記のように、`param_init_fn` を使用して重みを初期化した後で、それらを共有するメカニズムがありません。このようなケースに対処するために、`post_init_hook_fn` を許可する API を追加しました。この API を使用して重みを共有できます。モジュールを引数として受け取る任意の関数を渡すことができますが、`DelayedParamIniter` で事前に定義された `post_param_init_fn` もあり、これは、モジュールに `tie_weights` メソッドがある場合はそのメソッドを呼び出します。モジュールに `tie_weights` メソッドがない場合でも、常に `post_param_init_fn` を渡すと安全です。

```
with delayed_initer.validate_params_and_buffers_inited():
    model = FSDP(
        model,
        ...,
        param_init_fn=delayed_initer.get_param_init_fn(),
        post_param_init_fn=delayed_initer.get_post_param_init_fn()
    )
```