@@ -87,40 +87,55 @@ KERNEL(sdpa_opt)(
87
87
const uint start_partition_idx = partition_idx * SEQ_LEN_PARTITION_SIZE ;
88
88
89
89
{ // start Gemm1
90
- // const uint query_offset = INPUT0_GET_INDEX(batch_idx, head_num_idx, seq_idx, sgid * SUBGROUP_SIZE);
90
+ #define QUERY_BLOCK_SIZE 8
91
+ #define QUERY_BLOCK_READ_NEW (ptr , offset ) BLOCK_READN(INPUT0_TYPE, QUERY_BLOCK_SIZE, ptr, offset)
92
+ #define QUERY_BLOCK_NEW MAKE_VECTOR_TYPE(INPUT0_TYPE, QUERY_BLOCK_SIZE)
93
+
94
+ const uint query_offset = INPUT0_GET_INDEX (batch_idx , head_num_idx , seq_idx , 0 );
95
+ QUERY_BLOCK_NEW query_vals = QUERY_BLOCK_READ_NEW (query_input , query_offset );
91
96
// query_vals_local[head_size_idx] = QUERY_BLOCK_READ(query_input, query_offset);
92
97
93
98
// barrier(CLK_LOCAL_MEM_FENCE);
94
99
95
100
/* Calculate Gemm1 */
96
- for (uint seq_len = lid ; seq_len < partition_seq_len ; seq_len += wi_num_per_partition ) {
97
- uint query_offset = INPUT0_GET_INDEX (batch_idx , head_num_idx , seq_idx , 0 );
101
+ for (uint seq_len = sgid ; seq_len < partition_seq_len ; seq_len += (HEAD_SIZE / SUBGROUP_SIZE )) {
98
102
uint key_offset = INPUT1_GET_INDEX (batch_idx , head_num_idx , start_partition_idx + seq_len , 0 );
99
103
100
104
INPUT0_TYPE acc = INPUT0_VAL_ZERO ;
101
- unroll_for (uint h = 0 ; h < HEAD_SIZE ; h += SUBGROUP_SIZE ) {
102
- INPUT0_TYPE query_val = QUERY_BLOCK_READ (query_input , query_offset );
103
- KEY_VEC_TYPE key_vec = AS_VALUE_VEC (VLOAD (0 , key_input + key_offset ));
104
105
105
- unroll_for (uint i = 0 ; i < SUBGROUP_SIZE ; i ++ ) {
106
- acc = mad (sub_group_broadcast (query_val , i ), key_vec [i ], acc );
106
+ #define MULS_NUM 2
107
+ #define KEY_BLOCK_READ_NEW (ptr , offset ) BLOCK_READN(INPUT1_TYPE, MULS_NUM, ptr, offset)
108
+ #define KEY_BLOCK_NEW MAKE_VECTOR_TYPE(INPUT1_TYPE, MULS_NUM)
109
+
110
+ unroll_for (uint h = 0 ; h < HEAD_SIZE / SUBGROUP_SIZE / MULS_NUM ; h ++ ) {
111
+ KEY_BLOCK_NEW key_vec = KEY_BLOCK_READ_NEW (key_input , key_offset );
112
+
113
+ unroll_for (uint i = 0 ; i < MULS_NUM ; i ++ ) {
114
+ #if MULS_NUM == 1
115
+ acc = mad (query_vals [h * MULS_NUM + i ], key_vec , acc );
116
+ #else
117
+ acc = mad (query_vals [h * MULS_NUM + i ], key_vec [i ], acc );
118
+ #endif
107
119
}
108
120
109
- query_offset += SUBGROUP_SIZE ;
110
- key_offset += SUBGROUP_SIZE ;
121
+ key_offset += SUBGROUP_SIZE * MULS_NUM ;
111
122
}
112
123
113
- // Apply scale
114
- acc *= scale_val ;
124
+ acc = sub_group_reduce_add (acc );
125
+
126
+ if (sglid == 0 ) {
127
+ // Apply scale
128
+ acc *= scale_val ;
115
129
116
- // Apply attention mask
117
- uint attn_mask_offset = INPUT3_GET_INDEX_SAFE (batch_idx , head_num_idx , seq_idx , start_partition_idx + seq_len );
118
- acc += attn_mask [attn_mask_offset ];
130
+ // Apply attention mask
131
+ uint attn_mask_offset = INPUT3_GET_INDEX_SAFE (batch_idx , head_num_idx , seq_idx , start_partition_idx + seq_len );
132
+ acc += attn_mask [attn_mask_offset ];
119
133
120
- // Update qk_max value
121
- qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC (qk_max , TO_SOFTMAX_ACCUMULATOR_TYPE (acc ));
134
+ // Update qk_max value
135
+ qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC (qk_max , TO_SOFTMAX_ACCUMULATOR_TYPE (acc ));
122
136
123
- qk_vals_local [seq_len ] = acc ;
137
+ qk_vals_local [seq_len ] = acc ;
138
+ }
124
139
}
125
140
} // finish Gemm1
126
141
0 commit comments