Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.
Siapkan pos pemeriksaan berjenjang terkelola
Bagian ini berisi proses penyiapan untuk pos pemeriksaan berjenjang terkelola untuk Amazon. SageMaker HyperPod Anda akan mempelajari cara mengaktifkan kemampuan pada cluster Anda dan menerapkan checkpointing dalam kode pelatihan Anda.
Topik
Prasyarat
Sebelum menyiapkan pos pemeriksaan berjenjang terkelola, pastikan Anda memiliki:
-
HyperPod Cluster Amazon EKS dengan memori CPU yang cukup tersedia untuk alokasi pos pemeriksaan
-
PyTorch beban kerja pelatihan dan pekerjaan DCP (keduanya didukung)
-
Izin IAM yang sesuai untuk manajemen klaster, termasuk:
-
Amazon CloudWatch dan Amazon S3 menulis izin untuk pod pelatihan untuk membaca/menulis pos pemeriksaan dan mendorong metrik
-
Izin ini dapat dikonfigurasi melalui pengaturan EKS OIDC
-
Langkah 1: Aktifkan pos pemeriksaan berjenjang terkelola untuk klaster Anda
penting
Anda harus memilih untuk menggunakan pos pemeriksaan berjenjang terkelola.
Aktifkan pos pemeriksaan berjenjang terkelola melalui HyperPod APIs saat membuat atau memperbarui klaster Anda. Layanan secara otomatis menginstal sistem manajemen memori ketika Anda menentukan TieredStorageConfig parameter.
Untuk cluster baru, Anda dapat menggunakan create-cluster AWS CLI.
aws sagemaker create-cluster \ --cluster-namecluster-name\ --orchestrator "Eks={ClusterArn=eks-cluster-arn}" \ --instance-groups '{ "InstanceGroupName": "instance-group-name", "InstanceType": "instance-type", "InstanceCount":instance-count, "LifeCycleConfig": { "SourceS3Uri": "s3-path-to-lifecycle-scripts", "OnCreate": "lifecycle-script-name" }, "ExecutionRole": "instance-group-iam-role", "ThreadsPerCore":threads-per-core, "InstanceStorageConfigs": [ { "EbsVolumeConfig": {"VolumeSizeInGB":volume-size} } ] }' \ --vpc-config '{ "SecurityGroupIds": ["security-group-ids"], "Subnets": ["subnets"] }' \ --tiered-storage-config '{ "Mode": "Enable" }'
InstanceMemoryAllocationPercentageParameter menentukan (int) memori cluster untuk mengalokasikan untuk checkpointing. Kisarannya 20-100.percentage
Langkah 2: Instal pustaka Python di gambar pelatihan Anda
Instal library SageMaker checkpointing Amazon
# Add this line to your training image Dockerfile RUN pip install amzn-sagemaker-checkpointing s3torchconnector tenacity torch boto3 s3torchconnector
Langkah 3: Simpan pos pemeriksaan di loop pelatihan Anda
Dalam loop pelatihan Anda, Anda dapat menyimpan pos pemeriksaan secara asinkron menggunakan DCP. PyTorch Berikut ini adalah contoh bagaimana melakukannya.
import torch import torch.distributed as dist from torch.distributed.checkpoint import async_save, load from amzn_sagemaker_checkpointing.checkpointing.filesystem.filesystem import ( SageMakerTieredStorageWriter, SageMakerTieredStorageReader ) # Initialize distributed training dist.init_process_group(backend="nccl") # Configure checkpointing checkpoint_config = SageMakerCheckpointConfig( # Unique ID for your training job # Allowed characters in ID include: alphanumeric, hyphens, and underscores namespace=os.environ.get('TRAINING_JOB_NAME', f'job-{int(time.time())}'), # Number of distributed processes/available GPUs world_size=dist.get_world_size(), # S3 storage location, required for SageMakerTieredStorageReader for read fallbacks # Required for SageMakerTieredStorageWriter when save_to_s3 is True s3_tier_base_path="s3://my-bucket/checkpoints" ) # Your model and optimizer model = MyModel() optimizer = torch.optim.AdamW(model.parameters()) # Training loop future = None in_memory_ckpt_freq = 10 s3_ckpt_freq = 50 for training_step in range(1000): # ... training code ... # Save checkpoint if (training_step % in_memory_ckpt_freq == 0 or training_step % s3_ckpt_freq == 0): # Create state dictionary state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": training_step, "epoch": epoch } # Create storage writer for current step checkpoint_config.save_to_s3 = training_step % s3_ckpt_freq == 0 storage_writer = SageMakerTieredStorageWriter( checkpoint_config=checkpoint_config, step=training_step ) # wait for previous checkpoint to get completed if future is not None: exc = future.exception() if exc: print(f"Failure in saving previous checkpoint:{str(exc)}") # Handle failures as required else: result = future.result() # Process results from save, if required # Async save checkpoint using PyTorch DCP future = async_save(state_dict=state_dict, storage_writer=storage_writer) # Continue training while checkpoint saves in background
Langkah 4: Muat pos pemeriksaan untuk pemulihan
Berikut ini adalah contoh tentang memuat pos pemeriksaan.
# Create state dictionary template state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "step": 0, "epoch": 0 } # Load latest checkpoint storage_reader = SageMakerTieredStorageReader(checkpoint_config=checkpoint_config) load(state_dict, storage_reader=storage_reader) # Load specific checkpoint step storage_reader = SageMakerTieredStorageReader( checkpoint_config=checkpoint_config, step=500 # Or don't pass step if you have to load the latest available step. ) try: load(state_dict, storage_reader=storage_reader) except BaseException as e: print(f"Checkpoint load failed: {str(e)}") # Add additional exception handling
Validasi operasi pos pemeriksaan berjenjang terkelola
Anda dapat memvalidasi operasi pos pemeriksaan berjenjang terkelola dengan log.
Logging kustom (opsional)
Anda dapat mengintegrasikan log checkpointing dengan log lain dengan meneruskan logger kustom ke perpustakaan. Misalnya, Anda dapat menambahkan logger kustom ke kode pelatihan Anda sehingga semua log dari perpustakaan juga dikumpulkan dalam logger pelatihan.
Pencatatan layanan yang disempurnakan (opsional)
Untuk meningkatkan debugging dan visibilitas layanan, Anda dapat memasang jalur log checkpointing /var/log/sagemaker_checkpointing dari dalam pod Anda ke jalur di host Anda. /var/logs/sagemaker_checkpointing Ini memastikan bahwa hanya log khusus perpustakaan yang dikumpulkan secara terpisah. Ini memberi tim layanan visibilitas yang ditingkatkan untuk debugging dan dukungan.