Skip to content

Commit 8a5f0ea

Browse files
committed
[GPU] FP32 acc for 2nd+ token PagedAttention
1 parent cfbc998 commit 8a5f0ea

File tree

4 files changed

+47
-9
lines changed

4 files changed

+47
-9
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl/crop.cpp

+13-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,19 @@ struct crop_impl : typed_primitive_impl_ocl<crop> {
5555
}
5656

5757
update_shapes(*_kernel_data.params, impl_param);
58-
auto runtime_offset = convert_data_tensor(impl_param.get_input_layout(), impl_param.input_offsets[0]).GetFirstElementOffset();
58+
59+
// Reset input_layout padding as the offset configured by crop should affect only "data"
60+
// area and shouldn't depend on input_layout paddings.
61+
// For example, for an input shape like: [1, 32, 128 (pad_before=512, pad_after=0), 8]
62+
// with crop_axis=2 and split_lengths = {64, 64},
63+
// runtime_offset should be set in terms of [1, 32, 128, 8] shape, as the kernel reads data
64+
// using "input[GET_INDEX(INPUT, order) + runtime_offset]", where GET_INDEX already reflects input
65+
// data paddings.
66+
// So crop.out0's runtime_offset=0 and crop.out1's runtime_offset=512.
67+
auto input_layout = impl_param.get_input_layout();
68+
input_layout.data_padding = padding();
69+
70+
auto runtime_offset = convert_data_tensor(input_layout, impl_param.input_offsets[0]).GetFirstElementOffset();
5971
kernel_selector::ScalarDescriptor s;
6072
s.t = kernel_selector::ScalarDescriptor::Types::UINT32;
6173
s.v.u32 = static_cast<uint32_t>(runtime_offset);

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl

+7-7
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ KERNEL(pa_sdpa_opt)(
107107
#endif
108108

109109
// SLM for intermediate QK results
110-
__local OUTPUT_TYPE slm_qk_vals[SEQ_LEN_PARTITION_SIZE];
110+
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_vals[SEQ_LEN_PARTITION_SIZE];
111111

112112
// SLM buffers for SoftMax calculation and qk_max/qk_sums results aggregation across all WGs
113113
__local SOFTMAX_ACCUMULATOR_TYPE slm_qk_max_vals[SUBGROUPS_PER_WG];
@@ -166,7 +166,7 @@ KERNEL(pa_sdpa_opt)(
166166
#endif
167167
const uint block_offset = block_indices[start_block_idx + block_num * SUBGROUPS_PER_WG] * HEAD_SIZE * KV_HEADS_NUM * SUBGROUP_SIZE + head_idx * HEAD_SIZE * SUBGROUP_SIZE;
168168

169-
INPUT0_TYPE qk_acc = INPUT0_VAL_ZERO;
169+
SOFTMAX_ACCUMULATOR_TYPE qk_acc = SOFTMAX_ACCUMULATOR_VAL_ZERO;
170170

171171
#define KEY_VEC_SIZE SUBGROUP_SIZE
172172
unroll_for (uint qk_idx = 0; qk_idx < HEAD_SIZE / KEY_VEC_SIZE; qk_idx++) {
@@ -181,9 +181,9 @@ KERNEL(pa_sdpa_opt)(
181181

182182
unroll_for (uint i = 0; i < KEY_VEC_SIZE; i++) {
183183
#if STORE_QUERY_TO_SLM
184-
qk_acc = mad(sub_group_broadcast(q_val, i), k_vals[i], qk_acc);
184+
qk_acc = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(q_val, i)), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), qk_acc);
185185
#else
186-
qk_acc = mad(sub_group_broadcast(q_val[qk_idx], i), k_vals[i], qk_acc);
186+
qk_acc = mad(TO_SOFTMAX_ACCUMULATOR_TYPE(sub_group_broadcast(q_val[qk_idx], i)), TO_SOFTMAX_ACCUMULATOR_TYPE(k_vals[i]), qk_acc);
187187
#endif
188188
}
189189
}
@@ -196,7 +196,7 @@ KERNEL(pa_sdpa_opt)(
196196
#endif
197197

198198
if (token_idx >= seq_len)
199-
qk_acc = INPUT0_VAL_MIN;
199+
qk_acc = SOFTMAX_ACCUMULATOR_VAL_MIN;
200200

201201
qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(qk_acc));
202202

