

本文為英文版的機器翻譯版本，如內容有任何歧義或不一致之處，概以英文版為準。

# 使用 SMP 進行檢查點
<a name="model-parallel-core-features-v2-checkpoints"></a>

SageMaker 模型平行化 (SMP) 程式庫支援適用於檢查點的 PyTorch APIs，並提供可在使用 SMP 程式庫時協助正確檢查點的 APIs。

PyTorch FSDP (完整碎片資料平行化) 支援三種類型的檢查點：完整、碎片和本機，每個檢查點都有不同的用途。訓練完成後匯出模型時會使用完整檢查點，因為產生完整檢查點是運算上昂貴的程序。碎片檢查點有助於儲存和載入針對每個個別排名碎片的模型狀態。使用碎片檢查點，您可以使用不同的硬體組態繼續訓練，例如不同的 GPU 數量。不過，由於多個裝置之間涉及通訊，載入碎片檢查點可能會很慢。SMP 程式庫提供本機檢查點功能，可更快速擷取模型的狀態，無需額外的通訊開銷。請注意，FSDP 建立的檢查點需要寫入共用網路檔案系統，例如 Amazon FSx。

## 非同步本機檢查點
<a name="w2aac25c25c19c19c33b7"></a>

訓練機器學習模型時，不需要後續迭代運算來等待檢查點檔案儲存至磁碟。隨著 SMP v2.5 的發行，程式庫支援以非同步方式儲存檢查點檔案。這表示後續的訓練迭代運算可以與用於建立檢查點的輸入和輸出 (I/O) 操作同時執行，而不會被這些 I/O 操作減慢或保留。此外，在 PyTorch 中擷取碎片模型和最佳化工具參數的程序可能很耗時，因為在排名之間交換分散式張量中繼資料所需的額外集體通訊。即使使用 `StateDictType.LOCAL_STATE_DICT` 為每個排名儲存本機檢查點，PyTorch 仍會調用執行集體通訊的勾點。為了減輕此問題並縮短檢查點擷取所需的時間，SMP 引進 `SMStateDictType.SM_LOCAL_STATE_DICT`，透過略過集體通訊開銷，可更快速擷取模型和最佳化工具檢查點。

**注意**  
在 FSDP `SHARD_DEGREE` 中保持一致性是使用 `SMStateDictType.SM_LOCAL_STATE_DICT` 的必要條件。確保 `SHARD_DEGREE` 保持不變。雖然模型複寫的數量可能不同，但從檢查點繼續時，模型碎片程度必須與先前的訓練設定相同。

```
import os
import torch.distributed as dist
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.sagemaker.distributed.checkpoint.state_dict_saver import (
    async_save,
    maybe_finalize_async_calls,
)
from torch.sagemaker.distributed.checkpoint.state_dict_utils import (
    sm_state_dict_type,
    SMStateDictType,
)

global_rank = dist.get_rank()
save_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}"

# 1. Get replication ranks and group
current_replication_group = None
current_replication_ranks = None
for replication_ranks in state.ranker.get_rep_groups():
    rep_group = dist.new_group(replication_ranks)
    if global_rank in replication_ranks:
        current_replication_group = rep_group
        current_replication_ranks = replication_ranks

coordinator_rank = min(current_replication_ranks)

# 2. Wait for the previous checkpointing done
maybe_finalize_async_calls(
    blocking=True, process_group=current_replication_group
)

# 3. Get model local checkpoint
with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT):
    state_dict = {
       "model": model.state_dict(),
       "optimizer": optimizer.state_dict(),
        # Potentially add more customized state dicts.
    }

# 4. Save a local checkpoint 
async_save(
    state_dict,
    checkpoint_id=os.path.join(save_dir, sub_dir),
    process_group=current_replication_group,
    coordinator_rank=coordinator_rank,
)
```

下列程式碼片段示範如何使用 `SMStateDictType.SM_LOCAL_STATE_DICT` 載入檢查點。

