관리형 계층형 체크포인트 설정 - Amazon SageMaker AI

기계 번역으로 제공되는 번역입니다. 제공된 번역과 원본 영어의 내용이 상충하는 경우에는 영어 버전이 우선합니다.

관리형 계층형 체크포인트 설정

이 섹션에는 Amazon SageMaker HyperPod에 대한 관리형 계층형 체크포인트 설정 프로세스가 포함되어 있습니다. 클러스터에서 기능을 활성화하고 훈련 코드에서 체크포인트 지정을 구현하는 방법을 알아봅니다.

사전 조건

관리형 계층형 체크포인트를 설정하기 전에 다음을 갖추어야 합니다.

  • 체크포인트 할당에 사용할 수 있는 CPU 메모리가 충분한 Amazon EKS HyperPod 클러스터가 있어야 합니다.

  • PyTorch 훈련 워크로드 및 DCP 작업(둘 다 지원됨)이 있어야 합니다.

  • 다음을 포함한 클러스터 관리를 위한 적절한 IAM 권한이 있어야 합니다.

    • 훈련 포드가 체크포인트를 읽고 쓰고 지표를 푸시할 수 있는 Amazon CloudWatch 및 Amazon S3 쓰기 권한이 있어야 합니다.

    • 이러한 권한은 EKS OIDC 설정을 통해 구성할 수 있습니다.

1단계: 클러스터에 대한 관리형 계층형 체크포인트 활성화

중요

관리형 계층형 체크포인트를 사용하려면 옵트인해야 합니다.

클러스터를 생성하거나 업데이트할 때 HyperPod APIs 통해 관리형 계층형 체크포인트를 활성화합니다. TieredStorageConfig 파라미터를 지정하면 서비스가 메모리 관리 시스템을 자동으로 설치합니다.

새 클러스터의 경우를 사용할 수 있습니다create-cluster AWS CLI.

aws sagemaker create-cluster \ --cluster-name cluster-name \ --orchestrator "Eks={ClusterArn=eks-cluster-arn}" \ --instance-groups '{ "InstanceGroupName": "instance-group-name", "InstanceType": "instance-type", "InstanceCount": instance-count, "LifeCycleConfig": { "SourceS3Uri": "s3-path-to-lifecycle-scripts", "OnCreate": "lifecycle-script-name" }, "ExecutionRole": "instance-group-iam-role", "ThreadsPerCore": threads-per-core, "InstanceStorageConfigs": [ { "EbsVolumeConfig": {"VolumeSizeInGB": volume-size} } ] }' \ --vpc-config '{ "SecurityGroupIds": ["security-group-ids"], "Subnets": ["subnets"] }' \ --tiered-storage-config '{ "Mode": "Enable" }'

InstanceMemoryAllocationPercentage 파라미터는 체크포인트 지정에 할당할 클러스터 메모리의 percentage(int)를 지정합니다. 범위는 20~100입니다.

2단계: 훈련 이미지에 Python 라이브러리 설치

Amazon SageMaker 체크포인트 라이브러리와 해당 종속성을 Dockerfile에 추가하여 훈련 이미지에 설치합니다.

# Add this line to your training image Dockerfile RUN pip install amzn-sagemaker-checkpointing s3torchconnector tenacity torch boto3 s3torchconnector

3단계: 훈련 루프에 체크포인트 저장

훈련 루프에서 PyTorch DCP를 사용하여 체크포인트를 비동기적으로 저장할 수 있습니다. 다음은 이를 수행하는 방법에 대한 예제입니다.

import torch import torch.distributed as dist from torch.distributed.checkpoint import async_save, load from amzn_sagemaker_checkpointing.checkpointing.filesystem.filesystem import ( SageMakerTieredStorageWriter, SageMakerTieredStorageReader ) # Initialize distributed training dist.init_process_group(backend="nccl") # Configure checkpointing checkpoint_config = SageMakerCheckpointConfig( # Unique ID for your training job # Allowed characters in ID include: alphanumeric, hyphens, and underscores namespace=os.environ.get('TRAINING_JOB_NAME', f'job-{int(time.time())}'), # Number of distributed processes/available GPUs world_size=dist.get_world_size(), # S3 storage location, required for SageMakerTieredStorageReader for read fallbacks # Required for SageMakerTieredStorageWriter when save_to_s3 is True s3_tier_base_path="s3://my-bucket/checkpoints" ) # Your model and optimizer model = MyModel() optimizer = torch.optim.AdamW(model.parameters()) # Training loop future = None in_memory_ckpt_freq = 10 s3_ckpt_freq = 50 for training_step in range(1000): # ... training code ... # Save checkpoint if (training_step % in_memory_ckpt_freq == 0 or training_step % s3_ckpt_freq == 0): # Create state dictionary state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": training_step, "epoch": epoch } # Create storage writer for current step checkpoint_config.save_to_s3 = training_step % s3_ckpt_freq == 0 storage_writer = SageMakerTieredStorageWriter( checkpoint_config=checkpoint_config, step=training_step ) # wait for previous checkpoint to get completed if future is not None: exc = future.exception() if exc: print(f"Failure in saving previous checkpoint:{str(exc)}") # Handle failures as required else: result = future.result() # Process results from save, if required # Async save checkpoint using PyTorch DCP future = async_save(state_dict=state_dict, storage_writer=storage_writer) # Continue training while checkpoint saves in background

4단계: 복구를 위한 로드 체크포인트

다음은 체크포인트를 로드하는 예제입니다.

# Create state dictionary template state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": 0, "epoch": 0 } # Load latest checkpoint storage_reader = SageMakerTieredStorageReader(checkpoint_config=checkpoint_config) load(state_dict, storage_reader=storage_reader) # Load specific checkpoint step storage_reader = SageMakerTieredStorageReader( checkpoint_config=checkpoint_config, step=500 # Or don't pass step if you have to load the latest available step. ) try: load(state_dict, storage_reader=storage_reader) except BaseException as e: print(f"Checkpoint load failed: {str(e)}") # Add additional exception handling

관리형 계층형 체크포인트 작업 검증

로그를 사용하여 관리형 계층형 체크포인트 작업을 검증할 수 있습니다.

사용자 지정 로깅(선택 사항)

사용자 지정 로거를 라이브러리에 전달하여 체크포인트 지정 로그를 다른 로그와 통합할 수 있습니다. 예를 들어 라이브러리의 모든 로그가 훈련 로거에도 수집되도록 훈련 코드에 사용자 지정 로거를 추가할 수 있습니다.

향상된 서비스 로깅(선택 사항)

디버깅 및 서비스 가시성을 높이기 위해 포드 내에서 호스트의 경로 /var/logs/sagemaker_checkpointing으로 체크포인트 로그 경로 /var/log/sagemaker_checkpointing을 탑재할 수 있습니다. 이렇게 하면 라이브러리에 한정된 로그만 별도로 수집됩니다. 이를 통해 서비스 팀은 디버깅 및 지원에 대한 향상된 가시성을 확보할 수 있습니다.