

Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.

# FlashAttention
<a name="model-parallel-core-features-v2-flashattention"></a>

SMP v2 prend en charge [FlashAttention](https://github.com/HazyResearch/flash-attention)les noyaux et permet de les appliquer facilement à différents scénarios pour les modèles Hugging Face Transformer. Notez que si vous utilisez le FlashAttention package v2.0 ou une version ultérieure, SMP utilise la version FlashAttention v2 ; toutefois, le Triton Flash Attention utilise par défaut le noyau Flash Attention dans la FlashAttention version v1.x, ce qui le rend exclusivement pris en charge dans la version v1. FlashAttention 

Le module (`nn.Module`) est une API de bas niveau qui définit les couches d’attention d’un modèle. Il doit être appliqué juste après la création du modèle, à partir de l’API `AutoModelForCausalLM.from_config()`, par exemple, et avant que le modèle ne soit transformé ou encapsulé avec FSDP.

## Utilisez des FlashAttention noyaux pour vous concentrer
<a name="model-parallel-core-features-v2-flashattention-self"></a>

L’extrait de code suivant illustre comment utiliser l’API [`torch.sagemaker.nn.attn.FlashSelfAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashselfattention) fournie par SMP v2.

```
def new_attn(self, q, k, v, attention_mask=None, head_mask=None):
    return (
        self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"),
        None,
    )

for layer in model.gpt_neox.layers:
    layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention()
    layer.attention._attn = functools.partial(new_attn, layer.attention)
```

## Utiliser des FlashAttention noyaux pour attirer l'attention sur les requêtes groupées
<a name="model-parallel-core-features-v2-flashattention-grouped-query"></a>

SMP v2 prend également en charge les [FlashAttention](https://github.com/HazyResearch/flash-attention)noyaux pour l'attention par requêtes groupées (GQA) et permet de les appliquer facilement à différents scénarios pour les modèles Hugging Face Transformer. Contrairement à l’architecture d’attention originale, GQA partitionne de façon égale les têtes de requête en groupes, et les têtes de requête d’un même groupe partagent les mêmes têtes de clé et de valeur. Par conséquent, les têtes q et kv sont transmises séparément à l’appel de transmission vers l’avant. Remarque : le nombre de têtes q doit être divisible par le nombre de têtes kv.

**Exemple d'utilisation FlashGroupedQueryAttention**

L’extrait de code suivant illustre comment utiliser l’API [`torch.sagemaker.nn.attn.FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn) fournie par SMP v2.

```
from transformers.models.llama.modeling_llama import LlamaAttention
from torch.sagemaker.nn.attn import FlashGroupedQueryAttention

class LlamaFlashAttention(LlamaAttention):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)

        self.flash_attn = FlashGroupedQueryAttention(
            attention_dropout_prob=0.0,
        )
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        ...
    ):
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        ...
        kv = (key_states, value_states)
        attn_output = self.flash_attn(
            query_states,
            kv,
            attn_mask=attention_mask,
            causal=True,
            layout="b h s d",
        )
        ...
        attn_output = self.o_proj(attn_output)
        ...
        return attn_output
```

La bibliothèque SMP fournit également [`torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-llamaFlashAttn), qui utilise l’API [`torch.sagemaker.nn.attn.FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn) à bas niveau. Les transformeurs Hugging Face bénéficient d’une implémentation similaire appelée [https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) depuis la version v4.36.0. L’extrait de code suivant montre comment utiliser l’API `LlamaFlashAttention` SMP v2 ou l’API `LlamaFlashAttention2` Transformers pour remplacer les couches d’attention d’un modèle Llama existant.

```
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention
from transformers.models.llama.modeling_llama import LlamaFlashAttention2

flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2

attn_name = "self_attn"
for layer in model.model.layers:
    prev_layer = getattr(layer, attn_name)
    setattr(layer, attn_name, flash_attn_class(model.config))
```