

기계 번역으로 제공되는 번역입니다. 제공된 번역과 원본 영어의 내용이 상충하는 경우에는 영어 버전이 우선합니다.

# SMP를 사용한 체크포인트 지정
<a name="model-parallel-core-features-v2-checkpoints"></a>

SageMaker 모델 병렬 처리(SMP) 라이브러리는 체크포인트에 대한 PyTorch API를 지원하며 SMP 라이브러리를 사용하는 동안 체크포인트를 올바르게 수행하는 데 도움이 되는 APIs를 제공합니다.

PyTorch FSDP(Fully Sharded Data Parallelism)는 각각 서로 다른 목적을 제공하는 전체, 샤딩 및 로컬의 세 가지 유형의 체크포인트를 지원합니다. 전체 체크포인트 생성은 계산 비용이 많이 드는 프로세스이므로 전체 체크포인트는 훈련이 완료된 후 모델을 내보낼 때 사용됩니다. 샤딩된 체크포인트는 각 개별 순위에 대해 샤딩된 모델의 상태를 저장하고 로드하는 데 도움이 됩니다. 샤딩된 체크포인트를 사용하면 GPU 개수 등 다양한 하드웨어 구성으로 훈련을 다시 시작할 수 있습니다. 그러나 여러 디바이스 간의 통신으로 인해 샤딩된 체크포인트 로드가 느려질 수 있습니다. SMP 라이브러리는 추가 통신 오버헤드 없이 모델 상태를 더 빠르게 검색할 수 있는 로컬 체크포인트 기능을 제공합니다. FSDP에서 생성한 체크포인트는 Amazon FSx 와 같은 공유 네트워크 파일 시스템에 작성해야 합니다.

## 로컬 체크포인트 비동기화
<a name="w2aac25c25c19c19c33b7"></a>

기계 학습 모델을 훈련할 때 체크포인트 파일이 디스크에 저장될 때까지 기다리지 않아도 됩니다. SMP v2.5 릴리스와 함께 라이브러리는 체크포인트 파일 비동기 저장을 지원합니다. 즉, 후속 훈련 반복은 I/O 작업으로 인해 속도가 느려지거나 지연되지 않고 체크포인트를 생성하기 위한 입력 및 출력(I/O) 작업과 동시에 실행될 수 있습니다. 또한 PyTorch에서 샤딩된 모델 및 옵티마이저 파라미터 검색 프로세스는 순위 간에 분산된 텐서 메타데이터를 교환하는 데 필요한 추가 집합 통신으로 인해 시간이 많이 걸릴 수 있습니다. `StateDictType.LOCAL_STATE_DICT`를 사용하여 각 순위에 대한 로컬 체크포인트를 저장할 때도 PyTorch는 집합 통신을 수행하는 후크를 여전히 호출합니다. 이 문제를 완화하고 체크포인트 검색에 필요한 시간을 줄이기 위해 SMP는 집합 통신 오버헤드를 우회하여 모델 및 옵티마이저 체크포인트를 더 빠르게 검색할 수 있도록 `SMStateDictType.SM_LOCAL_STATE_DICT`를 도입했습니다.

**참고**  
FSDP `SHARD_DEGREE`의 일관성을 유지하는 것은 `SMStateDictType.SM_LOCAL_STATE_DICT`를 활용하기 위한 요구 사항입니다. `SHARD_DEGREE`가 변경되지 않은 상태로 유지되는지 확인합니다. 모델 복제 수는 다를 수 있지만 체크포인트에서 다시 시작할 때 모델 샤드 정도는 이전 훈련 설정과 동일해야 합니다.

```
import os
import torch.distributed as dist
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.sagemaker.distributed.checkpoint.state_dict_saver import (
    async_save,
    maybe_finalize_async_calls,
)
from torch.sagemaker.distributed.checkpoint.state_dict_utils import (
    sm_state_dict_type,
    SMStateDictType,
)

global_rank = dist.get_rank()
save_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}"

# 1. Get replication ranks and group
current_replication_group = None
current_replication_ranks = None
for replication_ranks in state.ranker.get_rep_groups():
    rep_group = dist.new_group(replication_ranks)
    if global_rank in replication_ranks:
        current_replication_group = rep_group
        current_replication_ranks = replication_ranks

coordinator_rank = min(current_replication_ranks)

# 2. Wait for the previous checkpointing done
maybe_finalize_async_calls(
    blocking=True, process_group=current_replication_group
)

# 3. Get model local checkpoint
with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT):
    state_dict = {
       "model": model.state_dict(),
       "optimizer": optimizer.state_dict(),
        # Potentially add more customized state dicts.
    }

# 4. Save a local checkpoint 
async_save(
    state_dict,
    checkpoint_id=os.path.join(save_dir, sub_dir),
    process_group=current_replication_group,
    coordinator_rank=coordinator_rank,
)
```

