FlashAttention
SMP v2 es compatible con los kernels FlashAttention
El módulo (nn.Module) es una API de bajo nivel que define las capas de atención de un modelo. Debe aplicarse inmediatamente después de la creación del modelo, desde la API AutoModelForCausalLM.from_config(), por ejemplo, y antes de transformar o encapsular el modelo con FSDP.
Uso de kernels de FlashAttention para atención automática
En el siguiente fragmento de código se muestra cómo usar la API torch.sagemaker.nn.attn.FlashSelfAttention que proporciona SMP v2.
def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)
Uso de kernels de FlashAttention para atención de consultas agrupadas
SMP v2 también es compatible con los kernels FlashAttention
Ejemplo de uso de FlashGroupedQueryAttention
En el siguiente fragmento de código se muestra cómo usar la API torch.sagemaker.nn.attn.FlashGroupedQueryAttention que proporciona SMP v2.
from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output
La biblioteca de SMP también proporciona torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention, que utiliza la API torch.sagemaker.nn.attn.FlashGroupedQueryAttention en un nivel bajo. Hugging Face Transformers tiene una implementación similar llamada LlamaFlashAttention2LlamaFlashAttention de SMP v2 o la API LlamaFlashAttention2 de transformadores para reemplazar las capas de atención de un modelo de Llama existente.
from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))