Checkpointing con SMP - Amazon SageMaker AI

Le traduzioni sono generate tramite traduzione automatica. In caso di conflitto tra il contenuto di una traduzione e la versione originale in Inglese, quest'ultima prevarrà.

Checkpointing con SMP

La libreria SageMaker Model Parallelism (SMP) supporta i checkpoint e fornisce APIs questo supporto PyTorch APIs per il corretto funzionamento del checkpoint durante l'utilizzo della libreria SMP.

PyTorch FSDP (Fully Sharded Data Parallelism) supporta tre tipi di checkpoint: completi, frammentati e locali, ognuno dei quali serve a scopi diversi. I checkpoint completi vengono utilizzati quando si esporta il modello dopo il completamento dell’addestramento, poiché la generazione di un checkpoint completo è un processo impegnativo dal punto di vista del calcolo. I checkpoint suddivisi aiutano a salvare e caricare lo stato di un modello con sharding per ogni singola classificazione. Con i checkpoint condivisi, puoi riprendere l'allenamento con diverse configurazioni hardware, ad esempio un numero diverso di. GPUs Tuttavia, il caricamento di checkpoint sottoposti a sharding può essere lento, poiché la comunicazione coinvolge più dispositivi. La libreria SMP fornisce funzionalità di checkpoint locali, che permettono un recupero più rapido dello stato del modello senza l’ulteriore sovraccarico delle comunicazioni. Tieni presente che i checkpoint creati da FSDP richiedono la scrittura su un file system di rete condiviso come Amazon. FSx

Checkpoint locali asincroni

Durante l’addestramento dei modelli di machine learning, non è necessario attendere che i file di checkpoint vengano salvati su disco nelle iterazioni successive. Con il rilascio di SMP v2.5, la libreria supporta il salvataggio asincrono dei file di checkpoint. Ciò significa che la successiva iterazione dell'addestramento può essere eseguita contemporaneamente alle operazioni di input e output (). I/O) operations for creating checkpoints, without being slowed down or held back by those I/O Inoltre, il processo di recupero dei parametri del modello frammentato e dell'ottimizzatore PyTorch può richiedere molto tempo a causa della comunicazione collettiva aggiuntiva necessaria per lo scambio di metadati tensoriali distribuiti tra i ranghi. Anche quando viene utilizzato StateDictType.LOCAL_STATE_DICT per salvare i checkpoint locali per ogni rango, richiama comunque gli hook che eseguono comunicazioni collettive. PyTorch Per attenuare questo problema e ridurre il tempo necessario per il recupero dei checkpoint, SMP introduce SMStateDictType.SM_LOCAL_STATE_DICT, che consente un recupero più rapido dei checkpoint dell’ottimizzatore e del modello evitando il sovraccarico causato dalla comunicazione collettiva.

Nota

Mantenere la coerenza nel SHARD_DEGREE FSDP è un requisito per l’utilizzo di SMStateDictType.SM_LOCAL_STATE_DICT. Assicurati che SHARD_DEGREE rimanga invariato. Sebbene il numero di repliche del modello possa variare, quando si riprende da un checkpoint il grado di sharding del modello deve essere identico alla configurazione di addestramento precedente.

import os import torch.distributed as dist import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, ) global_rank = dist.get_rank() save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Wait for the previous checkpointing done maybe_finalize_async_calls( blocking=True, process_group=current_replication_group ) # 3. Get model local checkpoint with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), # Potentially add more customized state dicts. } # 4. Save a local checkpoint async_save( state_dict, checkpoint_id=os.path.join(save_dir, sub_dir), process_group=current_replication_group, coordinator_rank=coordinator_rank, )