다음 코드 조각은 `SMStateDictType.SM_LOCAL_STATE_DICT`를 사용하여 체크포인트를 로드하는 방법을 보여줍니다.

```
import os
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.sagemaker.distributed.checkpoint.state_dict_loader import load
from torch.sagemaker.distributed.checkpoint.state_dict_utils import (
    sm_state_dict_type,
    SMStateDictType,
    init_optim_state
)
from torch.sagemaker.distributed.checkpoint.filesystem import (
    DistributedFileSystemReader,
)

load_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}"
global_rank = dist.get_rank()
checkpoint_id = os.path.join(load_dir, sub_dir)
storage_reader = DistributedFileSystemReader(checkpoint_id)

# 1. Get replication ranks and group
current_replication_group = None
current_replication_ranks = None
for replication_ranks in state.ranker.get_rep_groups():
    rep_group = dist.new_group(replication_ranks)
    if global_rank in replication_ranks:
        current_replication_group = rep_group
        current_replication_ranks = replication_ranks

coordinator_rank = min(current_replication_ranks)

# 2. Create local state_dict
with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT):
    state_dict = {
        "model": model.state_dict(),
        # Potentially add more customized state dicts.
    }
 
    # Init optimizer state_dict states by setting zero grads and step.
    init_optim_state(optimizer, skip_empty_param=True)
    state_dict["optimizer"] = optimizer.state_dict()
 
# 3. Load a checkpoint
load(
    state_dict=state_dict,
    process_group=current_replication_group,
    coordinator_rank=coordinator_rank,
    storage_reader=storage_reader,
)
```

대규모 언어 모델(LLM)에 대한 체크포인트를 저장하려면 대용량 파일 시스템 볼륨을 생성해야 하는 경우가 많기 때문에 비용이 많이 들 수 있습니다. 비용을 절감하기 위해 Amazon FSx와 같은 추가 파일 시스템 서비스가 필요 없이 Amazon S3에 체크포인트를 직접 저장할 수 있습니다. 다음 코드 조각을 사용하여 S3 URL을 대상으로 지정하여 이전 예제를 활용하여 S3에 체크포인트를 저장할 수 있습니다.

```
key = os.path.join(checkpoint_dir, sub_dir)
checkpoint_id= f"s3://{your_s3_bucket}/{key}"
async_save(state_dict, checkpoint_id=checkpoint_id, **kw)
load(state_dict, checkpoint_id=checkpoint_id, **kw)
```

## 비동기 샤딩된 체크포인트
<a name="w2aac25c25c19c19c33b9"></a>

GPU 수를 변경하는 등 다른 하드웨어 구성으로 훈련을 계속해야 하는 상황이 있을 수 있습니다. 이러한 경우 훈련 프로세스는 리샤딩 중에 체크포인트를 로드해야 합니다. 즉, 다른 수의 `SHARD_DEGREE`로 후속 훈련을 재개해야 합니다. 다른 수의 `SHARD_DEGREE`로 훈련을 재개해야 하는 시나리오를 해결하려면 `StateDictType.SHARDED_STATE_DICT`로 표시되는 샤딩된 상태 사전 유형을 사용하여 모델 체크포인트를 저장해야 합니다. 이 형식으로 체크포인트를 저장하면 수정된 하드웨어 구성으로 훈련을 계속할 때 재분배 프로세스를 올바르게 처리할 수 있습니다. 제공된 코드 조각은 `tsm` API를 사용하여 샤딩된 체크포인트를 비동기적으로 저장하여 보다 효율적이고 간소화된 훈련 프로세스를 가능하게 하는 방법을 보여줍니다.

```
import os
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.sagemaker.utils.process_group_utils import get_global_ranks
from torch.sagemaker.distributed.checkpoint.state_dict_saver import (
    async_save,
    maybe_finalize_async_calls,
)

save_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}"
checkpoint_id = os.path.join(save_dir, sub_dir)

# To determine whether curreto take part in checkpointing.
global_rank = dist.get_rank()
action_rank = state.ranker.get_rep_rank(global_rank) == 0
process_group = model.process_group
coordinator_rank = min(get_global_ranks(process_group))

# 1. wait for the previous checkpointing done
maybe_finalize_async_calls(blocking=True, process_group=process_group)

# 2. retrieve model & optimizer sharded state_dict
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
    state_dict = {
        "model": model.state_dict(),
        "optimizer": FSDP.optim_state_dict(model, optimizer),
        # Potentially add more customized state dicts.
    }
 
# 3. save checkpoints asynchronously using async_save
if action_rank:
    async_save(
        state_dict,
        checkpoint_id=checkpoint_id,
        process_group=process_group,
        coordinator_rank=coordinator_rank,
    )
```

