マネージド階層型チェックポイントの設定 - Amazon SageMaker AI

翻訳は機械翻訳により提供されています。提供された翻訳内容と英語版の間で齟齬、不一致または矛盾がある場合、英語版が優先します。

マネージド階層型チェックポイントの設定

このセクションでは、Amazon SageMaker HyperPod のマネージド階層型チェックポイントの設定プロセスについて説明します。クラスターで機能を有効にし、トレーニングコードにチェックポイントを実装する方法を説明します。

前提条件

マネージド階層型チェックポイントを設定する前に、以下を確認してください。

  • チェックポイントの割り当てに使用できる十分な CPU メモリを備えた Amazon EKS HyperPod クラスター

  • PyTorch トレーニングワークロードと DCP ジョブ (どちらもサポート対象)

  • 以下を含めた、クラスター管理のための適切な IAM アクセス許可:

    • トレーニングポッドがチェックポイントの読み取り/書き込みとメトリクスのプッシュを行うための、Amazon CloudWatch および Amazon S3 への書き込みアクセス許可

    • これらのアクセス許可は、「EKS OIDC の設定」を介して設定できます。

ステップ 1: クラスターのマネージド階層型チェックポイントを有効にする

重要

マネージド階層型チェックポイントを使用するには、オプトインする必要があります。

クラスターを作成または更新するときに、HyperPod APIs を介してマネージド階層型チェックポイントを有効にします。TieredStorageConfig パラメータを指定すると、サービスはメモリ管理システムを自動的にインストールします。

新しいクラスターでは、 を使用できますcreate-cluster AWS CLI。

aws sagemaker create-cluster \ --cluster-name cluster-name \ --orchestrator "Eks={ClusterArn=eks-cluster-arn}" \ --instance-groups '{ "InstanceGroupName": "instance-group-name", "InstanceType": "instance-type", "InstanceCount": instance-count, "LifeCycleConfig": { "SourceS3Uri": "s3-path-to-lifecycle-scripts", "OnCreate": "lifecycle-script-name" }, "ExecutionRole": "instance-group-iam-role", "ThreadsPerCore": threads-per-core, "InstanceStorageConfigs": [ { "EbsVolumeConfig": {"VolumeSizeInGB": volume-size} } ] }' \ --vpc-config '{ "SecurityGroupIds": ["security-group-ids"], "Subnets": ["subnets"] }' \ --tiered-storage-config '{ "Mode": "Enable" }'

InstanceMemoryAllocationPercentage パラメータは、チェックポイントに割り当てるクラスターメモリの percentage (int) を指定します。範囲は 20~100 です。

ステップ 2: トレーニングイメージに Python ライブラリをインストールする

Dockerfile に追加して、Amazon SageMaker チェックポイントライブラリとその依存関係をトレーニングイメージにインストールします。

# Add this line to your training image Dockerfile RUN pip install amzn-sagemaker-checkpointing s3torchconnector tenacity torch boto3 s3torchconnector

ステップ 3: トレーニングループにチェックポイントを保存する

トレーニングループでは、PyTorch DCP を使用してチェックポイントを非同期的に保存できます。これを行う方法の例を次に示します。

import torch import torch.distributed as dist from torch.distributed.checkpoint import async_save, load from amzn_sagemaker_checkpointing.checkpointing.filesystem.filesystem import ( SageMakerTieredStorageWriter, SageMakerTieredStorageReader ) # Initialize distributed training dist.init_process_group(backend="nccl") # Configure checkpointing checkpoint_config = SageMakerCheckpointConfig( # Unique ID for your training job # Allowed characters in ID include: alphanumeric, hyphens, and underscores namespace=os.environ.get('TRAINING_JOB_NAME', f'job-{int(time.time())}'), # Number of distributed processes/available GPUs world_size=dist.get_world_size(), # S3 storage location, required for SageMakerTieredStorageReader for read fallbacks # Required for SageMakerTieredStorageWriter when save_to_s3 is True s3_tier_base_path="s3://my-bucket/checkpoints" ) # Your model and optimizer model = MyModel() optimizer = torch.optim.AdamW(model.parameters()) # Training loop future = None in_memory_ckpt_freq = 10 s3_ckpt_freq = 50 for training_step in range(1000): # ... training code ... # Save checkpoint if (training_step % in_memory_ckpt_freq == 0 or training_step % s3_ckpt_freq == 0): # Create state dictionary state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": training_step, "epoch": epoch } # Create storage writer for current step checkpoint_config.save_to_s3 = training_step % s3_ckpt_freq == 0 storage_writer = SageMakerTieredStorageWriter( checkpoint_config=checkpoint_config, step=training_step ) # wait for previous checkpoint to get completed if future is not None: exc = future.exception() if exc: print(f"Failure in saving previous checkpoint:{str(exc)}") # Handle failures as required else: result = future.result() # Process results from save, if required # Async save checkpoint using PyTorch DCP future = async_save(state_dict=state_dict, storage_writer=storage_writer) # Continue training while checkpoint saves in background

ステップ 4: 復旧用のチェックポイントをロードする

チェックポイントをロードする例を次に示します。

# Create state dictionary template state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": 0, "epoch": 0 } # Load latest checkpoint storage_reader = SageMakerTieredStorageReader(checkpoint_config=checkpoint_config) load(state_dict, storage_reader=storage_reader) # Load specific checkpoint step storage_reader = SageMakerTieredStorageReader( checkpoint_config=checkpoint_config, step=500 # Or don't pass step if you have to load the latest available step. ) try: load(state_dict, storage_reader=storage_reader) except BaseException as e: print(f"Checkpoint load failed: {str(e)}") # Add additional exception handling

マネージド階層型チェックポイントオペレーションを検証する

ログを使用して、マネージド階層型チェックポイントオペレーションを検証できます。

カスタムログ記録 (オプション)

カスタムロガーをライブラリに渡すことで、チェックポイントログを他のログと統合できます。例えば、カスタムロガーをトレーニングコードに追加して、ライブラリのすべてのログがトレーニングロガーにも収集されるようにできます。

拡張サービスログ記録 (オプション)

デバッグとサービスの可視性を強化するために、ポッド内のチェックポイントログパス /var/log/sagemaker_checkpointing をホスト上のパス /var/logs/sagemaker_checkpointing にマウントできます。これにより、ライブラリ固有のログのみが個別に収集されることになり、サービスチームはデバッグとサポートの可視性を向上させることができます。