

翻訳は機械翻訳により提供されています。提供された翻訳内容と英語版の間で齟齬、不一致または矛盾がある場合、英語版が優先します。

# FlashAttention
<a name="model-parallel-core-features-v2-flashattention"></a>

SMP v2 は [FlashAttention](https://github.com/HazyResearch/flash-attention) カーネルをサポートしており、Hugging Face Transformer モデルのさまざまなシナリオに簡単に適用できます。FlashAttention パッケージ v2.0 以降を使用する場合、SMP は FlashAttention v2 を使用します。しかし、Triton のフラッシュアテンションはデフォルトで FlashAttention v1.x のフラッシュアテンションカーネルを使用するので、FlashAttention v1 でのみサポートされます。

モジュール (`nn.Module`) は、モデルのアテンション層を定義する低レベル API です。モデルの作成直後に、例えば `AutoModelForCausalLM.from_config()` API から、モデルが変換されるか、FSDP でラップされる前に適用する必要があります。

## セルフアテンションに FlashAttention カーネルを使用する
<a name="model-parallel-core-features-v2-flashattention-self"></a>

次のコードスニペットは、SMP v2 が提供する [`torch.sagemaker.nn.attn.FlashSelfAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashselfattention) API の使用方法を示しています。

```
def new_attn(self, q, k, v, attention_mask=None, head_mask=None):
    return (
        self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"),
        None,
    )

for layer in model.gpt_neox.layers:
    layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention()
    layer.attention._attn = functools.partial(new_attn, layer.attention)
```

## グループ化クエリアテンションに FlashAttention カーネルを使用する
<a name="model-parallel-core-features-v2-flashattention-grouped-query"></a>

SMP v2 は、グループ化クエリアテンション (GQA) でも [FlashAttention](https://github.com/HazyResearch/flash-attention) カーネルをサポートしており、Hugging Face Transformer モデルのさまざまなシナリオに簡単に適用できます。元のアテンションアーキテクチャとは異なり、GQA はクエリヘッドを複数のグループに均等に分割し、同じグループ内のクエリヘッドは同じキーと値のヘッドを共有します。したがって、q ヘッドと kv ヘッドは前方呼び出しで個別に渡されます。注: q ヘッドの数は、kv ヘッドの数で割り切れる必要があります。

**FlashGroupedQueryAttention の使用例**

次のコードスニペットは、SMP v2 が提供する [`torch.sagemaker.nn.attn.FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn) API の使用方法を示しています。

```
from transformers.models.llama.modeling_llama import LlamaAttention
from torch.sagemaker.nn.attn import FlashGroupedQueryAttention

class LlamaFlashAttention(LlamaAttention):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)

        self.flash_attn = FlashGroupedQueryAttention(
            attention_dropout_prob=0.0,
        )
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        ...
    ):
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        ...
        kv = (key_states, value_states)
        attn_output = self.flash_attn(
            query_states,
            kv,
            attn_mask=attention_mask,
            causal=True,
            layout="b h s d",
        )
        ...
        attn_output = self.o_proj(attn_output)
        ...
        return attn_output
```

SMP ライブラリは [`torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-llamaFlashAttn) も提供していますが、これは低レベルで [`torch.sagemaker.nn.attn.FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn) API を使用します。Hugging Face Transformers には、v4.36.0 以降で [https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) という同様の実装があります。次のコードスニペットは、SMP v2 の `LlamaFlashAttention` API または Transformers の `LlamaFlashAttention2` API を使用して、既存の Llama モデルのアテンション層を置き換える方法を示しています。

```
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention
from transformers.models.llama.modeling_llama import LlamaFlashAttention2

flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2

attn_name = "self_attn"
for layer in model.model.layers:
    prev_layer = getattr(layer, attn_name)
    setattr(layer, attn_name, flash_attn_class(model.config))
```