Configurar pontos de verificação de nível gerenciado
Esta seção contém o processo de configuração de pontos de verificação de nível gerenciado para o Amazon SageMaker HyperPod. Você aprenderá a habilitar o recurso no cluster e implementar pontos de verificação no código de treinamento.
Tópicos
Pré-requisitos
Você deve atender aos seguintes pré-requisitos antes de configurar o ponto de verificação de nível gerenciado:
-
Um cluster do HyperPod orquestrado pelo Amazon EKS com memória de CPU suficiente disponível para alocação de pontos de verificação.
-
Workloads de treinamento do PyTorch e tarefas de DCP (ambas são permitidas).
-
Permissões do IAM apropriadas para gerenciamento de clusters, como:
-
Permissões de gravação do Amazon CloudWatch e do Amazon S3 para o pod de treinamento ler/gravar pontos de verificação e enviar métricas.
-
Essas permissões podem ser definidas por meio da configuração do OIDC do EKS.
-
Etapa 1: habilitar o ponto de verificação de nível gerenciado para o cluster
Importante
Você deve optar por usar o ponto de verificação de nível gerenciado.
Habilite o ponto de verificação de nível gerenciado por meio da API do HyperPod ao criar ou atualizar o cluster. O serviço instala automaticamente o sistema de gerenciamento de memória quando você especifica o parâmetro TieredStorageConfig. Para a criação de clusters:
aws sagemaker update-cluster \ --cluster-name my-training-cluster \ --tiered-storage-config { "Mode": "Enable" "InstanceMemoryAllocationPercentage":percentage}
O parâmetro InstanceMemoryAllocationPercentage especifica a (int) da memória do cluster a ser alocada ao ponto de verificação. O intervalo é 20-100.percentage
Etapa 2: instalar a biblioteca do Python na imagem de treinamento
Instale a biblioteca de pontos de verificação do Amazon SageMaker
# Add this line to your training image Dockerfile RUN pip install amzn-sagemaker-checkpointing
Etapa 3: criar uma configuração de ponto de verificação
Crie um objeto CheckpointConfig para especificar o comportamento do ponto de verificação. Isso inclui:
-
Locais dos pontos de verificação
-
Frequência dos pontos de verificação
-
Nome dos namespaces
O seguinte exemplo é uma amostra de configuração:
from amzn_sagemaker_checkpointing.config.sagemaker_checkpoint_config import SageMakerCheckpointConfig from amzn_sagemaker_checkpointing.checkpointing.filesystem import SageMakerTieredStorageWriter, SageMakerTieredStorageReader checkpoint_config = sm_ckpt.CheckpointConfig( world_size = 100, in_memory_namespace:my-ml-workload, # Logical grouping for checkpoints s3_base_path: "s3://bucket-name/checkpointing-path-prefix/", s3_every_n_steps: 100, # Every 100 steps, save to S3 )
Etapa 4: definir um gravador do sistema de arquivos do SageMaker
Defina o gravador do sistema de arquivos do ponto de verificação. Você tem a opção de especificar o número da etapa durante a inicialização.
Gravador básico (etapa especificada em salvar chamada):
smWriter = sagemaker_checkpointing.SageMakerTieredStorageWriter(checkpoint_config)
Gravador com o parâmetro da etapa (a etapa especificada na inicialização):
smWriter = sagemaker_checkpointing.SageMakerTieredStorageWriter( checkpoint_config, step=step_number )
nota
Quando você especifica o parâmetro step durante a inicialização do gravador, o parâmetro checkpoint_id na chamada de salvamento se torna opcional. O parâmetro da etapa tem precedência sobre o formato do diretório do ponto de verificação.
Etapa 5: salvar pontos de verificação no ciclo de treinamento
No ciclo de treinamento, salve os pontos de verificação usando o DCP do PyTorch com FileSystemWriter.
Usar o DCP do PyTorch com FileSystemWriter
Chame o método dist_cp.save() com FileSystemWriter como entrada:
Opção 1: usar checkpoint_id com formato de etapa (quando a etapa não for especificada no gravador)
# Construct checkpoint directory with step number checkpoint_dir = f"step_number" dist_cp.save_state_dict( state_dict=state_dict, # state_dict is a dictionary containing model parameters, optimizer state, etc. checkpoint_id=checkpoint_dir, # Should contain step number storage_writer=smWriter )
Opção 2: usar o gravador com o parâmetro da etapa (o checkpoint_id se torna opcional)
dist_cp.save_state_dict( state_dict=state_dict, storage_writer=smWriter # Step already specified in writer initialization )
nota
O valor checkpoint_id (ou a string checkpoint_dir) deve ter o formato step_. Por exemplo, numberstep_5. Ao usar o parâmetro da etapa na inicialização do gravador, o checkpoint_id se torna opcional.
Etapa 6: carregar pontos de verificação para recuperação
Quando precisar carregar um ponto de verificação, use o DCP do PyTorch com FileSystemReader.
Usar o DCP do PyTorch com FileSystemReader
Chame o método para carregar o DCP com FileSystemReader como entrada:
# Define FileSystemReader smReader = sagemaker_checkpointing.SageMakerTieredStorageReader( config=checkpoint_config ) # Load checkpoint dist_cp.load_state_dict( state_dict=state_dict, checkpoint_id=checkpoint_dir, storage_reader=smReader )
Monitoramento e validação
É possível monitorar e validar operações de ponto de verificação de nível gerenciado por meio de métricas e logs.
Registro em log personalizado (opcional)
Você pode integrar logs de ponto de verificação a outros logs transferindo um agente de log personalizado para a biblioteca. Por exemplo, é possível adicionar um agente de log personalizado ao código de treinamento para que todos os logs da biblioteca também sejam coletados no agente de log de treinamento.
Registro em log aprimorado do serviço (opcional)
Para melhorar a depuração e a visibilidade do serviço, é possível montar o caminho de log do ponto de verificação /var/log/sagemaker_checkpointing dentro do pod até um caminho /var/logs/sagemaker_checkpointing no host. Isso garante que somente os logs específicos da biblioteca sejam coletados separadamente. Isso oferece visibilidade aprimorada à equipe do serviço para depuração e suporte.