FlashAttention
SMP v2 prend en charge les noyaux FlashAttention
Le module (nn.Module) est une API de bas niveau qui définit les couches d’attention d’un modèle. Il doit être appliqué juste après la création du modèle, à partir de l’API AutoModelForCausalLM.from_config(), par exemple, et avant que le modèle ne soit transformé ou encapsulé avec FSDP.
Utilisation des noyaux FlashAttention pour l’auto-attention
L’extrait de code suivant illustre comment utiliser l’API torch.sagemaker.nn.attn.FlashSelfAttention fournie par SMP v2.
def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)
Utilisation de noyaux FlashAttention pour l’attention par requêtes groupées
SMP v2 prend également en charge les noyaux FlashAttention
Exemple d’utilisation de FlashGroupedQueryAttention
L’extrait de code suivant illustre comment utiliser l’API torch.sagemaker.nn.attn.FlashGroupedQueryAttention fournie par SMP v2.
from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output
La bibliothèque SMP fournit également torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention, qui utilise l’API torch.sagemaker.nn.attn.FlashGroupedQueryAttention à bas niveau. Les transformeurs Hugging Face bénéficient d’une implémentation similaire appelée LlamaFlashAttention2LlamaFlashAttention SMP v2 ou l’API LlamaFlashAttention2 Transformers pour remplacer les couches d’attention d’un modèle Llama existant.
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))