@@ -235,7 +235,7 @@ KERNEL(pa_sdpa_opt)(
235235
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) {
236236
#endif
237237
SOFTMAX_ACCUMULATOR_TYPE qk_new = native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) - qk_max);
238-
slm_qk_vals[local_data_idx] = TO_OUTPUT_TYPE(qk_new);
238+
slm_qk_vals[local_data_idx] = qk_new;
239239

240240
exp_sum += qk_new;
241241
}
@@ -266,7 +266,7 @@ KERNEL(pa_sdpa_opt)(
266266
if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) {
267267
#endif
268268
SOFTMAX_ACCUMULATOR_TYPE qk_new = TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) / exp_sum;
269-
slm_qk_vals[local_data_idx] = TO_OUTPUT_TYPE(qk_new);
269+
slm_qk_vals[local_data_idx] = qk_new;
270270
}
271271
}
272272

src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,20 @@ std::vector<cldnn::event::ptr> SyncInferRequest::prepare_input(const std::string
799799
auto& engine = m_graph->get_engine();
800800
auto& stream = network->get_stream();
801801

802+
if (internal_name == "parameter:input_ids") {
803+
auto data = user_tensor->data<int64_t>();
804+
805+
auto print_arr = [&](int64_t* vec, size_t max_len, std::string name) {
806+
std::stringstream ss;
807+
for (size_t i = 0; i < max_len; i++) {
808+
ss << vec[i] << ", ";
809+
}
810+
std::cout << "Array " << name << " (len=" << max_len << ") content: " << ss.str() << "\n";
811+
};
812+
813+
print_arr(data, user_tensor->get_size(), "input_ids");
814+
}
815+
802816
auto need_lockable_mem = network->does_node_need_lockable_output(internal_name);
803817

804818
OPENVINO_ASSERT(pshape.compatible(ov::PartialShape(user_tensor->get_shape())) || is_batched_input(port),

src/plugins/intel_gpu/src/runtime/execution_config.cpp

+13-1
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,25 @@ class PerformanceModeValidator : public BaseValidator {
3434
};
3535

3636
void ExecutionConfig::set_default() {
37+
auto default_inference_precision_hint = ov::element::f16;
38+
int USE_FP32 = 0;
39+
if (const auto env_var = std::getenv("USE_FP32")) {
40+
std::istringstream ss(env_var);
41+
ss >> USE_FP32;
42+
}
43+
44+
if (USE_FP32) {
45+
default_inference_precision_hint = ov::element::f32;
46+
std::cout << "inference_precision forced to f32\n";
47+
}
48+
3749
register_property<PropertyVisibility::PUBLIC>(
3850
std::make_tuple(ov::device::id, "0"),
3951
std::make_tuple(ov::enable_profiling, false),
4052
std::make_tuple(ov::cache_dir, ""),
4153
std::make_tuple(ov::num_streams, 1),
4254
std::make_tuple(ov::compilation_num_threads, std::max(1, static_cast<int>(std::thread::hardware_concurrency()))),
43-
std::make_tuple(ov::hint::inference_precision, ov::element::f16, InferencePrecisionValidator()),
55+
std::make_tuple(ov::hint::inference_precision, default_inference_precision_hint, InferencePrecisionValidator()),
4456
std::make_tuple(ov::hint::model_priority, ov::hint::Priority::MEDIUM),
4557
std::make_tuple(ov::hint::performance_mode, ov::hint::PerformanceMode::LATENCY, PerformanceModeValidator()),
4658
std::make_tuple(ov::hint::execution_mode, ov::hint::ExecutionMode::PERFORMANCE),

0 commit comments

Comments
 (0)