本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
支持 FlashAttention
Suppor FlashAttention t for 是该库的一项功能,仅适用于分布式变压器模型,分布式变压器模型是为模型并行训练而封装的 Trans smp.DistributedModel()
该FlashAttentionattention_head_size为 8 的倍数且小于 128 的值时才支持模型。因此,在训练分布式变压器并确保其 FlashAttention 正常工作时,应调整参数以使注意力头大小符合要求。有关更多信息,另请参阅FlashAttention GitHub存储库中的安装和功能
例如,假设您使用 hidden_width=864 和 num_heads=48 配置转换器模型。的头部大小计算公式 FlashAttention 为attention_head_size = hidden_width / num_heads = 864 / 48 = 18。要启用 FlashAttention,您需要将num_heads参数调整为 54attention_head_size = hidden_width / num_heads = 864
/ 54 = 16,即 8 的倍数。