

翻訳は機械翻訳により提供されています。提供された翻訳内容と英語版の間で齟齬、不一致または矛盾がある場合、英語版が優先します。

# SMP を使用したチェックポイント
<a name="model-parallel-core-features-v2-checkpoints"></a>

SageMaker モデル並列処理 (SMP) ライブラリは、チェックポイント用の PyTorch API に対応し、SMP ライブラリの使用中にチェックポイントを適切に行うのに役立つ API を提供しています。

PyTorch FSDP (Fully Sharded Data Parallelism) は、フル (full)、シャード (sharded)、ローカル (local) の 3 種類のチェックポイントをサポートしていますが、それぞれ目的が異なります。フルチェックポイントの生成は計算コストがかかるプロセスであるため、フルチェックポイントは、トレーニングの完了後にモデルをエクスポートする際に使用されます。シャードチェックポイントは、個々のランクごとに分割されたモデルの状態を保存し、ロードするのに役立ちます。シャードチェックポイントを使用すると、異なるハードウェア構成 (GPU の数が違うなど) でトレーニングを再開できます。ただし、シャードチェックポイントのロードは、複数のデバイス間の通信が必要になるため、遅くなる可能性があります。SMP ライブラリにはローカルチェックポイント機能があり、追加の通信オーバーヘッドを伴わずにモデルの状態をすばやく取得できます。FSDP によって作成されたチェックポイントは、Amazon FSx などの共有ネットワークファイルシステムへの書き込みが必要です。

## 非同期ローカルチェックポイント
<a name="w2aac25c25c19c19c33b7"></a>

機械学習モデルをトレーニングする際、チェックポイントファイルがディスクに保存されるのを後続のイテレーションが待つ必要はありません。SMP v2.5 のリリースで、チェックポイントファイルの非同期保存に対応しました。つまり、後続のトレーニングイテレーションは、チェックポイント作成時の入出力 (I/O) 演算と同時に実行でき、それらの I/O 演算によって低速化したり、中断されたりすることはありません。また、PyTorch でシャーディングされたモデルとオプティマイザのパラメータを取得するプロセスは、分散テンソルのメタデータをランク間で交換するために追加の集合通信が必要になるため、時間がかかる場合があります。`StateDictType.LOCAL_STATE_DICT` を使用して各ランクのローカルチェックポイントを保存した場合でも、PyTorch は依然として集合通信を実行するフックを呼び出します。この問題を軽減し、チェックポイントの取得にかかる時間を短縮するために、SMP は `SMStateDictType.SM_LOCAL_STATE_DICT` を導入しました。これにより、集合通信のオーバーヘッドを回避することで、モデルおよびオプティマイザのチェックポイントをより迅速に取得できます。

**注記**  
`SMStateDictType.SM_LOCAL_STATE_DICT` を使用するには、FSDP `SHARD_DEGREE` の一貫性を維持する必要があります。`SHARD_DEGREE` が変更されないように徹底してください。モデルレプリケーションの数は変更されても大丈夫ですが、チェックポイントから再開する際は、モデルのシャード度合いが以前のトレーニング設定と同一である必要があります。

```
import os
import torch.distributed as dist
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.sagemaker.distributed.checkpoint.state_dict_saver import (
    async_save,
    maybe_finalize_async_calls,
)
from torch.sagemaker.distributed.checkpoint.state_dict_utils import (
    sm_state_dict_type,
    SMStateDictType,
)

global_rank = dist.get_rank()
save_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}"

# 1. Get replication ranks and group
current_replication_group = None
current_replication_ranks = None
for replication_ranks in state.ranker.get_rep_groups():
    rep_group = dist.new_group(replication_ranks)
    if global_rank in replication_ranks:
        current_replication_group = rep_group
        current_replication_ranks = replication_ranks

coordinator_rank = min(current_replication_ranks)

# 2. Wait for the previous checkpointing done
maybe_finalize_async_calls(
    blocking=True, process_group=current_replication_group
)

# 3. Get model local checkpoint
with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT):
    state_dict = {
       "model": model.state_dict(),
       "optimizer": optimizer.state_dict(),
        # Potentially add more customized state dicts.
    }

# 4. Save a local checkpoint 
async_save(
    state_dict,
    checkpoint_id=os.path.join(save_dir, sub_dir),
    process_group=current_replication_group,
    coordinator_rank=coordinator_rank,
)
```

