FlashAttention - Amazon SageMaker AI

FlashAttention

SMP v2 es compatible con los kernels FlashAttention y facilita su aplicación a distintos escenarios para los modelos de Hugging Face Transformer. Tenga en cuenta que si utiliza el paquete FlashAttention v2.0 o posterior, SMP utilizará FlashAttention v2; sin embargo, Triton Flash Attention utiliza de forma predeterminada el kernel de atención flash de FlashAttention v1.x, por lo que es compatible exclusivamente con FlashAttention v1.

El módulo (nn.Module) es una API de bajo nivel que define las capas de atención de un modelo. Debe aplicarse inmediatamente después de la creación del modelo, desde la API AutoModelForCausalLM.from_config(), por ejemplo, y antes de transformar o encapsular el modelo con FSDP.

Uso de kernels de FlashAttention para atención automática

En el siguiente fragmento de código se muestra cómo usar la API torch.sagemaker.nn.attn.FlashSelfAttention que proporciona SMP v2.

def new_attn(self, q, k, v, attention_mask=None, head_mask=None): return ( self.flashmod((q, k, v), causal=True, cast_dtype=torch.bfloat16, layout="b h s d"), None, ) for layer in model.gpt_neox.layers: layer.attention.flash_mod = torch.sagemaker.nn.attn.FlashSelfAttention() layer.attention._attn = functools.partial(new_attn, layer.attention)

Uso de kernels de FlashAttention para atención de consultas agrupadas

SMP v2 también es compatible con los kernels FlashAttention para atención de consultas agrupadas (GQA) y facilita su aplicación a distintos escenarios para los modelos de Hugging Face Transformer. A diferencia de la arquitectura de atención original, GQA divide los encabezados de consulta en grupos iguales, y los encabezados de consulta del mismo grupo comparten los mismos encabezados clave y de valor. Por tanto, los encabezados q y kv se pasan a la llamada hacia delante por separado. Nota: el número de encabezados q debe ser divisible por el número de encabezados kv.

Ejemplo de uso de FlashGroupedQueryAttention

En el siguiente fragmento de código se muestra cómo usar la API torch.sagemaker.nn.attn.FlashGroupedQueryAttention que proporciona SMP v2.

from transformers.models.llama.modeling_llama import LlamaAttention from torch.sagemaker.nn.attn import FlashGroupedQueryAttention class LlamaFlashAttention(LlamaAttention): def __init__(self, config: LlamaConfig): super().__init__(config) self.flash_attn = FlashGroupedQueryAttention( attention_dropout_prob=0.0, ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ... ): query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) ... kv = (key_states, value_states) attn_output = self.flash_attn( query_states, kv, attn_mask=attention_mask, causal=True, layout="b h s d", ) ... attn_output = self.o_proj(attn_output) ... return attn_output

La biblioteca de SMP también proporciona torch.sagemaker.nn.huggingface.llama_flashattn.LlamaFlashAttention, que utiliza la API torch.sagemaker.nn.attn.FlashGroupedQueryAttention en un nivel bajo. Hugging Face Transformers tiene una implementación similar llamada LlamaFlashAttention2 desde la v4.36.0. El siguiente fragmento de código muestra cómo usar la API LlamaFlashAttention de SMP v2 o la API LlamaFlashAttention2 de transformadores para reemplazar las capas de atención de un modelo de Llama existente.

from torch.sagemaker.nn.huggingface.llama_flashattn import LlamaFlashAttention from transformers.models.llama.modeling_llama import LlamaFlashAttention2 flash_attn_class = LlamaFlashAttention # or flash_attn_class = LlamaFlashAttention2 attn_name = "self_attn" for layer in model.model.layers: prev_layer = getattr(layer, attn_name) setattr(layer, attn_name, flash_attn_class(model.config))