

Las traducciones son generadas a través de traducción automática. En caso de conflicto entre la traducción y la version original de inglés, prevalecerá la version en inglés.

# Inicialización diferida de parámetros
<a name="model-parallel-core-features-v2-delayed-param-init"></a>

La inicialización de un modelo grande para entrenamiento no siempre es posible con memoria limitada de la GPU. Para resolver este problema de memoria insuficiente de la GPU, puede inicializar el modelo en la memoria de la CPU. Sin embargo, para modelos más grandes con más de 20 000 o 40 000 millones de parámetros, incluso la memoria de la CPU podría resultar insuficiente. En tal caso, le recomendamos que inicialice el modelo en lo que se PyTorch denomina un *metadispositivo*, que permite la creación de tensores sin necesidad de adjuntar ningún dato a ellos. Un tensor en un metadispositivo solo necesita información de la forma, lo que permite crear un modelo grande con sus parámetros en metadispositivos. [Hugging Face Accelerate](https://huggingface.co/docs/accelerate/index) proporciona el administrador de contexto `init_empty_weights` para ayudar a crear dicho modelo en metadispositivos y, al mismo tiempo, inicializar los búferes en un dispositivo normal. Antes de que comience el entrenamiento, el PyTorch FSDP inicializa los parámetros del modelo. Esta función de inicialización retardada de los parámetros del SMP v2 retrasa la creación de los parámetros del modelo para que se produzca después PyTorch de que el FSDP realice la fragmentación de los parámetros. PyTorch El FSDP acepta una función de inicialización de parámetros (`param_init_fn`) al fragmentar los módulos y llama a cada módulo. `param_init_fn` La `param_init_fn` API toma un módulo como argumento e inicializa todos sus parámetros, sin incluir los parámetros de módulos secundarios. Tenga en cuenta que este comportamiento *difiere* del de la versión PyTorch 2.0.1 nativa, que tiene un error que provoca que los parámetros se inicialicen varias veces.

SMP v2 proporciona la [`torch.sagemaker.delayed_param.DelayedParamIniter`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-delayed-param-init) API para aplicar la inicialización diferida de parámetros.

Los siguientes fragmentos de código muestran cómo aplicar la `torch.sagemaker.delayed_param.DelayedParamIniter` API al script de entrenamiento.

Suponga que tiene el siguiente guion de formación sobre el PyTorch FSDP.

```
# Creation of model on meta device
from accelerate import init_empty_weights
with init_empty_weights():
    model = create_model()

# Define a param init fn, below is an example for Hugging Face GPTNeoX.
def init_weights(module):
    d = torch.cuda.current_device()
    # Note that below doesn't work if you have buffers in the model
    # buffers will need to reinitialized after this call
    module.to_empty(device=d, recurse=False)
    if isinstance(module, (nn.Linear, Conv1D)):
        module.weight.data.normal_(mean=0.0, std=args.initializer_range)
        if module.bias:
            module.bias.data.zero_()
    elif isinstance(module, nn.Embedding):
        module.weight.data.normal_(mean=0.0, std=args.initializer_range)
        if module.padding_idx:
            module.weight.data[module.padding_idx].zero_()
    elif isinstance(module, nn.LayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)

# Changes to FSDP wrapper.
model = FSDP(
    model,
    ...,
    param_init_fn=init_weights
)

# At this point model is initialized and sharded for sharded data parallelism.
```

Tenga en cuenta que el enfoque de inicialización diferido de parámetros no es independiente del modelo. Para resolver este problema, debe escribir una función `init_weights`, como se muestra en el ejemplo anterior, que coincida con la inicialización de la definición del modelo original y que abarque todos los parámetros del modelo. Para simplificar este proceso de preparación de dicha función `init_weights`, SMP v2 implementa esta función de inicialización para los siguientes modelos: GPT-2, GPT-J, GPT-NeoX y Llama de Hugging Face Transformers. La `torch.sagemaker.delayed_param.DelayedParamIniter` API también funciona con la implementación de paralelismo de tensores de SMP, modelo de `torch.sagemaker.tensor_parallel.transformer.TransformerLMHead`, al que puede llamar después de la llamada a la [`torch.sagemaker.transform`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-transform) API.

Con la `torch.sagemaker.delayed_param.DelayedParamIniter` API, puede adaptar su script del PyTorch FSDP de la siguiente manera. Después de crear un modelo con ponderaciones vacías, registre la `torch.sagemaker.delayed_param.DelayedParamIniter` API en el modelo y defina un objeto de este. Pase el objeto a la clase `param_init_fn` PyTorch FSDP.

```
from torch.sagemaker.delayed_param import DelayedParamIniter
from accelerate import init_empty_weights

with init_empty_weights():
    model = create_model()
    
delayed_initer = DelayedParamIniter(model)

with delayed_initer.validate_params_and_buffers_inited():
    model = FSDP(
        model,
        ...,
        param_init_fn=delayed_initer.get_param_init_fn()
    )
```

**Notas sobre ponderaciones vinculadas**

Al entrenar modelos con pesas atadas, debemos tener especial cuidado de atar las pesas después de inicializar las pesas con una inicialización retardada de los parámetros. PyTorch El FSDP no tiene un mecanismo para vincular los pesos después de inicializarlos utilizando `param_init_fn` lo anterior. Para abordar estos casos, hemos agregado una API para permitir un `post_init_hook_fn`, que se puede usar para vincular las ponderaciones. Puede pasar cualquier función que acepte el módulo como argumento, pero también tenemos un `post_param_init_fn` predefinido que se define en `DelayedParamIniter` que llama al método `tie_weights` del módulo si existe. Tenga en cuenta que es seguro pasar siempre en `post_param_init_fn`, incluso aunque no haya ningún método `tie_weights` para el módulo.

```
with delayed_initer.validate_params_and_buffers_inited():
    model = FSDP(
        model,
        ...,
        param_init_fn=delayed_initer.get_param_init_fn(),
        post_param_init_fn=delayed_initer.get_post_param_init_fn()
    )
```