

기계 번역으로 제공되는 번역입니다. 제공된 번역과 원본 영어의 내용이 상충하는 경우에는 영어 버전이 우선합니다.

# 지연된 파라미터 초기화
<a name="model-parallel-core-features-v2-delayed-param-init"></a>

제한된 GPU 메모리를 사용하면 훈련을 위한 대형 모델을 초기화할 수 없습니다. GPU 메모리 부족의 이 문제를 해결하려면 CPU 메모리에서 모델을 초기화할 수 있습니다. 그러나 파라미터가 200억 또는 400억 이상인 대형 모델의 경우 CPU 메모리만으로는 충분하지 않을 수 있습니다. 이러한 경우 PyTorch가 *메타 디바이스*를 호출하는 에 대해 모델을 초기화하는 것이 좋습니다. 그러면 연결된 데이터 없이 텐서를 생성할 수 있습니다. 메타 디바이스의 텐서는 모양 정보만 필요하며, 이를 통해 메타 디바이스에서 파라미터가 있는 대형 모델을 생성할 수 있습니다. [Hugging Face Accelerate](https://huggingface.co/docs/accelerate/index)는 일반 디바이스에서 버퍼를 초기화하는 동안 메타 디바이스에서 이러한 모델을 생성하는 데 도움이 되는 컨텍스트 관리자 `init_empty_weights`를 제공합니다. 훈련을 시작하기 전에 PyTorch FSDP는 모델 파라미터를 초기화합니다. SMP v2의 이 지연된 파라미터 초기화 기능은 PyTorch FSDP가 파라미터 샤딩을 수행한 후 모델 파라미터 생성이 지연됩니다. PyTorch FSDP는 모듈을 샤딩할 때 파라미터 초기화 함수(`param_init_fn`)를 수락하고 각 모듈에 대해 `param_init_fn`을 호출합니다. `param_init_fn` API는 모듈을 인수로 받아들이고 하위 모듈의 파라미터를 제외한 모든 파라미터를 초기화합니다. 이 동작은 파라미터가 여러 번 초기화되는 버그가 있는 기본 PyTorch v2.0.1과 *다릅니다*.

SMP v2는 지연된 파라미터 초기화를 적용하는 [`torch.sagemaker.delayed_param.DelayedParamIniter`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-delayed-param-init) API를 제공합니다.

다음 코드 조각은 `torch.sagemaker.delayed_param.DelayedParamIniter` API를 훈련 스크립트에 적용하는 방법을 보여줍니다.

다음과 같이 PyTorch FSDP 훈련 스크립트가 있다고 가정합니다.

```
# Creation of model on meta device
from accelerate import init_empty_weights
with init_empty_weights():
    model = create_model()

# Define a param init fn, below is an example for Hugging Face GPTNeoX.
def init_weights(module):
    d = torch.cuda.current_device()
    # Note that below doesn't work if you have buffers in the model
    # buffers will need to reinitialized after this call
    module.to_empty(device=d, recurse=False)
    if isinstance(module, (nn.Linear, Conv1D)):
        module.weight.data.normal_(mean=0.0, std=args.initializer_range)
        if module.bias:
            module.bias.data.zero_()
    elif isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=args.initializer_range)
        if module.padding_idx:
            module.weight.data[module.padding_idx].zero_()
    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)

# Changes to FSDP wrapper.
model = FSDP(
    model,
    ...,
    param_init_fn=init_weights
)

# At this point model is initialized and sharded for sharded data parallelism.
```

지연된 파라미터 초기화 접근 방식은 모델에 구애받지 않습니다. 이 문제를 해결하려면 이전 예제와 같이 `init_weights` 함수를 작성하여 원래 모델 정의의 초기화와 일치해야 하며, 모델의 모든 파라미터를 포함해야 합니다. 이러한 `init_weights` 함수를 준비하는 이 프로세스를 간소화하기 위해 SMP v2는 Hugging Face 트랜스포머의 GPT-2, GPT-J, GPT-NeoX 및 Llama 모델에 대해 이 초기화 함수를 구현합니다. `torch.sagemaker.delayed_param.DelayedParamIniter` API는 또한 [`torch.sagemaker.transform`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-transform) API 호출 후 호출할 수 있는 SMP 텐서 병렬 구현 `torch.sagemaker.tensor_parallel.transformer.TransformerLMHead` 모델과 함께 작동합니다.

`torch.sagemaker.delayed_param.DelayedParamIniter` API를 사용하여 다음과 같이 PyTorch FSDP 스크립트를 조정할 수 있습니다. 가중치가 비어 있는 모델을 생성한 후 `torch.sagemaker.delayed_param.DelayedParamIniter` API를 모델에 등록하고 해당 모델의 객체를 정의합니다. 객체를 PyTorch FSDP 클래스의 `param_init_fn`로 전달합니다.

```
from torch.sagemaker.delayed_param import DelayedParamIniter
from accelerate import init_empty_weights

with init_empty_weights():
    model = create_model()
    
delayed_initer = DelayedParamIniter(model)

with delayed_initer.validate_params_and_buffers_inited():
    model = FSDP(
        model,
        ...,
        param_init_fn=delayed_initer.get_param_init_fn()
    )
```

**연결된 가중치에 대한 참고 사항**

가중치가 묶인 모델을 훈련할 때는 파라미터 초기화가 지연된 가중치를 초기화한 후 가중치를 묶을 수 있도록 각별히 주의해야 합니다. PyTorch FSDP에는 위와 같이 `param_init_fn`를 사용하여 가중치를 초기화한 후 가중치를 연결하는 메커니즘이 없습니다. 이러한 경우를 해결하기 위해 가중치를 연결하는 데 사용할 수 있는 `post_init_hook_fn`를 허용하도록 API를 추가했습니다. 모듈을 인수로 수락하는 함수를 거기서 전달할 수 있지만 모듈이 있는 경우 모듈의 `tie_weights` 메서드를 호출하는 `DelayedParamIniter`에 미리 정의된 `post_param_init_fn`도 있습니다. 모듈에 대한 `tie_weights` 메서드가 없더라도 항상 `post_param_init_fn`에서 패스하는 것이 안전합니다.

```
with delayed_initer.validate_params_and_buffers_inited():
    model = FSDP(
        model,
        ...,
        param_init_fn=delayed_initer.get_param_init_fn(),
        post_param_init_fn=delayed_initer.get_post_param_init_fn()
    )
```