Il frammento di codice riportato di seguito mostra come caricare un checkpoint utilizzando SMStateDictType.SM_LOCAL_STATE_DICT.

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.distributed.checkpoint.state_dict_utils import ( sm_state_dict_type, SMStateDictType, init_optim_state ) from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}" global_rank = dist.get_rank() checkpoint_id = os.path.join(load_dir, sub_dir) storage_reader = DistributedFileSystemReader(checkpoint_id) # 1. Get replication ranks and group current_replication_group = None current_replication_ranks = None for replication_ranks in state.ranker.get_rep_groups(): rep_group = dist.new_group(replication_ranks) if global_rank in replication_ranks: current_replication_group = rep_group current_replication_ranks = replication_ranks coordinator_rank = min(current_replication_ranks) # 2. Create local state_dict with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT): state_dict = { "model": model.state_dict(), # Potentially add more customized state dicts. } # Init optimizer state_dict states by setting zero grads and step. init_optim_state(optimizer, skip_empty_param=True) state_dict["optimizer"] = optimizer.state_dict() # 3. Load a checkpoint load( state_dict=state_dict, process_group=current_replication_group, coordinator_rank=coordinator_rank, storage_reader=storage_reader, )

La memorizzazione di checkpoint per modelli linguistici di grandi dimensioni (LLMs) può essere costosa in quanto spesso richiede la creazione di un grande volume di file system. Per ridurre i costi, hai la possibilità di salvare i checkpoint direttamente su Amazon S3 senza la necessità di servizi di file system aggiuntivi come Amazon. FSx Per salvare i checkpoint in S3 specificando un URL S3 come destinazione puoi applicare l’esempio precedente al frammento di codice riportato di seguito.

key = os.path.join(checkpoint_dir, sub_dir) checkpoint_id= f"s3://{your_s3_bucket}/{key}" async_save(state_dict, checkpoint_id=checkpoint_id, **kw) load(state_dict, checkpoint_id=checkpoint_id, **kw)

Checkpoint sottoposti a sharding eseguiti in modo asincrono

Potrebbero verificarsi situazioni in cui è necessario continuare la formazione con diverse configurazioni hardware, ad esempio modificando il numero di. GPUs In questi casi, i job di addestramento devono caricare i checkpoint durante il resharding e riprendere l’addestramento successivo con un numero diverso di SHARD_DEGREE. Per risolvere lo scenario in cui è necessario riprendere l’addestramento con un numero diverso di SHARD_DEGREE, è necessario salvare i checkpoint del modello utilizzando il tipo di dizionario di stato sottoposto a sharding, rappresentato da StateDictType.SHARDED_STATE_DICT. Il salvataggio dei checkpoint in questo formato consente di gestire correttamente il processo di resharding quando si continua l’addestramento con una configurazione hardware modificata. Il frammento di codice fornito illustra come utilizzare l’API tsm per salvare in modo asincrono i checkpoint sottoposti a sharding, rendendo il processo di addestramento più semplice ed efficiente.

import os import torch.sagemaker as tsm from torch.sagemaker import state from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.state_dict_saver import ( async_save, maybe_finalize_async_calls, ) save_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(save_dir, sub_dir) # To determine whether curreto take part in checkpointing. global_rank = dist.get_rank() action_rank = state.ranker.get_rep_rank(global_rank) == 0 process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) # 1. wait for the previous checkpointing done maybe_finalize_async_calls(blocking=True, process_group=process_group) # 2. retrieve model & optimizer sharded state_dict with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): state_dict = { "model": model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer), # Potentially add more customized state dicts. } # 3. save checkpoints asynchronously using async_save if action_rank: async_save( state_dict, checkpoint_id=checkpoint_id, process_group=process_group, coordinator_rank=coordinator_rank, )

Il processo di caricamento dei checkpoint sottoposti a sharding è simile a quello della sezione precedente, ma prevede l’utilizzo di torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader e del relativo metodo load. Il metodo load di questa classe consente di caricare i dati dei checkpoint sottoposti a sharding, seguendo un processo analogo a quello descritto in precedenza.

