

本文属于机器翻译版本。若本译文内容与英语原文存在差异，则一律以英文原文为准。

# FlashAttention
<a name="model-parallel-core-features-v2-flashattention"></a>

SMP v2 支持[FlashAttention](https://github.com/HazyResearch/flash-attention)内核，可以轻松地将其应用于 Hugging Face Transformer 模型的各种场景。请注意，如果您使用 v2.0 或更高版本的 FlashAttention 软件包，SMP 使用 FlashAttention v2；但是，在 v FlashAttention 1.x 中，Triton 闪光注意力默认为闪光注意内核，因此在 v1 中仅支持该内核。 FlashAttention 

模块 (`nn.Module`) 是一种低级 API，用于定义模型的注意层。它应在模型创建后立即应用，例如从 `AutoModelForCausalLM.from_config()` API，并在使用 FSDP 对模型进行转换或封装之前应用。

## 使用 FlashAttention 内核来集中注意力
<a name="model-parallel-core-features-v2-flashattention-self"></a>

下面的代码片段显示了如何使用 SMP v2 提供的 [`torch.sagemaker.nn.attn.FlashSelfAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashselfattention) API。

```
def new_attn(self, q, k, v, attention_mask=None, head_mask=None):
    return (
        self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"),
        None,
    )

for layer in model.gpt_neox.layers:
    layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention()
    layer.attention._attn = functools.partial(new_attn, layer.attention)
```

## 使用 FlashAttention 内核进行分组查询注意
<a name="model-parallel-core-features-v2-flashattention-grouped-query"></a>

SMP v2 还支持用于分组查询注意力 (GQA) 的[FlashAttention](https://github.com/HazyResearch/flash-attention)内核，并且可以轻松地将其应用于 Hugging Face Transformer 模型的各种场景。与最初的注意架构不同，GQA 将查询磁头平均分为若干组，同一组中的查询磁头共享相同的键和值磁头。因此，q 和 kv 磁头被分别传入前向调用。注意：q 磁头的数量需要可以被 kv 磁头的数量整除。

**使用示例 FlashGroupedQueryAttention**

下面的代码片段显示了如何使用 SMP v2 提供的 [`torch.sagemaker.nn.attn.FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn) API。

```
from transformers.models.llama.modeling_llama import LlamaAttention
from torch.sagemaker.nn.attn import FlashGroupedQueryAttention

class LlamaFlashAttention(LlamaAttention):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)

        self.flash_attn = FlashGroupedQueryAttention(
            attention_dropout_prob=0.0,
        )
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        ...
    ):
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        ...
        kv = (key_states, value_states)
        attn_output = self.flash_attn(
            query_states,
            kv,
            attn_mask=attention_mask,
            causal=True,
            layout="b h s d",
        )
        ...
        attn_output = self.o_proj(attn_output)
        ...
        return attn_output
```

SMP 库还提供 [`torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-llamaFlashAttn)，它在低级别使用 [`torch.sagemaker.nn.attn.FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn) API。Hugging Face 转换器在 4.36.0 版中也有一个名为 [https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) 的类似实现。下面的代码片段显示了如何使用 SMP v2 `LlamaFlashAttention` API 或转换器 `LlamaFlashAttention2` API 替换现有 Llama 模型的注意层。

```
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention
from transformers.models.llama.modeling_llama import LlamaFlashAttention2

flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2

attn_name = "self_attn"
for layer in model.model.layers:
    prev_layer = getattr(layer, attn_name)
    setattr(layer, attn_name, flash_attn_class(model.config))
```