|
8 | 8 | #include "openvino/pass/stateful_to_stateless.hpp"
|
9 | 9 | #include "openvino/runtime/core.hpp"
|
10 | 10 | #include "openvino/opsets/opset13.hpp"
|
| 11 | +#include "openvino/core/preprocess/pre_post_process.hpp" |
11 | 12 |
|
12 | 13 | #include <jinja2cpp/user_callable.h>
|
13 | 14 |
|
|
16 | 17 |
|
17 | 18 | namespace {
|
18 | 19 |
|
| 20 | +std::shared_ptr<ov::Model> cvt_kvcache_to_fp16(const std::shared_ptr<ov::Model>& model) { |
| 21 | + ov::preprocess::PrePostProcessor ppp(model); |
| 22 | + |
| 23 | + for (auto tensor : model->inputs()) { |
| 24 | + if (tensor.get_any_name().find("past_key") != std::string::npos) { |
| 25 | + ppp.input(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16); |
| 26 | + } |
| 27 | + } |
| 28 | + |
| 29 | + for (auto tensor : model->outputs()) { |
| 30 | + if (tensor.get_any_name().find("present") != std::string::npos) { |
| 31 | + ppp.output(tensor.get_any_name()).tensor().set_element_type(ov::element::Type_t::f16); |
| 32 | + } |
| 33 | + } |
| 34 | + |
| 35 | + return ppp.build(); |
| 36 | +} |
| 37 | + |
19 | 38 | void align_u4_zp_constants(const std::shared_ptr<ov::Model>& model) {
|
20 | 39 | for (auto op : model->get_ops()) {
|
21 | 40 | if (ov::op::util::is_constant(op)) {
|
@@ -280,17 +299,19 @@ void StaticLLMPipeline::setupAndCompileModels(
|
280 | 299 | align_u4_zp_constants(m_kvcache_model);
|
281 | 300 | // (4) Replace KV-tensors for the entire cache to tensors only for new token
|
282 | 301 | m_kvcache_model = redirect_new_kv_to_output(m_kvcache_model);
|
283 |
| - // (5) Clone the model - this will be prefill |
| 302 | + // (5) Convert kvcache tensors to fp16 precision |
| 303 | + m_kvcache_model = cvt_kvcache_to_fp16(m_kvcache_model); |
| 304 | + // (6) Clone the model - this will be prefill |
284 | 305 | m_prefill_model = m_kvcache_model->clone();
|
285 | 306 | m_prefill_model->set_friendly_name(m_kvcache_model->get_friendly_name() + "_prefill");
|
286 |
| - // (6) Reshape both models to static shape |
| 307 | + // (7) Reshape both models to static shape |
287 | 308 | const auto kMaxPromptLen = pop_or_default(pipeline_config, "MAX_PROMPT_LEN", 1024u);
|
288 | 309 | const auto kMinResponseLen = pop_or_default(pipeline_config, "MIN_RESPONSE_LEN", 150u);
|
289 | 310 | KVAxesPosition axes = get_kv_axes(get_model_type_from_json(path / "config.json"));
|
290 | 311 | m_kvcache_desc = KVCacheDesc { kMaxPromptLen, kMaxPromptLen + kMinResponseLen, 0u, axes.seq_len };
|
291 | 312 | reshape_to_static(m_prefill_model, m_kvcache_desc.max_prompt_size, m_kvcache_desc.max_prompt_size, axes);
|
292 | 313 | reshape_to_static(m_kvcache_model, 1u, m_kvcache_desc.total_size, axes);
|
293 |
| - // (7) Compile both model |
| 314 | + // (8) Compile both model |
294 | 315 | auto prefill_config = pop_or_default(pipeline_config, "PREFILL_CONFIG", get_default_prefill_config());
|
295 | 316 | auto generate_config = pop_or_default(pipeline_config, "GENERATE_CONFIG", get_default_generate_config());
|
296 | 317 | merge_config_with(prefill_config, pipeline_config);
|
|
0 commit comments