Skip to content

Commit

Permalink
Update 2024-08-07-flexattention.md (#1707)
Browse files Browse the repository at this point in the history
* Update 2024-08-07-flexattention.md

* Update 2024-08-07-flexattention.md

---------

Co-authored-by: Chris Abraham <cjyabraham@gmail.com>
  • Loading branch information
Chillee and cjyabraham authored Feb 6, 2025
1 parent 25636f5 commit 67fc54c
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions _posts/2024-08-07-flexattention.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
layout: blog_detail
title: "FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention"
author: "Team PyTorch: Horace He, Driss Guessous, Yanbo Liang, Joy Dong"
author: "Team PyTorch: Driss Guessous, Yanbo Liang, Joy Dong, Horace He"
---

![a cartoon chart flexing his muscles](/assets/images/flexattention/fg1.jpg){:style="width:100%"}
Expand Down Expand Up @@ -131,7 +131,7 @@ Alibi is similar to relative positional encodings with one exception \- it has a
alibi_bias = generate_alibi_bias() # [num_heads]

def alibi(score, b, h, q_idx, kv_idx):
bias = alibi_bias[h] * (q_idx - kv_idx)
bias = alibi_bias[h] * (kv_idx - q_idx)
return score + bias
```

Expand Down Expand Up @@ -218,12 +218,12 @@ def sliding_window_causal(b, h, q_idx, kv_idx):
return causal_mask & window_mask

# If you want to be cute...
from torch.nn.attention import or_masks
from torch.nn.attention import and_masks

def sliding_window(b, h, q_idx, kv_idx)
return q_idx - kv_idx <= SLIDING_WINDOW

sliding_window_causal = or_masks(causal_mask, sliding_window)
sliding_window_causal = and_masks(causal_mask, sliding_window)
```

We benchmark it against `F.scaled_dot_product_attention` with a sliding window mask as well as FA2 with a causal mask (as a reference point for performance). Not only are we significantly faster than `F.scaled_dot_product_attention`, we’re *also* significantly faster than FA2 with a causal mask as this mask has significantly more sparsity.
Expand Down Expand Up @@ -479,4 +479,4 @@ We want to highlight some prior work (and people) that have inspired FlexAttenti
- The Jax team's work on SplashAttention
- Philippe Tillet and Keren Zhou for helping us with Triton
- Ali Hassani for discussions on neighborhood attention
- Everybody who's complained about attention kernels not supporting their favorite attention variant :)
- Everybody who's complained about attention kernels not supporting their favorite attention variant :)

0 comments on commit 67fc54c

Please sign in to comment.