```
import os
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.sagemaker.distributed.checkpoint.state_dict_loader import load
from torch.sagemaker.distributed.checkpoint.state_dict_utils import (
    sm_state_dict_type,
    SMStateDictType,
    init_optim_state
)
from torch.sagemaker.distributed.checkpoint.filesystem import (
    DistributedFileSystemReader,
)

load_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}"
global_rank = dist.get_rank()
checkpoint_id = os.path.join(load_dir, sub_dir)
storage_reader = DistributedFileSystemReader(checkpoint_id)

# 1. Get replication ranks and group
current_replication_group = None
current_replication_ranks = None
for replication_ranks in state.ranker.get_rep_groups():
    rep_group = dist.new_group(replication_ranks)
    if global_rank in replication_ranks:
        current_replication_group = rep_group
        current_replication_ranks = replication_ranks

coordinator_rank = min(current_replication_ranks)

# 2. Create local state_dict
with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT):
    state_dict = {
        "model": model.state_dict(),
        # Potentially add more customized state dicts.
    }
 
    # Init optimizer state_dict states by setting zero grads and step.
    init_optim_state(optimizer, skip_empty_param=True)
    state_dict["optimizer"] = optimizer.state_dict()
 
# 3. Load a checkpoint
load(
    state_dict=state_dict,
    process_group=current_replication_group,
    coordinator_rank=coordinator_rank,
    storage_reader=storage_reader,
)
```

儲存大型語言模型 (LLMs) 的檢查點可能很昂貴，因為它通常需要建立大型檔案系統磁碟區。若要降低成本，您可以選擇將檢查點直接儲存到 Amazon S3，而不需要其他檔案系統服務，例如 Amazon FSx。您可以使用下列程式碼片段利用先前的範例，透過指定 S3 URL 做為目的地，將檢查點儲存至 S3。

```
key = os.path.join(checkpoint_dir, sub_dir)
checkpoint_id= f"s3://{your_s3_bucket}/{key}"
async_save(state_dict, checkpoint_id=checkpoint_id, **kw)
load(state_dict, checkpoint_id=checkpoint_id, **kw)
```

## 非同步碎片檢查點
<a name="w2aac25c25c19c19c33b9"></a>

在某些情況下，您可能需要繼續使用不同的硬體組態進行訓練，例如變更 GPU 數量。在這些情況下，您的訓練程序必須在重新分片時載入檢查點，這表示使用不同數量的 `SHARD_DEGREE` 繼續後續訓練。為了解決您需要使用不同數量的 `SHARD_DEGREE` 繼續訓練的情況，您必須使用碎片狀態字典類型儲存模型檢查點，該類型由 `StateDictType.SHARDED_STATE_DICT` 表示。使用此格式儲存檢查點可讓您在繼續使用已修改硬體組態的訓練時，正確處理重新分片程序。提供的程式碼片段說明如何使用 `tsm` API 以非同步方式儲存碎片化檢查點，讓訓練程序更有效率且更簡化。

```
import os
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.sagemaker.utils.process_group_utils import get_global_ranks
from torch.sagemaker.distributed.checkpoint.state_dict_saver import (
    async_save,
    maybe_finalize_async_calls,
)

save_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}"
checkpoint_id = os.path.join(save_dir, sub_dir)

# To determine whether curreto take part in checkpointing.
global_rank = dist.get_rank()
action_rank = state.ranker.get_rep_rank(global_rank) == 0
process_group = model.process_group
coordinator_rank = min(get_global_ranks(process_group))

# 1. wait for the previous checkpointing done
maybe_finalize_async_calls(blocking=True, process_group=process_group)

# 2. retrieve model & optimizer sharded state_dict
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
    state_dict = {
        "model": model.state_dict(),
        "optimizer": FSDP.optim_state_dict(model, optimizer),
        # Potentially add more customized state dicts.
    }
 
# 3. save checkpoints asynchronously using async_save
if action_rank:
    async_save(
        state_dict,
        checkpoint_id=checkpoint_id,
        process_group=process_group,
        coordinator_rank=coordinator_rank,
    )
```

載入共用檢查點的程序類似於上一節，但涉及使用 `torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader` 及其 `load` 方法。此類別的 `load` 方法可讓您載入共用檢查點資料，遵循類似先前所述程序。