次のコードスニペットは、`SMStateDictType.SM_LOCAL_STATE_DICT` を使用してチェックポイントをロードする方法を示しています。

```
import os
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.sagemaker.distributed.checkpoint.state_dict_loader import load
from torch.sagemaker.distributed.checkpoint.state_dict_utils import (
    sm_state_dict_type,
    SMStateDictType,
    init_optim_state
)
from torch.sagemaker.distributed.checkpoint.filesystem import (
    DistributedFileSystemReader,
)

load_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}"
global_rank = dist.get_rank()
checkpoint_id = os.path.join(load_dir, sub_dir)
storage_reader = DistributedFileSystemReader(checkpoint_id)

# 1. Get replication ranks and group
current_replication_group = None
current_replication_ranks = None
for replication_ranks in state.ranker.get_rep_groups():
    rep_group = dist.new_group(replication_ranks)
    if global_rank in replication_ranks:
        current_replication_group = rep_group
        current_replication_ranks = replication_ranks

coordinator_rank = min(current_replication_ranks)

# 2. Create local state_dict
with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT):
    state_dict = {
        "model": model.state_dict(),
        # Potentially add more customized state dicts.
    }
 
    # Init optimizer state_dict states by setting zero grads and step.
    init_optim_state(optimizer, skip_empty_param=True)
    state_dict["optimizer"] = optimizer.state_dict()
 
# 3. Load a checkpoint
load(
    state_dict=state_dict,
    process_group=current_replication_group,
    coordinator_rank=coordinator_rank,
    storage_reader=storage_reader,
)
```

大規模言語モデル (LLM) のチェックポイントの保存には、多くの場合、大容量のファイルシステムの作成が必要になるため、コストがかかる可能性があります。コストを削減するために、Amazon FSx などの追加のファイルシステムサービスを使用する必要なく、チェックポイントを Amazon S3 に直接保存することができます。前述の例を基にして次のコードスニペットを使用し、S3 の URL を保存先として指定することで、チェックポイントを S3 に保存できます。

```
key = os.path.join(checkpoint_dir, sub_dir)
checkpoint_id= f"{{s3://{your_s3_bucket}/{key}}}"
async_save(state_dict, checkpoint_id=checkpoint_id, **kw)
load(state_dict, checkpoint_id=checkpoint_id, **kw)
```

## 非同期シャードチェックポイント
<a name="w2aac25c25c19c19c33b9"></a>

GPU の数を変えるなど、異なるハードウェア構成でトレーニングを続行しなければならない場合があります。このような場合、トレーニングプロセスではリシャーディング中にチェックポイントをロードする必要があります。つまり、後続のトレーニングを異なる数の `SHARD_DEGREE` で再開することになります。異なる数の `SHARD_DEGREE` でトレーニングを再開する必要がある状況に対応するには、シャーディングされた状態のディクショナリタイプ (`StateDictType.SHARDED_STATE_DICT`) でモデルのチェックポイントを保存する必要があります。この形式でチェックポイントを保存すると、変更後のハードウェア構成でトレーニングを続行するときに、リシャーディングプロセスを適切に処理できます。提示されているコードスニペットは、`tsm` API を使用してシャードチェックポイントを非同期的に保存し、より効率的で合理的なトレーニングプロセスを実現する方法を示しています。

```
import os
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.sagemaker.utils.process_group_utils import get_global_ranks
from torch.sagemaker.distributed.checkpoint.state_dict_saver import (
    async_save,
    maybe_finalize_async_calls,
)

save_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}"
checkpoint_id = os.path.join(save_dir, sub_dir)

# To determine whether curreto take part in checkpointing.
global_rank = dist.get_rank()
action_rank = state.ranker.get_rep_rank(global_rank) == 0
process_group = model.process_group
coordinator_rank = min(get_global_ranks(process_group))

# 1. wait for the previous checkpointing done
maybe_finalize_async_calls(blocking=True, process_group=process_group)

# 2. retrieve model & optimizer sharded state_dict
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
    state_dict = {
        "model": model.state_dict(),
        "optimizer": FSDP.optim_state_dict(model, optimizer),
        # Potentially add more customized state dicts.
    }
 
# 3. save checkpoints asynchronously using async_save
if action_rank:
    async_save(
        state_dict,
        checkpoint_id=checkpoint_id,
        process_group=process_group,
        coordinator_rank=coordinator_rank,
    )
```

