Skip to content

Commit 045204b

Browse files
committed
[GPU] Enable KV-cache compression for systolic platforms by default for PagedAttention-based models
1 parent b88dcc6 commit 045204b

File tree

1 file changed

+23
-9
lines changed

1 file changed

+23
-9
lines changed

src/plugins/intel_gpu/src/runtime/execution_config.cpp

+23-9
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "openvino/core/model.hpp"
99
#include "openvino/op/loop.hpp"
1010
#include "openvino/op/lstm_sequence.hpp"
11+
#include "openvino/op/paged_attention.hpp"
1112
#include "openvino/op/search_sorted.hpp"
1213
#include "openvino/op/stft.hpp"
1314
#include "ov_ops/dynamic_quantize.hpp"
@@ -135,6 +136,7 @@ void ExecutionConfig::apply_model_specific_options(const IRemoteContext* context
135136

136137
const auto& ops = model.get_ops();
137138

139+
auto is_paged_attention_model = false;
138140
std::function<void(std::shared_ptr<Node>)> process_op = [&, this](std::shared_ptr<Node> op) {
139141
if (requires_new_shape_infer(op)) {
140142
m_allow_new_shape_infer = true;
@@ -158,12 +160,28 @@ void ExecutionConfig::apply_model_specific_options(const IRemoteContext* context
158160
}
159161
}
160162
}
163+
164+
if (ov::is_type<ov::op::PagedAttentionExtension>(op)) {
165+
is_paged_attention_model = true;
166+
}
161167
};
162168

163169
for (const auto& op : ops) {
164170
process_op(op);
165171
}
166172

173+
const auto& info = dynamic_cast<const RemoteContextImpl*>(context)->get_engine().get_device_info();
174+
if (!is_set_by_user(ov::hint::kv_cache_precision) || get_kv_cache_precision() == ov::element::dynamic) {
175+
if (is_paged_attention_model || !info.supports_immad) {
176+
// Enable KV-cache compression by default for:
177+
// 1) Non-systolic platforms in case of SDPA-based models
178+
// 2) For any platforms in case of PagedAttention-based model
179+
m_kv_cache_precision = ov::element::i8;
180+
} else {
181+
m_kv_cache_precision = get_inference_precision();
182+
}
183+
}
184+
167185
m_optimize_data = true;
168186
}
169187

@@ -185,15 +203,6 @@ void ExecutionConfig::finalize_impl(const IRemoteContext* context) {
185203
m_queue_type = QueueTypes::in_order;
186204
}
187205

188-
if (!is_set_by_user(ov::hint::kv_cache_precision) || get_kv_cache_precision() == ov::element::dynamic) {
189-
if (info.supports_immad) { // MFDNN-11755
190-
m_kv_cache_precision = get_inference_precision();
191-
} else {
192-
// Enable KV-cache compression by default for non-systolic platforms only
193-
m_kv_cache_precision = ov::element::i8;
194-
}
195-
}
196-
197206
// Enable dynamic quantization by default for non-systolic platforms
198207
if (!is_set_by_user(ov::hint::dynamic_quantization_group_size) && get_dynamic_quantization_group_size() == 0 && !info.supports_immad) {
199208
m_dynamic_quantization_group_size = 32;
@@ -203,6 +212,11 @@ void ExecutionConfig::finalize_impl(const IRemoteContext* context) {
203212
m_optimize_data = true;
204213
}
205214

215+
// Replace UINT8 KV-cache compression data type with INT8, as plugin is supposed to work with INT8 internally
216+
if (get_kv_cache_precision() == ov::element::u8) {
217+
m_kv_cache_precision = ov::element::i8;
218+
}
219+
206220
#ifdef ENABLE_DEBUG_CAPS
207221
// For now we apply env/config only for build with debug caps, but it can be updated in the future to allow
208222
// reading release options for any build type

0 commit comments

Comments
 (0)