

As traduções são geradas por tradução automática. Em caso de conflito entre o conteúdo da tradução e da versão original em inglês, a versão em inglês prevalecerá.

# Ponto de verificação com uso do SMP
<a name="model-parallel-core-features-v2-checkpoints"></a>

A biblioteca de paralelismo de SageMaker modelos (SMP) oferece suporte a pontos de PyTorch APIs verificação e fornece APIs esses pontos de verificação de forma adequada ao usar a biblioteca SMP. 

PyTorch O FSDP (paralelismo de dados totalmente fragmentado) suporta três tipos de pontos de verificação: completos, fragmentados e locais, cada um com propósitos diferentes. Os pontos de verificação completos são usados ao exportar o modelo após a conclusão do treinamento, pois gerar um ponto de verificação completo é um processo computacionalmente caro. Os pontos de verificação fragmentados ajudam a salvar e a carregar o estado de um modelo fragmentado para cada classificação individual. Com pontos de verificação fragmentados, você pode retomar o treinamento com diferentes configurações de hardware, como um número diferente de. GPUs Entretanto, o carregamento de pontos de verificação fragmentados pode ser lento devido à comunicação envolvida entre vários dispositivos. A biblioteca de SMP fornece funcionalidades de ponto de verificação local, que permitem uma recuperação mais rápida do estado do modelo sem sobrecarga adicional de comunicação. Observe que os pontos de verificação criados pelo FSDP exigem gravação em um sistema de arquivos de rede compartilhado, como o Amazon. FSx

## Pontos de verificação locais assíncronos
<a name="w2aac25c25c19c19c33b7"></a>

Ao treinar modelos de machine learning, não é necessário que as iterações posteriores aguardem até que os arquivos do ponto de verificação sejam salvos no disco. Com o lançamento do SMP v2.5, a biblioteca aceita o salvamento de arquivos de ponto de verificação de forma assíncrona. Isso significa que a iteração de treinamento subsequente pode ser executada simultaneamente com as I/O) operations for creating checkpoints, without being slowed down or held back by those I/O operações de entrada e saída (). Além disso, o processo de recuperação dos parâmetros fragmentados do modelo e do otimizador PyTorch pode ser demorado devido à comunicação coletiva adicional necessária para trocar metadados de tensores distribuídos entre as classificações. Mesmo quando usado `StateDictType.LOCAL_STATE_DICT` para salvar pontos de verificação locais para cada classificação, PyTorch ainda invoca ganchos que realizam comunicação coletiva. Para reduzir esse problema e diminuir o tempo necessário para a recuperação do ponto de verificação, o SMP apresenta o `SMStateDictType.SM_LOCAL_STATE_DICT`, que permite uma recuperação mais rápida dos pontos de verificação do modelo e do otimizador, ignorando a sobrecarga de comunicação coletiva. 

**nota**  
Manter a consistência no FSDP `SHARD_DEGREE` é um requisito para utilizar o `SMStateDictType.SM_LOCAL_STATE_DICT`. Certifique-se de que o `SHARD_DEGREE` permaneça inalterado. Embora o número de replicações do modelo possa variar, o grau de fragmentação do modelo precisa ser idêntico à configuração de treinamento anterior ao retornar de um ponto de verificação.

```
import os
import torch.distributed as dist
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.sagemaker.distributed.checkpoint.state_dict_saver import (
    async_save,
    maybe_finalize_async_calls,
)
from torch.sagemaker.distributed.checkpoint.state_dict_utils import (
    sm_state_dict_type,
    SMStateDictType,
)

global_rank = dist.get_rank()
save_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}"

# 1. Get replication ranks and group
current_replication_group = None
current_replication_ranks = None
for replication_ranks in state.ranker.get_rep_groups():
    rep_group = dist.new_group(replication_ranks)
    if global_rank in replication_ranks:
        current_replication_group = rep_group
        current_replication_ranks = replication_ranks

coordinator_rank = min(current_replication_ranks)

# 2. Wait for the previous checkpointing done
maybe_finalize_async_calls(
    blocking=True, process_group=current_replication_group
)

# 3. Get model local checkpoint
with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT):
    state_dict = {
       "model": model.state_dict(),
       "optimizer": optimizer.state_dict(),
        # Potentially add more customized state dicts.
    }

# 4. Save a local checkpoint 
async_save(
    state_dict,
    checkpoint_id=os.path.join(save_dir, sub_dir),
    process_group=current_replication_group,
    coordinator_rank=coordinator_rank,
)
```

O trecho de código a seguir demonstra como carregar um ponto de verificação com uso de `SMStateDictType.SM_LOCAL_STATE_DICT`.

