

기계 번역으로 제공되는 번역입니다. 제공된 번역과 원본 영어의 내용이 상충하는 경우에는 영어 버전이 우선합니다.

# FlashAttention
<a name="model-parallel-core-features-v2-flashattention"></a>

SMP v2는 [FlashAttention](https://github.com/HazyResearch/flash-attention) 커널을 지원하므로 Hugging Face 트랜스포머 모델의 다양한 시나리오에 쉽게 적용할 수 있습니다. FlashAttention 패키지 v2.0 이상을 사용하는 경우 SMP는 FlashAttention v2를 사용하지만 Triton 플래시 주의는 FlashAttention v1.x의 플래시 주의 커널로 기본 설정되므로 FlashAttention v1에서만 지원됩니다.

모듈(`nn.Module`)은 모델의 주의 계층을 정의하는 하위 수준 API입니다. 예를 들어 `AutoModelForCausalLM.from_config()` API에서 모델 생성 직후, 모델을 변환하거나 FSDP로 래핑하기 전에 적용해야 합니다.

## 자기 관심을 위해 FlashAttention 커널 사용
<a name="model-parallel-core-features-v2-flashattention-self"></a>

다음 코드 조각은 SMP v2에서 제공하는 [`torch.sagemaker.nn.attn.FlashSelfAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashselfattention) API를 사용하는 방법을 보여줍니다.

```
def new_attn(self, q, k, v, attention_mask=None, head_mask=None):
    return (
        self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"),
        None,
    )

for layer in model.gpt_neox.layers:
    layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention()
    layer.attention._attn = functools.partial(new_attn, layer.attention)
```

## 그룹화된 쿼리 주의에 FlashAttention 커널 사용
<a name="model-parallel-core-features-v2-flashattention-grouped-query"></a>

또한 SMP v2는 그룹화된 쿼리 주의(GQA)를 위한 [FlashAttention](https://github.com/HazyResearch/flash-attention) 커널을 지원하므로 Hugging Face 트랜스포머 모델의 다양한 시나리오에 쉽게 적용할 수 있습니다. GQA는 원래 주의 아키텍처와 다르게 쿼리 헤드를 그룹으로 균등하게 분할하고 동일한 그룹의 쿼리 헤드는 동일한 키와 값 헤드를 공유합니다. 따라서 q 및 kv 헤드는 전달 호출로 별도로 전달됩니다. 참고: q 헤드 수는 kv 헤드 수로 나눌 수 있어야 합니다.

**FlashGroupedQueryAttention 사용 예제**

다음 코드 조각은 SMP v2에서 제공하는 [`torch.sagemaker.nn.attn.FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn) API를 사용하는 방법을 보여줍니다.

```
from transformers.models.llama.modeling_llama import LlamaAttention
from torch.sagemaker.nn.attn import FlashGroupedQueryAttention

class LlamaFlashAttention(LlamaAttention):
    def __init__(self, config: LlamaConfig):
        super().__init__(config)

        self.flash_attn = FlashGroupedQueryAttention(
            attention_dropout_prob=0.0,
        )
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        ...
    ):
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        ...
        kv = (key_states, value_states)
        attn_output = self.flash_attn(
            query_states,
            kv,
            attn_mask=attention_mask,
            causal=True,
            layout="b h s d",
        )
        ...
        attn_output = self.o_proj(attn_output)
        ...
        return attn_output
```

SMP 라이브러리는 [`torch.sagemaker.nn.attn.FlashGroupedQueryAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-flashGroupedQueryAttn) API를 낮은 수준에서 사용하는 [`torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention`](distributed-model-parallel-v2-reference.md#model-parallel-v2-torch-sagemaker-reference-llamaFlashAttn)도 제공합니다. Hugging Face 트랜스포머는 v4.36.0에서 [https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py)라는 유사한 구현을 수행합니다. 다음 코드 조각은 SMP v2 `LlamaFlashAttention` API 또는 트랜스포머 `LlamaFlashAttention2` API를 사용하여 기존 Llama 모델의 주의 계층을 대체하는 방법을 보여줍니다.

```
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention
from transformers.models.llama.modeling_llama import LlamaFlashAttention2

flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2

attn_name = "self_attn"
for layer in model.model.layers:
    prev_layer = getattr(layer, attn_name)
    setattr(layer, attn_name, flash_attn_class(model.config))
```