

Die vorliegende Übersetzung wurde maschinell erstellt. Im Falle eines Konflikts oder eines Widerspruchs zwischen dieser übersetzten Fassung und der englischen Fassung (einschließlich infolge von Verzögerungen bei der Übersetzung) ist die englische Fassung maßgeblich.

# FlashAttention
<a name="model-parallel-core-features-v2-flashattention"></a>

SMP v2 unterstützt [FlashAttention](https://github.com/HazyResearch/flash-attention)Kernel und macht es einfach, sie auf verschiedene Szenarien für Hugging Face Transformer-Modelle anzuwenden. Beachten Sie, dass SMP FlashAttention v2 verwendet, wenn Sie FlashAttention Paket v2.0 oder höher verwenden. Triton Flash Attention verwendet jedoch standardmäßig den Flash Attention-Kernel in FlashAttention v1.x, sodass er ausschließlich in Version 1 unterstützt wird. FlashAttention 

Das Modul (`nn.Module`) ist eine Low-Level-API, die die Aufmerksamkeitsebenen eines Modells definiert. Es sollte direkt nach der Modellerstellung angewendet werden, beispielsweise über die `AutoModelForCausalLM.from_config()`-API, bevor das Modell transformiert oder mit FSDP umschlossen wird.

## Verwenden Sie Kernel zur Selbstwahrnehmung FlashAttention
<a name="model-parallel-core-features-v2-flashattention-self"></a>

Der folgende Codeausschnitt veranschaulicht, wie die von SMP v2 bereitgestellte [`torch.sagemaker.nn.attn.FlashSelfAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashselfattention)-API verwendet wird.

```
def new_attn(self, q, k, v, attention_mask=None, head_mask=None):
    return (
        self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"),
        None,
    )

for layer in model.gpt_neox.layers:
    layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention()
    layer.attention._attn = functools.partial(new_attn, layer.attention)
```

## Verwenden Sie FlashAttention Kernel für die Aufmerksamkeit bei gruppierten Abfragen
<a name="model-parallel-core-features-v2-flashattention-grouped-query"></a>

SMP v2 unterstützt auch [FlashAttention](https://github.com/HazyResearch/flash-attention)Kernel für Grouped-Query Attention (GQA) und macht es einfach, sie auf verschiedene Szenarien für Hugging Face Transformer-Modelle anzuwenden. Im Unterschied zur ursprünglichen Aufmerksamkeitsarchitektur unterteilt GQA Abfrageköpfe gleichmäßig in Gruppen und die Abfrageköpfe in derselben Gruppe verwenden dieselben Schlüssel- und Wertköpfe. Daher werden q- und kv-Köpfe getrennt an den Vorwärtsaufruf übergeben. Hinweis: Die Anzahl der q-Köpfe muss durch die Anzahl der kv-Köpfe teilbar sein.

**Beispiel für die Verwendung FlashGroupedQueryAttention**

Der folgende Codeausschnitt veranschaulicht, wie die von SMP v2 bereitgestellte [`torch.sagemaker.nn.attn.FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn)-API verwendet wird.

```
from transformers.models.llama.modeling_llama import LlamaAttention
from torch.sagemaker.nn.attn import FlashGroupedQueryAttention

class LlamaFlashAttention(LlamaAttention):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)

        self.flash_attn = FlashGroupedQueryAttention(
            attention_dropout_prob=0.0,
        )
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        ...
    ):
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        ...
        kv = (key_states, value_states)
        attn_output = self.flash_attn(
            query_states,
            kv,
            attn_mask=attention_mask,
            causal=True,
            layout="b h s d",
        )
        ...
        attn_output = self.o_proj(attn_output)
        ...
        return attn_output
```

Die SMP-Bibliothek bietet auch die Funktion [`torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-llamaFlashAttn), die die [`torch.sagemaker.nn.attn.FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn)-API auf niedriger Ebene verwendet. Hugging Face Transformers hat eine ähnliche Implementierung, die ab Version 4.36.0 [https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) genannt wird. Der folgende Codeausschnitt zeigt, wie die APIs SMP `LlamaFlashAttention` v2 oder Transformers `LlamaFlashAttention2` verwendet werden, um die Aufmerksamkeitsebenen eines vorhandenen Llama-Modells zu ersetzen.

```
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention
from transformers.models.llama.modeling_llama import LlamaFlashAttention2

flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2

attn_name = "self_attn"
for layer in model.model.layers:
    prev_layer = getattr(layer, attn_name)
    setattr(layer, attn_name, flash_attn_class(model.config))
```