```
import os
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.sagemaker.distributed.checkpoint.state_dict_loader import load
from torch.sagemaker.distributed.checkpoint.state_dict_utils import (
    sm_state_dict_type,
    SMStateDictType,
    init_optim_state
)
from torch.sagemaker.distributed.checkpoint.filesystem import (
    DistributedFileSystemReader,
)

load_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}_fsdp{model.rank}"
global_rank = dist.get_rank()
checkpoint_id = os.path.join(load_dir, sub_dir)
storage_reader = DistributedFileSystemReader(checkpoint_id)

# 1. Get replication ranks and group
current_replication_group = None
current_replication_ranks = None
for replication_ranks in state.ranker.get_rep_groups():
    rep_group = dist.new_group(replication_ranks)
    if global_rank in replication_ranks:
        current_replication_group = rep_group
        current_replication_ranks = replication_ranks

coordinator_rank = min(current_replication_ranks)

# 2. Create local state_dict
with sm_state_dict_type(model, SMStateDictType.SM_LOCAL_STATE_DICT):
    state_dict = {
        "model": model.state_dict(),
        # Potentially add more customized state dicts.
    }
 
    # Init optimizer state_dict states by setting zero grads and step.
    init_optim_state(optimizer, skip_empty_param=True)
    state_dict["optimizer"] = optimizer.state_dict()
 
# 3. Load a checkpoint
load(
    state_dict=state_dict,
    process_group=current_replication_group,
    coordinator_rank=coordinator_rank,
    storage_reader=storage_reader,
)
```

Armazenar pontos de verificação para modelos de linguagem grandes (LLMs) pode ser caro, pois geralmente requer a criação de um grande volume de sistema de arquivos. Para reduzir custos, você tem a opção de salvar pontos de verificação diretamente no Amazon S3 sem a necessidade de serviços adicionais de sistema de arquivos, como o Amazon. FSx Você pode aproveitar o exemplo anterior com o trecho de código a seguir para salvar os pontos de verificação no S3 ao especificar uma URL do S3 como destino. 

```
key = os.path.join(checkpoint_dir, sub_dir)
checkpoint_id= f"{{s3://{your_s3_bucket}/{key}}}"
async_save(state_dict, checkpoint_id=checkpoint_id, **kw)
load(state_dict, checkpoint_id=checkpoint_id, **kw)
```

## Pontos de verificação fragmentados assíncronos
<a name="w2aac25c25c19c19c33b9"></a>

Pode haver situações em que você precise continuar treinando com diferentes configurações de hardware, como alterar o número de GPUs. Nesses casos, seus processos de treinamento devem carregar os pontos de verificação durante a fragmentação, o que significa retomar o treinamento posterior com um número diferente de `SHARD_DEGREE`. Para chegar ao cenário em que você precisa retomar o treinamento com um número diferente de `SHARD_DEGREE`, você deve salvar os pontos de verificação do modelo usando o tipo de dicionário de estado fragmentado, representado por `StateDictType.SHARDED_STATE_DICT`. Salvar pontos de verificação nesse formato permite que você gerencie adequadamente o processo de refragmentação ao continuar o treinamento com uma configuração de hardware modificada. O trecho de código fornecido ilustra como usar a API `tsm` para salvar pontos de verificação fragmentados de forma assíncrona, o que permite um processo de treinamento mais eficiente e simplificado.

```
import os
import torch.sagemaker as tsm
from torch.sagemaker import state
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.sagemaker.utils.process_group_utils import get_global_ranks
from torch.sagemaker.distributed.checkpoint.state_dict_saver import (
    async_save,
    maybe_finalize_async_calls,
)

save_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}"
checkpoint_id = os.path.join(save_dir, sub_dir)

# To determine whether curreto take part in checkpointing.
global_rank = dist.get_rank()
action_rank = state.ranker.get_rep_rank(global_rank) == 0
process_group = model.process_group
coordinator_rank = min(get_global_ranks(process_group))

# 1. wait for the previous checkpointing done
maybe_finalize_async_calls(blocking=True, process_group=process_group)

# 2. retrieve model & optimizer sharded state_dict
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
    state_dict = {
        "model": model.state_dict(),
        "optimizer": FSDP.optim_state_dict(model, optimizer),
        # Potentially add more customized state dicts.
    }
 
# 3. save checkpoints asynchronously using async_save
if action_rank:
    async_save(
        state_dict,
        checkpoint_id=checkpoint_id,
        process_group=process_group,
        coordinator_rank=coordinator_rank,
    )
```

O processo de carregamento de pontos de verificação compartilhados é semelhante ao da seção anterior, mas inclui o uso do `torch.sagemaker.distributed.checkpoint.filesystem.DistributedFileSystemReader` e seu método `load`. O método `load` dessa função permite carregar os dados compartilhados do ponto de verificação, seguindo um processo semelhante ao descrito anteriormente.

