Support for FlashAttention
Support for FlashAttention is a feature of the library only applicable for the
                distributed transformer model, which is a Transformer model
            wrapped by smp.DistributedModel()
The FlashAttentionattention_head_size is set to a value that's a multiple of 8 and less
            than 128. Therefore, when you train a distributed transformer and make sure that
            FlashAttention works properly, you should adjust parameters to make the attention head
            size comply the requirements. For more information, see also Installation and features
For
            example,
            assume that you configure a Transformer model with hidden_width=864 and
                num_heads=48. The head size of FlashAttention is calculated as
                attention_head_size = hidden_width / num_heads = 864 / 48 = 18. To
            enable FlashAttention, you need to adjust the num_heads parameter to
                54, so that attention_head_size = hidden_width / num_heads = 864
                / 54 = 16, which is a multiple of 8.