Skip to content

Commit 616ab12

Browse files
authored
Add kvcache convert to fp16 for StaticLLMPipeline (openvinotoolkit#898)
1 parent f053e5e commit 616ab12

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

src/cpp/src/llm_pipeline_static.cpp

+24-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "openvino/pass/stateful_to_stateless.hpp"
99
#include "openvino/runtime/core.hpp"
1010
#include "openvino/opsets/opset13.hpp"
11+
#include "openvino/core/preprocess/pre_post_process.hpp"
1112

1213
#include <jinja2cpp/user_callable.h>
1314

@@ -16,6 +17,24 @@
1617

1718
namespace {
1819

20+
std::shared_ptr<ov::Model> cvt_kvcache_to_fp16(const std::shared_ptr<ov::Model>& model) {
21+
ov::preprocess::PrePostProcessor ppp(model);
22+
23+
for (auto tensor : model->inputs()) {
24+
if (tensor.get_any_name().find("past_key") != std::string::npos) {
25+
ppp.input(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16);
26+
}
27+
}
28+
29+
for (auto tensor : model->outputs()) {
30+
if (tensor.get_any_name().find("present") != std::string::npos) {
31+
ppp.output(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16);
32+
}
33+
}
34+
35+
return ppp.build();
36+
}
37+
1938
void align_u4_zp_constants(const std::shared_ptr<ov::Model>& model) {
2039
for (auto op : model->get_ops()) {
2140
if (ov::op::util::is_constant(op)) {
@@ -280,17 +299,19 @@ void StaticLLMPipeline::setupAndCompileModels(
280299
align_u4_zp_constants(m_kvcache_model);
281300
// (4) Replace KV-tensors for the entire cache to tensors only for new token
282301
m_kvcache_model = redirect_new_kv_to_output(m_kvcache_model);
283-
// (5) Clone the model - this will be prefill
302+
// (5) Convert kvcache tensors to fp16 precision
303+
m_kvcache_model = cvt_kvcache_to_fp16(m_kvcache_model);
304+
// (6) Clone the model - this will be prefill
284305
m_prefill_model = m_kvcache_model->clone();
285306
m_prefill_model->set_friendly_name(m_kvcache_model->get_friendly_name() + "_prefill");
286-
// (6) Reshape both models to static shape
307+
// (7) Reshape both models to static shape
287308
const auto kMaxPromptLen = pop_or_default(pipeline_config, "MAX_PROMPT_LEN", 1024u);
288309
const auto kMinResponseLen = pop_or_default(pipeline_config, "MIN_RESPONSE_LEN", 150u);
289310
KVAxesPosition axes = get_kv_axes(get_model_type_from_json(path / "config.json"));
290311
m_kvcache_desc = KVCacheDesc { kMaxPromptLen, kMaxPromptLen + kMinResponseLen, 0u, axes.seq_len };
291312
reshape_to_static(m_prefill_model, m_kvcache_desc.max_prompt_size, m_kvcache_desc.max_prompt_size, axes);
292313
reshape_to_static(m_kvcache_model, 1u, m_kvcache_desc.total_size, axes);
293-
// (7) Compile both model
314+
// (8) Compile both model
294315
auto prefill_config = pop_or_default(pipeline_config, "PREFILL_CONFIG", get_default_prefill_config());
295316
auto generate_config = pop_or_default(pipeline_config, "GENERATE_CONFIG", get_default_generate_config());
296317
merge_config_with(prefill_config, pipeline_config);

0 commit comments

Comments
 (0)