You are required to implement a pytorch module named OnlineSlidingWindowAttn
in src/modeling/attention.py
.
- Building upon the
OfflineSlidingWindowAttn
module described in task1, we continue to implement theOnlineSlidingWindowAttn
module, which is the online version of the former one, only applying attention on a block of$Q_{bq_i},K_{bkv_j},V_{bkv_j}$ inAttnQKVLayout.BSHD
layout andAttnQKVPackFormat.Q_K_V
packing format, and aggregate the local output$O_{bq_i}^{(bkv_j)}$ of this block to the global output$O$ , with the help oflog-sum-exp
-style softmax calibration coefficient$lse$ . - To be more specific, although both the computation cost and the memory footprint of the
attention
operation generally follow the quadratic complexity, we can reduce the memory complexity to almost linear by transforming theoffline softmax
toonline softmax
(See the Online Softmax Paper in References). The basic idea is to split thesq
-dim andskv
-dim of$Q$ and$K,V$ equally tobq
-dim andbkv
-dim respectively as blocks, and each time only apply attention on a single block of$Q_{bq_i},K_{bkv_j},V_{bkv_j}$ , where the indices$bq_i \in [0, \frac{sq}{bq}]$ ,$bkv_j \in [0, \frac{skv}{bkv}]$ . - The local attention output of this block is denoted as
$O_{bq_i}^{(bkv_j)}$ , with the shape[b, bq, hq, hd]
. Give the global output buffer$O$ with the shape[b, sq, hq, hd]
, how can we aggregate$O_{bq_i}^{(bkv_j)}$ to$O$ accurately since the local/global softmax weights are not normalized from the same factors? - As the stable softmax factorization equation shown below, if we split a row vector
$X \in \mathbb{R}^{n}$ into two parts$X_1 \in \mathbb{R}^{n_1}$ and$X_2 \in \mathbb{R}^{n_2}$ , where$n_1 + n_2 = n$ , then the key to restore the softmax of the whole$X$ from the local softmax of$X_1$ and$X_2$ is to re-calculate the new normalization factor$l$ and new maximum value$m$ .
- To simplify the above calibration of softmax, we can also utilize the
log-sum-exp
operator$\text{lse}$ (See the Pytorch LSE Functional in References) following the flash-attention's strategy (See the Flash Attention 2 Paper in References for more details) to rewrite the stable softmax operation as:
- where the last step uses a property of
$\text{lse}$ :$\text{lse}(X) = \max{(X)} + \text{lse}(X - \max{(X)})$ (See the LSE Wiki in References). So the stable softmax factorization can be also re-formulated with the$\text{lse}$ operation as:
-
where the last three steps are designed to address the
$\exp$ explosion problem by extracting the maximum values as the additive term to prevent the exponential term from being positive large, along with the help of$\text{log1p}$ or$\text{softplus}$ operation for numerical stability (See the Pytorch Log1p / Softplus Functional in References). Therefore, for each online attention step, we just need to apply the local block of attention to get$O_{bq_i}^{(bkv_j)}$ along with the local statistics$lse^{(bkv_j)}_{bq_i}$ , and then update the global statistics$lse$ to calibrate the global output$O$ for the rows indexing in the range$[bq_i\cdot bq, (bq_i + 1)\cdot bq]$ , as the equations shown above. -
To make full use of the implemented
OfflineSlidingWindowAttn
module in task1, theOnlineSlidingWindowAttn
module just inherits theOfflineSlidingWindowAttn
module, where the input arguments are different in several ways as follows:- To simplify the diversity of inputs, the
OnlineSlidingWindowAttn
module only accepts the block of$Q_{bq_i},K_{bkv_j},V_{bkv_j}$ inAttnQKVLayout.BSHD
layout andAttnQKVPackFormat.Q_K_V
packing format, thus no arguments are required for the QKV packing format and layout. - Since the
sofmax clipping
andsoftmax dropout
should only be applied to the global softmax weights$A$ , we disable these two stabilization strategies in theOnlineSlidingWindowAttn
module. - To better prepare for the online attention forward pass during the initialization, we provide
block_size
andseqlen
for$Q$ and$K,V$ respectively in the argument list of__init__
method. Therefore, you can pre-calculate something such as the full attention mask in the__init__
method. - Since the layout is fixed to
AttnQKVLayout.BSHD
, we don't need neithercu_seqlens_q
norcu_seqlens_kv
anymore in the argument list of the forward method. - The
q,k,v
arguments for the forward method are only a single block of$Q_{bq_i},K_{bkv_j},V_{bkv_j}$ , where the$bq_i$ and$bkv_j$ are given as argumentsblock_idx_q
andblock_idx_kv
respectively. - The global output
$O$ and the global statistics$lse$ (each entry is either partially updated already or set to the initial value as0
forO
and-∞
forlse
) are given as argumentsglobal_o
andglobal_lse
respectively, and you should update them in-place, thus no return value is needed for the forward method.
- To simplify the diversity of inputs, the
In summary, you should implement this OnlineSlidingWindowAttn
module, which takes a block of AttnQKVLayout.BSHD
layout and AttnQKVPackFormat.Q_K_V
packing format given the block index offline sliding window attention
operation on this block, gets the local output
- First of all, we inherit the same notice mentioned in task1.
- The
dtype
anddevice
ofq,k,v,global_o
are ensured to be the same, while we keep thedtype
ofglobal_lse
astorch.float32
to maintain the high precision to reduce the accumulation error. - When the
seqlen
can not be fully divided by theblock_size
, the last in-complete block will be padded at the end of the sequence-dim to match the correspondingblock_size
, where the padding entries are filled with zeros. - The
block_idx_q
andblock_idx_kv
are ensured to be in their corresponding valid ranges. -
Note that any online attention step in the forward pass of
OnlineSlidingWindowAttn
module should be regarded as an inner iterative step for the corresponding offline attention, i.e. if we tranverse each$bq_i \in [0, \frac{sq}{bq}]$ and$bkv_j \in [0, \frac{skv}{bkv}]$ on this online attention module, the final updated output$O$ should be the same as the corresponding offline attention module, ignoring the accumulation error.
Hints: Here're some references which may be helpful to your task, or just deepen / broaden your knowledge to attention layers particularly in transformer:
!! Remember: it is a fundemental and essential capability to search, read, think and learn from the paper, source code, and official documentation for your answer, try NOT to rely too much on some biased and superficial blogs, e.g. CSDN !!
- Online Softmax Paper
- LSE Wiki
- Pytorch LSE Functional
- Pytorch Log1p Functional
- Pytorch Softplus Functional
- Nvidia Methods of Improving LLM Training Stability
- Llama Attention Layer
- Google MHA paper
- Google MQA paper
- Google GQA paper
- Pytorch Repeat Interleave Functional
- Transformer paper
- Flash Attention 2 Paper
- Flash Attention Interface
- Pytorch SDPA Functional
- Pytorch FlexAttention Functional