```
import os
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.sagemaker.distributed.checkpoint.state_dict_loader import load
from torch.sagemaker.utils.process_group_utils import get_global_ranks
from torch.sagemaker.distributed.checkpoint.filesystem import (
    DistributedFileSystemReader,
)
 
 load_dir = "/opt/ml/checkpoints"
sub_dir = f"tp{state.tp_rank}_ep{state.ep_rank}"
checkpoint_id = os.path.join(load_dir, sub_dir)
reader = DistributedFileSystemReader(checkpoint_id)

process_group = model.process_group
coordinator_rank = min(get_global_ranks(process_group))

with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
   # 1. Load model and everything else except the optimizer.
   state_dict = {
        "model": model.state_dict()
        # Potentially more customized state dicts.
   }
   load(
        state_dict,
        storage_reader=reader,
        process_group=process_group,
        coordinator_rank=coordinator_rank,
   )
   model.load_state_dict(state_dict["model"])
 
   # 2. Load optimizer.
   optim_state = load_sharded_optimizer_state_dict(
        model_state_dict=state_dict["model"],
        optimizer_key="optimizer",
        storage_reader=reader,
        process_group=process_group,
    )    
   flattened_optimizer_state = FSDP.optim_state_dict_to_load(
        optim_state["optimizer"], model, optimizer,
         group=model.process_group
   )
   optimizer.load_state_dict(flattened_optimizer_state)
```

## Pontos de verificação de modelos completos
<a name="model-parallel-core-features-v2-checkpoints-full"></a>

Ao final do treinamento, você pode salvar um ponto de verificação completo que une todos os fragmentos de um modelo em um único arquivo de ponto de verificação do modelo. A biblioteca SMP é totalmente compatível com a API PyTorch completa de pontos de verificação do modelo, portanto, você não precisa fazer nenhuma alteração.

Observe que, se usar o SMP [Paralelismo de tensores](model-parallel-core-features-v2-tensor-parallelism.md), a biblioteca de SMP transforma o modelo. Ao verificar o modelo completo nesse caso, a biblioteca de SMP converte o modelo de volta para o formato de ponto de verificação do Hugging Face Transformers por padrão.

Nos casos em que você treina com o paralelismo do tensor SMP e desativa o processo de tradução do SMP, você pode usar o `translate_on_save` argumento da PyTorch `FullStateDictConfig` API para ativar ou desativar a tradução automática do SMP conforme necessário. Por exemplo, se você se concentrar em treinar um modelo, não precisa adicionar o processo de conversão, que aumenta a sobrecarga. Nesse caos, recomendamos que você defina como `translate_on_save=False`. Além disso, se planeja continuar usando a conversão do SMP do modelo para treinamento adicional no futuro, você pode desativá-la para salvar a conversão do SMP do modelo para uso posterior. É necessário converter o modelo de volta para o formato de ponto de verificação do modelo Hugging Face Transformers quando você encerra o treinamento do seu modelo e o usa para inferência.

```
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullStateDictConfig
import torch.sagemaker as tsm

# Save checkpoints.
with FSDP.state_dict_type(
    model, 
    StateDictType.FULL_STATE_DICT, 
    FullStateDictConfig(
        rank0_only=True, offload_to_cpu=True,
        # Default value is to translate back to Hugging Face Transformers format,
        # when saving full checkpoints for models trained with SMP tensor parallelism.
        # translate_on_save=True
    ),
):
    state_dict = model.state_dict()
    if dist.get_rank() == 0:
        logger.info("Processed state dict to save. Starting write to disk now.")
        os.makedirs({{save_dir}}, exist_ok=True)
        # This name is needed for HF from_pretrained API to work.
        torch.save(state_dict, os.path.join({{save_dir}}, "pytorch_model.bin"))
        hf_model_config.save_pretrained({{save_dir}})
    dist.barrier()
```

Observe que a opção `FullStateDictConfig(rank0_only=True, offload_to_cpu=True)` é reunir o modelo na CPU do dispositivo de 0ª classificação para economizar memória ao treinar grandes modelos.

Para carregar o modelo de volta para inferência, faça isso conforme apresentado no exemplo de código a seguir. Observe que a função `AutoModelForCausalLM` pode mudar para outras funções de criação de fatores no modelo Hugging Face Transformers, como `AutoModelForSeq2SeqLM`, dependendo do seu modelo. Para obter mais informações, consulte a [documentação do Hugging Face Transformers](https://huggingface.co/docs/transformers/v4.36.1/en/model_doc/auto#natural-language-processing).

```
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained({{save_dir}})
```