공유 체크포인트를 로드하는 프로세스는 이전 섹션과 유사하지만 `torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader` 및 해당 `load` 메서드를 사용해야 합니다. 이 클래스의 `load` 메서드를 사용하면 앞서 설명한 것과 유사한 프로세스에 따라 공유 체크포인트 데이터를 로드할 수 있습니다.

```
import os
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.sagemaker.distributed.checkpoint.state_dict_loader import load
from torch.sagemaker.utils.process_group_utils import get_global_ranks
from torch.sagemaker.distributed.checkpoint.filesystem import (
    DistributedFileSystemReader,
)
 
 load_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}"
checkpoint_id = os.path.join(load_dir, sub_dir)
reader = DistributedFileSystemReader(checkpoint_id)

process_group = model.process_group
coordinator_rank = min(get_global_ranks(process_group))

with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
   # 1. Load model and everything else except the optimizer.
   state_dict = {
        "model": model.state_dict()
        # Potentially more customized state dicts.
   }
   load(
        state_dict,
        storage_reader=reader,
        process_group=process_group,
        coordinator_rank=coordinator_rank,
   )
   model.load_state_dict(state_dict["model"])
 
   # 2. Load optimizer.
   optim_state = load_sharded_optimizer_state_dict(
        model_state_dict=state_dict["model"],
        optimizer_key="optimizer",
        storage_reader=reader,
        process_group=process_group,
    )    
   flattened_optimizer_state = FSDP.optim_state_dict_to_load(
        optim_state["optimizer"], model, optimizer,
         group=model.process_group
   )
   optimizer.load_state_dict(flattened_optimizer_state)
```

## 전체 모델 체크포인트
<a name="model-parallel-core-features-v2-checkpoints-full"></a>

훈련이 끝나면 모델의 모든 샤드를 단일 모델 체크포인트 파일로 결합하는 전체 체크포인트를 저장할 수 있습니다. SMP 라이브러리는 PyTorch 전체 모델 체크포인트 API를 완전히 지원하므로 변경할 필요가 없습니다.

SMP [텐서 병렬화](model-parallel-core-features-v2-tensor-parallelism.md)을 사용하면 SMP 라이브러리가 모델을 변환합니다. 이 경우 전체 모델을 체크포인트할 때 SMP 라이브러리는 모델을 Hugging Face 트랜스포머 체크포인트 형식으로 다시 변환합니다.

SMP 텐서 병렬 처리로 훈련하고 SMP 번역 프로세스를 끄는 경우 PyTorch `FullStateDictConfig` API의 `translate_on_save` 인수를 사용하여 필요에 따라 SMP 자동 번역을 켜거나 끌 수 있습니다. 예를 들어 모델 훈련에 집중하는 경우 오버헤드를 추가하는 번역 프로세스를 추가할 필요가 없습니다. `translate_on_save=False`를 설정하는 것이 좋습니다. 또한 향후 추가 훈련을 위해 모델의 SMP 번역을 계속 사용할 계획이라면 끄면 나중에 사용할 수 있도록 모델의 SMP 번역을 저장할 수 있습니다. 모델 훈련을 마무리하고 추론에 사용할 때는 모델을 Hugging Face 트랜스포머 모델 체크포인트 형식으로 다시 변환해야 합니다.

```
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullStateDictConfig
import torch.sagemaker as tsm

# Save checkpoints.
with FSDP.state_dict_type(
    model, 
    StateDictType.FULL_STATE_DICT, 
    FullStateDictConfig(
        rank0_only=True, offload_to_cpu=True,
        # Default value is to translate back to Hugging Face Transformers format,
        # when saving full checkpoints for models trained with SMP tensor parallelism.
        # translate_on_save=True
    ),
):
    state_dict = model.state_dict()
    if dist.get_rank() == 0:
        logger.info("Processed state dict to save. Starting write to disk now.")
        os.makedirs(save_dir, exist_ok=True)
        # This name is needed for HF from_pretrained API to work.
        torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin"))
        hf_model_config.save_pretrained(save_dir)
    dist.barrier()
```

`FullStateDictConfig(rank0_only=True, offload_to_cpu=True)` 옵션은 0순위 디바이스의 CPU에서 모델을 수집하여 대형 모델을 훈련할 때 메모리를 저장하는 것입니다.

추론을 위해 모델을 다시 로드하려면 다음 코드 예제와 같이 로드합니다. 모델에 따라 `AutoModelForCausalLM` 클래스가 `AutoModelForSeq2SeqLM`등 Hugging Face 트랜스포머의 다른 팩터 빌더 클래스로 변경될 수 있습니다. 자세한 내용은 [Hugging Face 트랜스포머 설명서](https://huggingface.co/docs/transformers/v4.36.1/en/model_doc/auto#natural-language-processing)를 참조하세요.

```
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(save_dir)
```