设置托管层检查点 - Amazon SageMaker AI

设置托管层检查点

此部分包含 Amazon SageMaker HyperPod 的托管层检查点的设置过程。您将了解如何在集群上启用该功能并在训练代码中实现检查点。

先决条件

在设置托管层检查点之前,请确保您具有:

  • 一个具备可用于检查点分配的充足 CPU 内存的 Amazon EKS HyperPod 集群

  • PyTorch 训练工作负载和 DCP 作业(两者都受支持)

  • 用于管理集群的相应 IAM 权限,包括:

    • Amazon CloudWatch 和 Amazon S3 写入权限,用于训练容器组(pod)以读取/写入检查点和推送指标

    • 这些权限可通过 EKS OIDC 设置进行配置

步骤 1:为集群启用托管层检查点

重要

您必须选择使用托管层检查点。

在创建或更新集群时,通过 HyperPod API 启用托管层检查点。当您指定 TieredStorageConfig 参数时,该服务会自动安装内存管理系统。对于新集群,请创建:

aws sagemaker update-cluster \ --cluster-name my-training-cluster \ --tiered-storage-config { "Mode": "Enable" "InstanceMemoryAllocationPercentage": percentage }

InstanceMemoryAllocationPercentage 参数指定要为检查点分配的集群内存的 percentage(int)。范围是 20-100。

步骤 2:在训练映像中安装 Python 库

通过将 Amazon SageMaker 检查点库添加到 Dockerfile 中,将其安装到训练映像中:

# Add this line to your training image Dockerfile RUN pip install amzn-sagemaker-checkpointing

步骤 3:创建检查点配置

创建一个 CheckpointConfig 对象来指定检查点行为。这包括:

  • 检查点位置

  • 检查点的频率

  • 命名空间的名称

以下示例显示了一个检查点配置:

from amzn_sagemaker_checkpointing.config.sagemaker_checkpoint_config import SageMakerCheckpointConfig from amzn_sagemaker_checkpointing.checkpointing.filesystem import SageMakerTieredStorageWriter, SageMakerTieredStorageReader checkpoint_config = sm_ckpt.CheckpointConfig( world_size = 100, in_memory_namespace: my-ml-workload, # Logical grouping for checkpoints s3_base_path: "s3://bucket-name/checkpointing-path-prefix/", s3_every_n_steps: 100, # Every 100 steps, save to S3 )

步骤 4:定义 SageMaker 文件系统写入器

定义检查点文件系统写入器。您可以选择在初始化期间指定步骤编号。

基础写入器(在 save 调用中指定的 step):

smWriter = sagemaker_checkpointing.SageMakerTieredStorageWriter(checkpoint_config)

带 step 参数的写入器(在初始化时指定的 step):

smWriter = sagemaker_checkpointing.SageMakerTieredStorageWriter( checkpoint_config, step=step_number )
注意

当您在写入器初始化期间指定 step 参数时,save 调用中的 checkpoint_id 参数将变为可选参数。step 参数的优先级高于检查点目录格式。

步骤 5:将检查点保存在训练循环中

在训练循环中,将 PyTorch DCP 与 FileSystemWriter 结合使用来保存检查点。

将 PyTorch DCP 与 FileSystemWriter 结合使用

调用 dist_cp.save() 方法,并使用 FileSystemWriter 作为输入:

选项 1:将 checkpoint_id 与 step 格式结合使用(当未在写入器中指定 step 时)

# Construct checkpoint directory with step number checkpoint_dir = f"step_number" dist_cp.save_state_dict( state_dict=state_dict, # state_dict is a dictionary containing model parameters, optimizer state, etc. checkpoint_id=checkpoint_dir, # Should contain step number storage_writer=smWriter )

选项 2:将写入器与 step 参数结合使用(checkpoint_id 变为可选参数)

dist_cp.save_state_dict( state_dict=state_dict, storage_writer=smWriter # Step already specified in writer initialization )
注意

checkpoint_id 值(或 checkpoint_dir 字符串)的格式必须为 step_number。例如 step_5。在写入器初始化期间使用 step 参数时,checkpoint_id 会变为可选参数。

步骤 6:加载检查点以进行恢复

当您需要加载检查点时,可将 PyTorch DCP 与 FileSystemReader 结合使用。

将 PyTorch DCP 与 FileSystemReader 结合使用

调用 DCP 加载方法,并使用 FileSystemReader 作为输入:

# Define FileSystemReader smReader = sagemaker_checkpointing.SageMakerTieredStorageReader( config=checkpoint_config ) # Load checkpoint dist_cp.load_state_dict( state_dict=state_dict, checkpoint_id=checkpoint_dir, storage_reader=smReader )

监控和验证

您可以通过指标和日志监控和验证托管层检查点操作。

自定义日志记录(可选)

您可以通过将自定义记录器传递给库来将检查点日志与其他日志集成。例如,您可以将自定义记录器添加到训练代码,这样库中的所有日志也会被收集到训练记录器中。

增强型服务日志记录(可选)

要增强调试和服务可见性,可以将检查点日志路径 /var/log/sagemaker_checkpointing 从容器组(pod)中挂载到主机上的路径 /var/logs/sagemaker_checkpointing。这可确保仅单独收集库特定的日志,并为服务团队提供更高的调试和支持可见性。