diff --git a/_posts/2024-08-07-flexattention.md b/_posts/2024-08-07-flexattention.md index 4c34879d33b6..acfc1fc40f01 100644 --- a/_posts/2024-08-07-flexattention.md +++ b/_posts/2024-08-07-flexattention.md @@ -1,7 +1,7 @@ --- layout: blog_detail title: "FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention" -author: "Team PyTorch: Horace He, Driss Guessous, Yanbo Liang, Joy Dong" +author: "Team PyTorch: Driss Guessous, Yanbo Liang, Joy Dong, Horace He" --- ![a cartoon chart flexing his muscles](/assets/images/flexattention/fg1.jpg){:style="width:100%"} @@ -131,7 +131,7 @@ Alibi is similar to relative positional encodings with one exception \- it has a alibi_bias = generate_alibi_bias() # [num_heads] def alibi(score, b, h, q_idx, kv_idx): - bias = alibi_bias[h] * (q_idx - kv_idx) + bias = alibi_bias[h] * (kv_idx - q_idx) return score + bias ``` @@ -218,12 +218,12 @@ def sliding_window_causal(b, h, q_idx, kv_idx): return causal_mask & window_mask # If you want to be cute... -from torch.nn.attention import or_masks +from torch.nn.attention import and_masks def sliding_window(b, h, q_idx, kv_idx) return q_idx - kv_idx <= SLIDING_WINDOW -sliding_window_causal = or_masks(causal_mask, sliding_window) +sliding_window_causal = and_masks(causal_mask, sliding_window) ``` We benchmark it against `F.scaled_dot_product_attention` with a sliding window mask as well as FA2 with a causal mask (as a reference point for performance). Not only are we significantly faster than `F.scaled_dot_product_attention`, we’re *also* significantly faster than FA2 with a causal mask as this mask has significantly more sparsity. @@ -479,4 +479,4 @@ We want to highlight some prior work (and people) that have inspired FlexAttenti - The Jax team's work on SplashAttention - Philippe Tillet and Keren Zhou for helping us with Triton - Ali Hassani for discussions on neighborhood attention -- Everybody who's complained about attention kernels not supporting their favorite attention variant :) \ No newline at end of file +- Everybody who's complained about attention kernels not supporting their favorite attention variant :)