設定受管分層檢查點 - Amazon SageMaker AI

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

設定受管分層檢查點

本節包含 Amazon SageMaker HyperPod 受管分層檢查點的設定程序。您將了解如何在叢集上啟用功能,並在訓練程式碼中實作檢查點。

先決條件

在設定受管分層檢查點之前,請確定您已:

  • 具有足夠 CPU 記憶體可用於檢查點配置的 Amazon EKS HyperPod 叢集

  • PyTorch 訓練工作負載和 DCP 任務 (兩者都受到支援)

  • 叢集管理的適當 IAM 許可,包括:

    • 訓練 Pod 的 Amazon CloudWatch 和 Amazon S3 寫入許可,用於讀取/寫入檢查點和推送指標

    • 這些許可可以透過 EKS OIDC 設定進行設定

步驟 1:為您的叢集啟用受管分層檢查點

重要

您必須選擇加入,才能使用受管分層檢查點。

在建立或更新叢集時,透過 HyperPod APIs啟用受管分層檢查點。當您指定 TieredStorageConfig 參數時,服務會自動安裝記憶體管理系統。

對於新的叢集,您可以使用 create-cluster AWS CLI。

aws sagemaker create-cluster \ --cluster-name cluster-name \ --orchestrator "Eks={ClusterArn=eks-cluster-arn}" \ --instance-groups '{ "InstanceGroupName": "instance-group-name", "InstanceType": "instance-type", "InstanceCount": instance-count, "LifeCycleConfig": { "SourceS3Uri": "s3-path-to-lifecycle-scripts", "OnCreate": "lifecycle-script-name" }, "ExecutionRole": "instance-group-iam-role", "ThreadsPerCore": threads-per-core, "InstanceStorageConfigs": [ { "EbsVolumeConfig": {"VolumeSizeInGB": volume-size} } ] }' \ --vpc-config '{ "SecurityGroupIds": ["security-group-ids"], "Subnets": ["subnets"] }' \ --tiered-storage-config '{ "Mode": "Enable" }'

InstanceMemoryAllocationPercentage 參數指定要針對檢查點配置的叢集記憶體 percentage (int)。範圍為 20-100。

步驟 2:在您的訓練映像中安裝 Python 程式庫

Amazon SageMaker 檢查點程式庫及其相依性新增至您的 Dockerfile,以將其安裝在訓練映像中:

# Add this line to your training image Dockerfile RUN pip install amzn-sagemaker-checkpointing s3torchconnector tenacity torch boto3 s3torchconnector

步驟 3:在訓練迴圈中儲存檢查點

在訓練迴圈中,您可以使用 PyTorch DCP 非同步儲存檢查點。以下是如何執行此操作的範例。

import torch import torch.distributed as dist from torch.distributed.checkpoint import async_save, load from amzn_sagemaker_checkpointing.checkpointing.filesystem.filesystem import ( SageMakerTieredStorageWriter, SageMakerTieredStorageReader ) # Initialize distributed training dist.init_process_group(backend="nccl") # Configure checkpointing checkpoint_config = SageMakerCheckpointConfig( # Unique ID for your training job # Allowed characters in ID include: alphanumeric, hyphens, and underscores namespace=os.environ.get('TRAINING_JOB_NAME', f'job-{int(time.time())}'), # Number of distributed processes/available GPUs world_size=dist.get_world_size(), # S3 storage location, required for SageMakerTieredStorageReader for read fallbacks # Required for SageMakerTieredStorageWriter when save_to_s3 is True s3_tier_base_path="s3://my-bucket/checkpoints" ) # Your model and optimizer model = MyModel() optimizer = torch.optim.AdamW(model.parameters()) # Training loop future = None in_memory_ckpt_freq = 10 s3_ckpt_freq = 50 for training_step in range(1000): # ... training code ... # Save checkpoint if (training_step % in_memory_ckpt_freq == 0 or training_step % s3_ckpt_freq == 0): # Create state dictionary state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": training_step, "epoch": epoch } # Create storage writer for current step checkpoint_config.save_to_s3 = training_step % s3_ckpt_freq == 0 storage_writer = SageMakerTieredStorageWriter( checkpoint_config=checkpoint_config, step=training_step ) # wait for previous checkpoint to get completed if future is not None: exc = future.exception() if exc: print(f"Failure in saving previous checkpoint:{str(exc)}") # Handle failures as required else: result = future.result() # Process results from save, if required # Async save checkpoint using PyTorch DCP future = async_save(state_dict=state_dict, storage_writer=storage_writer) # Continue training while checkpoint saves in background

步驟 4:載入用於復原的檢查點

以下是載入檢查點的範例。

# Create state dictionary template state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": 0, "epoch": 0 } # Load latest checkpoint storage_reader = SageMakerTieredStorageReader(checkpoint_config=checkpoint_config) load(state_dict, storage_reader=storage_reader) # Load specific checkpoint step storage_reader = SageMakerTieredStorageReader( checkpoint_config=checkpoint_config, step=500 # Or don't pass step if you have to load the latest available step. ) try: load(state_dict, storage_reader=storage_reader) except BaseException as e: print(f"Checkpoint load failed: {str(e)}") # Add additional exception handling

驗證您的受管分層檢查點操作

您可以使用 日誌驗證受管分層檢查點操作。

自訂記錄 (選用)

您可以透過將自訂記錄器傳遞至程式庫,將檢查點日誌與其他日誌整合。例如,您可以將自訂記錄器新增至訓練程式碼,以便也會在訓練記錄器中收集程式庫中的所有日誌。

增強式服務記錄 (選用)

如需增強偵錯和服務可見性,您可以將檢查點日誌路徑 /var/log/sagemaker_checkpointing 從 Pod 內掛載到主機上的路徑 /var/logs/sagemaker_checkpointing。這可確保僅單獨收集程式庫特定的日誌。這可為服務團隊提供增強的偵錯和支援可見性。