

Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.

# FlashAttention
<a name="model-parallel-core-features-v2-flashattention"></a>

SMP v2 mendukung [FlashAttention](https://github.com/HazyResearch/flash-attention)kernel dan membuatnya mudah untuk menerapkannya ke berbagai skenario untuk model Hugging Face Transformer. Perhatikan bahwa jika Anda menggunakan FlashAttention paket v2.0 atau yang lebih baru, SMP menggunakan FlashAttention v2; Namun, perhatian flash Triton default ke kernel perhatian flash di FlashAttention v1.x, membuatnya didukung secara eksklusif di v1. FlashAttention 

Module (`nn.Module`) adalah API tingkat rendah yang mendefinisikan lapisan perhatian model. Ini harus diterapkan tepat setelah pembuatan model, dari `AutoModelForCausalLM.from_config()` API misalnya, dan sebelum model diubah atau dibungkus dengan FSDP.

## Gunakan FlashAttention kernel untuk perhatian diri
<a name="model-parallel-core-features-v2-flashattention-self"></a>

Cuplikan kode berikut menunjukkan cara menggunakan [`torch.sagemaker.nn.attn.FlashSelfAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashselfattention) API yang disediakan oleh SMP v2.

```
def new_attn(self, q, k, v, attention_mask=None, head_mask=None):
    return (
        self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"),
        None,
    )

for layer in model.gpt_neox.layers:
    layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention()
    layer.attention._attn = functools.partial(new_attn, layer.attention)
```

## Gunakan FlashAttention kernel untuk perhatian kueri yang dikumpulkan
<a name="model-parallel-core-features-v2-flashattention-grouped-query"></a>

SMP v2 juga mendukung [FlashAttention](https://github.com/HazyResearch/flash-attention)kernel untuk grouped-query attention (GQA) dan membuatnya mudah untuk menerapkannya ke berbagai skenario untuk model Hugging Face Transformer. Berbeda dari arsitektur perhatian asli, GQA sama-sama mempartisi kepala kueri ke dalam grup, dan kepala kueri dalam grup yang sama berbagi kunci dan kepala nilai yang sama. Oleh karena itu, kepala q dan kv diteruskan ke panggilan maju secara terpisah. Catatan: Jumlah kepala q harus habis dibagi dengan jumlah kepala kv.

**Contoh penggunaan FlashGroupedQueryAttention**

Cuplikan kode berikut menunjukkan cara menggunakan [`torch.sagemaker.nn.attn.FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn) API yang disediakan oleh SMP v2.

```
from transformers.models.llama.modeling_llama import LlamaAttention
from torch.sagemaker.nn.attn import FlashGroupedQueryAttention

class LlamaFlashAttention(LlamaAttention):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)

        self.flash_attn = FlashGroupedQueryAttention(
            attention_dropout_prob=0.0,
        )
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        ...
    ):
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        ...
        kv = (key_states, value_states)
        attn_output = self.flash_attn(
            query_states,
            kv,
            attn_mask=attention_mask,
            causal=True,
            layout="b h s d",
        )
        ...
        attn_output = self.o_proj(attn_output)
        ...
        return attn_output
```

Pustaka SMP juga menyediakan[`torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-llamaFlashAttn), yang menggunakan [`torch.sagemaker.nn.attn.FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn) API pada tingkat rendah. Hugging Face Transformers memiliki [https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py)implementasi serupa yang disebut dari v4.36.0. Cuplikan kode berikut menunjukkan cara menggunakan API SMP v2 atau Transformers `LlamaFlashAttention` `LlamaFlashAttention2` API untuk mengganti lapisan perhatian model Llama yang ada.

```
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention
from transformers.models.llama.modeling_llama import LlamaFlashAttention2

flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2

attn_name = "self_attn"
for layer in model.model.layers:
    prev_layer = getattr(layer, attn_name)
    setattr(layer, attn_name, flash_attn_class(model.config))
```