import os from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import StateDictType from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict from torch.sagemaker.distributed.checkpoint.state_dict_loader import load from torch.sagemaker.utils.process_group_utils import get_global_ranks from torch.sagemaker.distributed.checkpoint.filesystem import ( DistributedFileSystemReader, ) load_dir = "/opt/ml/checkpoints" sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}" checkpoint_id = os.path.join(load_dir, sub_dir) reader = DistributedFileSystemReader(checkpoint_id) process_group = model.process_group coordinator_rank = min(get_global_ranks(process_group)) with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): # 1. Load model and everything else except the optimizer. state_dict = { "model": model.state_dict() # Potentially more customized state dicts. } load( state_dict, storage_reader=reader, process_group=process_group, coordinator_rank=coordinator_rank, ) model.load_state_dict(state_dict["model"]) # 2. Load optimizer. optim_state = load_sharded_optimizer_state_dict( model_state_dict=state_dict["model"], optimizer_key="optimizer", storage_reader=reader, process_group=process_group, ) flattened_optimizer_state = FSDP.optim_state_dict_to_load( optim_state["optimizer"], model, optimizer, group=model.process_group ) optimizer.load_state_dict(flattened_optimizer_state)

Checkpoint del modello completi

Al termine dell’addestramento, è possibile salvare un checkpoint completo che combini tutti gli shard di un modello in un unico file di checkpoint. La libreria SMP supporta completamente l'API dei checkpoint del modello PyTorch completo, quindi non è necessario apportare alcuna modifica.

Tieni presente che se utilizzi Parallelizzazione tensoriale SMP, la libreria SMP trasforma il modello. In questo caso, quando si esegue il checkpoint del modello completo, la libreria SMP traduce il modello nel formato di checkpoint Hugging Face Transformers per impostazione predefinita.

Nei casi in cui ci si allena con il parallelismo del tensore SMP e si disattiva il processo di traduzione SMP, è possibile utilizzare l'translate_on_saveargomento dell' PyTorch FullStateDictConfigAPI per attivare o disattivare la traduzione automatica SMP in base alle esigenze. Ad esempio, se ti stai concentrando sull’addestramento di un modello, non è necessario aggiungere il processo di traduzione che comporta sovraccarichi aggiuntivi. In questo caso, è consigliabile scegliere l’impostazione translate_on_save=False. Inoltre, se prevedi di continuare a utilizzare la traduzione SMP del modello per ulteriori cicli di addestramento futuri, puoi disattivarla e salvare la traduzione SMP del modello per un uso successivo. La traduzione del modello nel formato di checkpoint del modello di Hugging Face Transformers è necessaria quando si conclude l’addestramento del modello e lo si utilizza per l’inferenza.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullStateDictConfig import torch.sagemaker as tsm # Save checkpoints. with FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=True, offload_to_cpu=True, # Default value is to translate back to Hugging Face Transformers format, # when saving full checkpoints for models trained with SMP tensor parallelism. # translate_on_save=True ), ): state_dict = model.state_dict() if dist.get_rank() == 0: logger.info("Processed state dict to save. Starting write to disk now.") os.makedirs(save_dir, exist_ok=True) # This name is needed for HF from_pretrained API to work. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin")) hf_model_config.save_pretrained(save_dir) dist.barrier()

Tieni presente che l’opzione FullStateDictConfig(rank0_only=True, offload_to_cpu=True) consiste nel riunire il modello sulla CPU del dispositivo con classificazione 0 per risparmiare memoria durante l’addestramento di modelli di grandi dimensioni.

Per caricare il modello per l’inferenza, puoi eseguire questa operazione come illustrato nell’esempio di codice riportato di seguito. Tieni presente che la classe AutoModelForCausalLM potrebbe cambiare passando ad altre classi del generatore di fattori in Hugging Face Transformers, ad esempioAutoModelForSeq2SeqLM, a seconda del modello. Per ulteriori informazioni, consulta la documentazione di Hugging Face Transformers.

from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained(save_dir)