```
import os
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.sagemaker.distributed.checkpoint.state_dict_loader import load
from torch.sagemaker.utils.process_group_utils import get_global_ranks
from torch.sagemaker.distributed.checkpoint.filesystem import (
    DistributedFileSystemReader,
)
 
 load_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}"
checkpoint_id = os.path.join(load_dir, sub_dir)
reader = DistributedFileSystemReader(checkpoint_id)

process_group = model.process_group
coordinator_rank = min(get_global_ranks(process_group))

with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
   # 1. Load model and everything else except the optimizer.
   state_dict = {
        "model": model.state_dict()
        # Potentially more customized state dicts.
   }
   load(
        state_dict,
        storage_reader=reader,
        process_group=process_group,
        coordinator_rank=coordinator_rank,
   )
   model.load_state_dict(state_dict["model"])
 
   # 2. Load optimizer.
   optim_state = load_sharded_optimizer_state_dict(
        model_state_dict=state_dict["model"],
        optimizer_key="optimizer",
        storage_reader=reader,
        process_group=process_group,
    )    
   flattened_optimizer_state = FSDP.optim_state_dict_to_load(
        optim_state["optimizer"], model, optimizer,
         group=model.process_group
   )
   optimizer.load_state_dict(flattened_optimizer_state)
```

## 完整模型檢查點
<a name="model-parallel-core-features-v2-checkpoints-full"></a>

在訓練結束時，您可以儲存完整的檢查點，將模型的所有碎片合併為單一模型檢查點檔案。SMP 程式庫完全支援 PyTorch 完整模型檢查點 API，因此您不需要進行任何變更。

請注意，如果您使用 SMP [張量平行化](model-parallel-core-features-v2-tensor-parallelism.md)，SMP 程式庫會轉換模型。在這種情況下檢查點完整模型時，SMP 程式庫預設會將模型轉譯回 Hugging Face Transformer 檢查點格式。

如果您使用 SMP 張量平行化訓練並關閉 SMP 轉譯程序，您可以使用 PyTorch `FullStateDictConfig` API 的 `translate_on_save` 引數，視需要開啟或關閉 SMP 自動轉譯。例如，如果您專注於訓練模型，則不需要新增會增加額外負荷的翻譯程序。在這種情況下，我們建議您設定 `translate_on_save=False`。此外，如果您計劃繼續使用模型的 SMP 轉譯來進行未來的訓練，您可以將其關閉以儲存模型的 SMP 轉譯以供日後使用。當您完成模型的訓練並使用模型進行推論時，需要將模型轉換回 Hugging Face Transformer 模型檢查點格式。

```
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullStateDictConfig
import torch.sagemaker as tsm

# Save checkpoints.
with FSDP.state_dict_type(
    model, 
    StateDictType.FULL_STATE_DICT, 
    FullStateDictConfig(
        rank0_only=True, offload_to_cpu=True,
        # Default value is to translate back to Hugging Face Transformers format,
        # when saving full checkpoints for models trained with SMP tensor parallelism.
        # translate_on_save=True
    ),
):
    state_dict = model.state_dict()
    if dist.get_rank() == 0:
        logger.info("Processed state dict to save. Starting write to disk now.")
        os.makedirs(save_dir, exist_ok=True)
        # This name is needed for HF from_pretrained API to work.
        torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin"))
        hf_model_config.save_pretrained(save_dir)
    dist.barrier()
```

請注意，選項 `FullStateDictConfig(rank0_only=True, offload_to_cpu=True)` 是在排名 0 裝置的 CPU 上收集模型，以便在訓練大型模型時節省記憶體。

若要重新載入模型以進行推論，請執行此操作，如下列程式碼範例所示。請注意，類別 `AutoModelForCausalLM` 可能會變更為 Hugging Face Transformer 中的其他因素建置器類別，例如 `AutoModelForSeq2SeqLM`，視您的模型而定。如需詳細資訊，請參閱 [Hugging Face Transformer 文件](https://huggingface.co/docs/transformers/v4.36.1/en/model_doc/auto#natural-language-processing)。

```
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(save_dir)
```