Skip to content

Commit 2ad9125

Browse files
committed
Update Gemm1 calculation
1 parent 9bbf8f0 commit 2ad9125

File tree

2 files changed

+52
-18
lines changed

2 files changed

+52
-18
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl

+33-18
Original file line numberDiff line numberDiff line change
@@ -87,40 +87,55 @@ KERNEL(sdpa_opt)(
8787
const uint start_partition_idx = partition_idx * SEQ_LEN_PARTITION_SIZE;
8888

8989
{ // start Gemm1
90-
// const uint query_offset = INPUT0_GET_INDEX(batch_idx, head_num_idx, seq_idx, sgid * SUBGROUP_SIZE);
90+
#define QUERY_BLOCK_SIZE 8
91+
#define QUERY_BLOCK_READ_NEW(ptr, offset) BLOCK_READN(INPUT0_TYPE, QUERY_BLOCK_SIZE, ptr, offset)
92+
#define QUERY_BLOCK_NEW MAKE_VECTOR_TYPE(INPUT0_TYPE, QUERY_BLOCK_SIZE)
93+
94+
const uint query_offset = INPUT0_GET_INDEX(batch_idx, head_num_idx, seq_idx, 0);
95+
QUERY_BLOCK_NEW query_vals = QUERY_BLOCK_READ_NEW(query_input, query_offset);
9196
// query_vals_local[head_size_idx] = QUERY_BLOCK_READ(query_input, query_offset);
9297

9398
// barrier(CLK_LOCAL_MEM_FENCE);
9499

95100
/* Calculate Gemm1 */
96-
for (uint seq_len = lid; seq_len < partition_seq_len; seq_len += wi_num_per_partition) {
97-
uint query_offset = INPUT0_GET_INDEX(batch_idx, head_num_idx, seq_idx, 0);
101+
for (uint seq_len = sgid; seq_len < partition_seq_len; seq_len += (HEAD_SIZE / SUBGROUP_SIZE)) {
98102
uint key_offset = INPUT1_GET_INDEX(batch_idx, head_num_idx, start_partition_idx + seq_len, 0);
99103

100104
INPUT0_TYPE acc = INPUT0_VAL_ZERO;
101-
unroll_for (uint h = 0; h < HEAD_SIZE; h += SUBGROUP_SIZE) {
102-
INPUT0_TYPE query_val = QUERY_BLOCK_READ(query_input, query_offset);
103-
KEY_VEC_TYPE key_vec = AS_VALUE_VEC(VLOAD(0, key_input + key_offset));
104105

105-
unroll_for (uint i = 0; i < SUBGROUP_SIZE; i++) {
106-
acc = mad(sub_group_broadcast(query_val, i), key_vec[i], acc);
106+
#define MULS_NUM 2
107+
#define KEY_BLOCK_READ_NEW(ptr, offset) BLOCK_READN(INPUT1_TYPE, MULS_NUM, ptr, offset)
108+
#define KEY_BLOCK_NEW MAKE_VECTOR_TYPE(INPUT1_TYPE, MULS_NUM)
109+
110+
unroll_for (uint h = 0; h < HEAD_SIZE / SUBGROUP_SIZE / MULS_NUM; h++) {
111+
KEY_BLOCK_NEW key_vec = KEY_BLOCK_READ_NEW(key_input, key_offset);
112+
113+
unroll_for (uint i = 0; i < MULS_NUM; i++) {
114+
#if MULS_NUM == 1
115+
acc = mad(query_vals[h * MULS_NUM + i], key_vec, acc);
116+
#else
117+
acc = mad(query_vals[h * MULS_NUM + i], key_vec[i], acc);
118+
#endif
107119
}
108120

109-
query_offset += SUBGROUP_SIZE;
110-
key_offset += SUBGROUP_SIZE;
121+
key_offset += SUBGROUP_SIZE * MULS_NUM;
111122
}
112123

113-
// Apply scale
114-
acc *= scale_val;
124+
acc = sub_group_reduce_add(acc);
125+
126+
if (sglid == 0) {
127+
// Apply scale
128+
acc *= scale_val;
115129

116-
// Apply attention mask
117-
uint attn_mask_offset = INPUT3_GET_INDEX_SAFE(batch_idx, head_num_idx, seq_idx, start_partition_idx + seq_len);
118-
acc += attn_mask[attn_mask_offset];
130+
// Apply attention mask
131+
uint attn_mask_offset = INPUT3_GET_INDEX_SAFE(batch_idx, head_num_idx, seq_idx, start_partition_idx + seq_len);
132+
acc += attn_mask[attn_mask_offset];
119133

120-
// Update qk_max value
121-
qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(acc));
134+
// Update qk_max value
135+
qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(acc));
122136

123-
qk_vals_local[seq_len] = acc;
137+
qk_vals_local[seq_len] = acc;
138+
}
124139
}
125140
} // finish Gemm1
126141

src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,32 @@ bool SDPAKernelOpt::Validate(const Params& p) const {
4848
return true;
4949
}
5050

51+
template <typename T>
52+
T convert_to(const std::string &str) {
53+
std::istringstream ss(str);
54+
T res;
55+
ss >> res;
56+
return res;
57+
}
58+
59+
template <>
60+
std::string convert_to(const std::string &str) {
61+
return str;
62+
}
63+
5164
JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t kernel_idx) const {
5265
auto jit = MakeBaseParamsJitConstants(params);
5366

5467
// const auto softmax_acc_dt = Datatype::F32;
5568
const auto softmax_acc_dt = params.inputs[0].GetDType();
5669
jit.Merge(MakeTypeJitConstants(softmax_acc_dt, "SOFTMAX_ACCUMULATOR"));
5770

71+
// if (const auto env_var = std::getenv("MULS_NUM")) {
72+
// auto muls_num = convert_to<size_t>(env_var);
73+
// std::cout << "Force MULS_NUM to " << muls_num << "\n";
74+
// jit.AddConstant(MakeJitConstant("MULS_NUM", muls_num));
75+
// }
76+
5877
const auto& config = params.conf;
5978
jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size));
6079
jit.AddConstant(MakeJitConstant("HEAD_SIZE", config.head_size));

0 commit comments

Comments
 (0)