共有チェックポイントをロードするプロセスは前のセクションと似ていますが、ここでは `torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader` とその `load` メソッドを使用しています。このクラスの `load` メソッドを使用すると、前述のプロセスに類似したプロセスに従って、共有チェックポイントデータをロードできます。

```
import os
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.sagemaker.distributed.checkpoint.state_dict_loader import load
from torch.sagemaker.utils.process_group_utils import get_global_ranks
from torch.sagemaker.distributed.checkpoint.filesystem import (
    DistributedFileSystemReader,
)
 
 load_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}"
checkpoint_id = os.path.join(load_dir, sub_dir)
reader = DistributedFileSystemReader(checkpoint_id)

process_group = model.process_group
coordinator_rank = min(get_global_ranks(process_group))

with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
   # 1. Load model and everything else except the optimizer.
   state_dict = {
        "model": model.state_dict()
        # Potentially more customized state dicts.
   }
   load(
        state_dict,
        storage_reader=reader,
        process_group=process_group,
        coordinator_rank=coordinator_rank,
   )
   model.load_state_dict(state_dict["model"])
 
   # 2. Load optimizer.
   optim_state = load_sharded_optimizer_state_dict(
        model_state_dict=state_dict["model"],
        optimizer_key="optimizer",
        storage_reader=reader,
        process_group=process_group,
    )    
   flattened_optimizer_state = FSDP.optim_state_dict_to_load(
        optim_state["optimizer"], model, optimizer,
         group=model.process_group
   )
   optimizer.load_state_dict(flattened_optimizer_state)
```

## フルモデルチェックポイント
<a name="model-parallel-core-features-v2-checkpoints-full"></a>

トレーニングの最後には、モデルのすべてのシャードを 1 つのモデルチェックポイントファイルにまとめたフルチェックポイントを保存できます。SMP ライブラリは PyTorch のフルモデルチェックポイント API を完全にサポートしているため、変更を加える必要はありません。

SMP [テンソル並列性](model-parallel-core-features-v2-tensor-parallelism.md) を使用する場合、SMP ライブラリはモデルを変換します。この場合、完全なモデルをチェックポイントする際に、モデルがデフォルトで元の Hugging Face Transformers チェックポイント形式に変換されます。

SMP のテンソル並列処理でトレーニングを行い、SMP の変換プロセスをオフにする場合は、PyTorch の `FullStateDictConfig` API の `translate_on_save` 引数を使用して、SMP の自動変換のオン/オフを適宜切り替えることができます。例えば、モデルのトレーニングに集中する場合は、追加のオーバーヘッドが生じる変換プロセスを追加する必要はありません。この場合は、`translate_on_save=False` を設定することをお勧めします。また、今後のトレーニングでモデルの SMP 変換を使用する予定がある場合は、オフに切り替えて、モデルの SMP 変換を後で使用するために保存できます。モデルのトレーニングを終了し、その後推論に使用する場合は、モデルを Hugging Face Transformers モデルチェックポイント形式に戻す必要があります。

```
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullStateDictConfig
import torch.sagemaker as tsm

# Save checkpoints.
with FSDP.state_dict_type(
    model, 
    StateDictType.FULL_STATE_DICT, 
    FullStateDictConfig(
        rank0_only=True, offload_to_cpu=True,
        # Default value is to translate back to Hugging Face Transformers format,
        # when saving full checkpoints for models trained with SMP tensor parallelism.
        # translate_on_save=True
    ),
):
    state_dict = model.state_dict()
    if dist.get_rank() == 0:
        logger.info("Processed state dict to save. Starting write to disk now.")
        os.makedirs({{save_dir}}, exist_ok=True)
        # This name is needed for HF from_pretrained API to work.
        torch.save(state_dict, os.path.join({{save_dir}}, "pytorch_model.bin"))
        hf_model_config.save_pretrained({{save_dir}})
    dist.barrier()
```

`FullStateDictConfig(rank0_only=True, offload_to_cpu=True)` オプションは、大型モデルのトレーニング時にメモリを節約するために、0 番ランクのデバイスの CPU にモデルを集約します。

推論のためにモデルを再ロードするには、次のコード例のように処理します。`AutoModelForSeq2SeqLM` クラスは、モデルによっては、Hugging Face Transformer の他のファクタービルダークラス (`AutoModelForCausalLM` など) に変更される場合があります。詳細については、[Hugging Face Transformers のドキュメント](https://huggingface.co/docs/transformers/v4.36.1/en/model_doc/auto#natural-language-processing)を参照してください。

```
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained({{save_dir}})
```