Skip to content

Commit 2237b5e

Browse files
committed
update according to comments
1 parent d8733cb commit 2237b5e

13 files changed

+94
-68
lines changed

src/cpp/src/llm_pipeline_stateful.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -275,11 +275,10 @@ EncodedResults StatefulLLMPipeline::generate(
275275
"but you have '" + std::to_string(num_inputs) + "' inputs");
276276

277277
if (is_chat_conversation) {
278-
if (m_kv_cache_state.get_state().empty() || m_use_full_chat_history)
278+
if (m_use_full_chat_history)
279279
reset_kv_state();
280280
else
281-
ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_cache_state.num_tokens_to_trim,
282-
m_kv_cache_state.seq_length_axis, m_adapter_controller);
281+
ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_cache_state, m_adapter_controller);
283282
}
284283

285284
size_t kv_cache_len = 0;

src/cpp/src/llm_pipeline_stateful.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
2727
size_t m_max_kv_cache_size = std::numeric_limits<size_t>::max();
2828
bool m_is_npu = false;
2929
// include reflection of tokens contained in the kv cache and amount of tokens, which are needed to trim from kv cache on the next step of chat
30-
KVCacheState m_kv_cache_state;
30+
utils::KVCacheState m_kv_cache_state;
3131

3232
void reset_kv_state();
3333
public:

src/cpp/src/lm_encoding.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ ov::genai::utils::GenerationFinishInfo get_lm_encoded_results(
8282
Sampler& sampler,
8383
std::vector<SequenceGroup::Ptr> sequence_groups,
8484
std::optional<ov::Tensor> position_ids,
85-
KVCacheState& kv_cache_state,
85+
utils::KVCacheState& kv_cache_state,
8686
std::optional<EmbeddingsModel> m_embedding,
8787
std::optional<int64_t> rope_delta,
8888
const size_t max_kv_cache_size
@@ -298,7 +298,7 @@ ov::genai::utils::GenerationFinishInfo get_lm_encoded_results(
298298
}
299299

300300

301-
TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state) {
301+
TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, utils::KVCacheState& kv_cache_state) {
302302
TokenizedInputs encoded_input;
303303
size_t kv_cache_len = kv_cache_state.get_state().size();
304304
if (kv_cache_len == 0) {
@@ -325,7 +325,7 @@ TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, KVCach
325325
}
326326

327327

328-
void align_kv_cache_and_history(const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state) {
328+
void align_kv_cache_and_history(const ov::Tensor& new_chat_tokens, utils::KVCacheState& kv_cache_state) {
329329
// KV cache in model already contains prompts and answers from previous iterations.
330330
// So only new prompt wrapped into chat template to be sent into model. Tokenizer always returns
331331
// token_ids = {<bos token>, ...<valuable tokens>}. So if tokenizer applies only to the new prompt,

src/cpp/src/lm_encoding.hpp

+3-24
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,16 @@
88
namespace ov {
99
namespace genai {
1010

11-
class KVCacheState {
12-
std::vector<int64_t> state;
13-
public:
14-
size_t num_tokens_to_trim = 0;
15-
size_t seq_length_axis = 2;
16-
17-
std::vector<int64_t>& get_state() {
18-
return state;
19-
}
20-
21-
void add_inputs(const ov::Tensor& inputs_ids) {
22-
std::copy_n(inputs_ids.data<int64_t>(), inputs_ids.get_size(), std::back_inserter(state));
23-
}
24-
25-
void reset_state() {
26-
num_tokens_to_trim = 0;
27-
state.clear();
28-
}
29-
};
30-
31-
3211
ov::genai::utils::GenerationFinishInfo get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask,
3312
const std::shared_ptr<StreamerBase>& streamer_ptr, Sampler& sampler, std::vector<SequenceGroup::Ptr> sequence_groups,
34-
std::optional<ov::Tensor> position_ids, KVCacheState& m_kv_cache_state, std::optional<EmbeddingsModel> m_embedding,
13+
std::optional<ov::Tensor> position_ids, utils::KVCacheState& m_kv_cache_state, std::optional<EmbeddingsModel> m_embedding,
3514
std::optional<int64_t> rope_delta = std::nullopt, const size_t max_kv_cache_size = std::numeric_limits<size_t>::max());
3615

3716

38-
void align_kv_cache_and_history(const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state);
17+
void align_kv_cache_and_history(const ov::Tensor& new_chat_tokens, utils::KVCacheState& kv_cache_state);
3918

4019

41-
TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state);
20+
TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, utils::KVCacheState& kv_cache_state);
4221

4322
}
4423
}

src/cpp/src/utils.cpp

+18-4
Original file line numberDiff line numberDiff line change
@@ -380,13 +380,27 @@ KVAxesPosition get_kv_axes_pos(std::shared_ptr<const ov::Model> model) {
380380
return kv_pos;
381381
}
382382

383-
void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional<AdapterController> adapter_controller) {
383+
void trim_kv_cache(ov::InferRequest request, KVCacheState& kv_cache_state, std::optional<AdapterController> adapter_controller) {
384+
if (kv_cache_state.get_state().empty()) {
385+
if (adapter_controller) {
386+
for(auto& state: request.query_state()) {
387+
if(!adapter_controller->has_state_name(state.get_name())) {
388+
state.reset();
389+
}
390+
}
391+
} else {
392+
request.reset_state();
393+
}
394+
395+
return;
396+
}
397+
384398
// nothing to trim in this case
385-
if (remove_from_end == 0)
399+
if (kv_cache_state.num_tokens_to_trim == 0)
386400
return;
387401

388402
auto states = request.query_state();
389-
403+
390404
OPENVINO_ASSERT(states.size() > 0, "Request contains no states.");
391405

392406
for (auto& state : states) {
@@ -396,7 +410,7 @@ void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t se
396410
ov::Tensor old_tensor = state.get_state();
397411
// [BATCH_SIZE, num_kv_heads, seq_len, head_size]
398412
auto shape = old_tensor.get_shape();
399-
shape[seq_length_axis] -= remove_from_end;
413+
shape[kv_cache_state.seq_length_axis] -= kv_cache_state.num_tokens_to_trim;
400414

401415
ov::Coordinate new_shape_begin{0, 0, 0, 0};
402416
ov::Coordinate new_shape_end{shape};

src/cpp/src/utils.hpp

+21-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,27 @@ struct KVAxesPosition {
104104

105105
KVAxesPosition get_kv_axes_pos(std::shared_ptr<const ov::Model> model);
106106

107-
void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional<AdapterController> adapter_controller);
107+
class KVCacheState {
108+
std::vector<int64_t> state;
109+
public:
110+
size_t num_tokens_to_trim = 0;
111+
size_t seq_length_axis = 2;
112+
113+
std::vector<int64_t>& get_state() {
114+
return state;
115+
}
116+
117+
void add_inputs(const ov::Tensor& inputs_ids) {
118+
std::copy_n(inputs_ids.data<int64_t>(), inputs_ids.get_size(), std::back_inserter(state));
119+
}
120+
121+
void reset_state() {
122+
num_tokens_to_trim = 0;
123+
state.clear();
124+
}
125+
};
126+
127+
void trim_kv_cache(ov::InferRequest request, KVCacheState& kv_cache_state, std::optional<AdapterController> adapter_controller);
108128

109129
ov::Tensor push_front_inputs(const ov::Tensor& base_tensor, int64_t add_to_front);
110130

src/cpp/src/visual_language/inputs_embedder.cpp

+12-10
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ std::pair<ov::Tensor, std::optional<int64_t>> InputsEmbedder::IInputsEmbedder::g
2929

3030
void InputsEmbedder::IInputsEmbedder::start_chat(const std::string& system_message) {
3131
m_is_chat_conversation = true;
32-
m_kv_history_trim_manager.reset();
3332
if (!m_kv_cache_state.get_state().empty()) {
3433
m_history.clear();
3534
m_kv_cache_state.reset_state();
@@ -40,17 +39,20 @@ void InputsEmbedder::IInputsEmbedder::start_chat(const std::string& system_messa
4039
m_history = {{{"role", "system"}, {"content", system_message}}};
4140
}
4241

43-
void InputsEmbedder::IInputsEmbedder::update_chat_history(const std::string& decoded_results) {
44-
// Tail of chat template is missing in KV cache.
45-
// Find the tail to concatenate it with the next input prompt.
46-
m_history.push_back({{"role", "assistant"}, {"content", decoded_results}});
47-
m_kv_history_trim_manager.reset();
42+
void InputsEmbedder::IInputsEmbedder::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
43+
m_kv_cache_state.num_tokens_to_trim = 0;
44+
if (generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
45+
// If chat generation process was cancelled by user, let's rollback to previous state of history
46+
m_history.pop_back();
47+
} else {
48+
// Tail of chat template is missing in KV cache.
49+
// Find the tail to concatenate it with the next input prompt.
50+
m_history.push_back({{"role", "assistant"}, {"content", decoded_results}});
51+
}
4852
}
4953

5054
void InputsEmbedder::IInputsEmbedder::finish_chat() {
5155
m_is_chat_conversation = false;
52-
m_kv_history_trim_manager.reset();
53-
5456
m_history.clear();
5557
m_kv_cache_state.reset_state();
5658
}
@@ -123,7 +125,7 @@ ov::Tensor InputsEmbedder::IInputsEmbedder::apply_chat_template_tokenize(const s
123125
ov::Tensor InputsEmbedder::IInputsEmbedder::update_history(const ov::Tensor& new_chat_tokens) {
124126
ov::Tensor encoded_inputs;
125127
if (m_is_chat_conversation) {
126-
ov::genai::align_kv_cache_and_history(m_kv_history_trim_manager, new_chat_tokens, m_kv_cache_state);
128+
ov::genai::align_kv_cache_and_history(new_chat_tokens, m_kv_cache_state);
127129
encoded_inputs = get_chat_encoded_input(new_chat_tokens, m_kv_cache_state).input_ids;
128130
} else {
129131
encoded_inputs = new_chat_tokens;
@@ -225,7 +227,7 @@ EmbeddingsModel InputsEmbedder::get_embedding_model() const {
225227
return m_impl->get_embedding_model();
226228
}
227229

228-
KVCacheState& InputsEmbedder::get_kv_cache_state() {
230+
ov::genai::utils::KVCacheState& InputsEmbedder::get_kv_cache_state() {
229231
return m_impl->get_kv_cache_state();
230232
}
231233

src/cpp/src/visual_language/inputs_embedder.hpp

+4-12
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class InputsEmbedder {
4545
Tokenizer get_tokenizer() const;
4646

4747
// get reflection of tokens contained in the kv cache
48-
KVCacheState& get_kv_cache_state();
48+
utils::KVCacheState& get_kv_cache_state();
4949

5050
// starts chat and adds optional system_message to chat history
5151
void start_chat(const std::string& system_message);
@@ -77,16 +77,12 @@ class InputsEmbedder {
7777
bool m_is_chat_conversation = false;
7878
// Chat history
7979
ChatHistory m_history;
80-
// If sequence contains some symbols, which could be ambiguous encoded by tokenizer, we need to trim kv cache
81-
// If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history
82-
// so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history
83-
ov::genai::KVCacheTrimManager m_kv_history_trim_manager = {0, 2};
8480
// True if chat template should be applied for non-chat scenario
8581
bool m_apply_chat_template = true;
8682
// Finish reason of last generation for chat scenario
8783
ov::genai::GenerationStatus m_chat_generation_finish_status = ov::genai::GenerationStatus::RUNNING;
8884
// reflection of tokens contained in the kv cache
89-
KVCacheState m_kv_cache_state;
85+
utils::KVCacheState m_kv_cache_state;
9086
public:
9187
virtual ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector<ov::Tensor>& images, ov::genai::VLMPerfMetrics& metrics) = 0;
9288

@@ -100,21 +96,17 @@ class InputsEmbedder {
10096
return m_tokenizer;
10197
}
10298

103-
KVCacheState& get_kv_cache_state() {
99+
utils::KVCacheState& get_kv_cache_state() {
104100
return m_kv_cache_state;
105101
}
106102

107-
size_t get_num_tokens_to_remove_from_hist() const {
108-
return m_kv_history_trim_manager.num_tokens_to_trim;
109-
}
110-
111103
void set_apply_chat_template_status(bool apply_chat_template) {
112104
m_apply_chat_template = apply_chat_template;
113105
}
114106

115107
virtual void start_chat(const std::string& system_message);
116108

117-
void update_chat_history(const std::string& decoded_results);
109+
virtual void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status);
118110

119111
virtual void finish_chat();
120112

src/cpp/src/visual_language/minicpm/classes.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,14 @@ ov::Tensor InputsEmbedderMiniCPM::get_inputs_embeds(const std::string& prompt, c
667667
return inputs_embeds;
668668
}
669669

670+
void InputsEmbedderMiniCPM::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
671+
IInputsEmbedder::update_chat_history(decoded_results, generation_finish_status);
672+
if (generation_finish_status == ov::genai::GenerationStatus::CANCEL)
673+
m_image_id = m_prev_image_id;
674+
else
675+
m_prev_image_id = m_image_id;
676+
}
677+
670678
void InputsEmbedderMiniCPM::start_chat(const std::string& system_message) {
671679
IInputsEmbedder::start_chat(system_message);
672680
m_image_id = 0;

src/cpp/src/visual_language/minicpm/classes.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class InputsEmbedderMiniCPM : public InputsEmbedder::IInputsEmbedder {
3030
ov::Tensor m_pos_embed_cache;
3131
// Used to insert <image_id>i</image_id> per image (not a slice).
3232
size_t m_image_id = 0;
33+
size_t m_prev_image_id = 0;
3334

3435
public:
3536
InputsEmbedderMiniCPM(
@@ -48,6 +49,8 @@ class InputsEmbedderMiniCPM : public InputsEmbedder::IInputsEmbedder {
4849

4950
ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector<ov::Tensor>& images, ov::genai::VLMPerfMetrics& metrics) override;
5051

52+
void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) override;
53+
5154
void start_chat(const std::string& system_message) override;
5255

5356
void finish_chat() override;

src/cpp/src/visual_language/phi3_vision/classes.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,8 @@ ov::Tensor insert_image_placeholders(const std::vector<ov::Tensor>& chunks, cons
471471
length,
472472
merged.data<int64_t>() + offset
473473
);
474+
if (tokens_per_images.empty())
475+
continue;
474476
offset += length;
475477
if (offset < merged_length) {
476478
std::fill_n(
@@ -602,6 +604,14 @@ ov::Tensor InputsEmbedderPhi3V::get_inputs_embeds(const std::string& prompt, con
602604
return inputs_embeds;
603605
}
604606

607+
void InputsEmbedderPhi3V::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
608+
IInputsEmbedder::update_chat_history(decoded_results, generation_finish_status);
609+
if (generation_finish_status == ov::genai::GenerationStatus::CANCEL)
610+
m_tokens_per_images = m_prev_tokens_per_images;
611+
else
612+
m_prev_tokens_per_images = m_tokens_per_images;
613+
}
614+
605615
void InputsEmbedderPhi3V::start_chat(const std::string& system_message) {
606616
IInputsEmbedder::start_chat(system_message);
607617
m_tokens_per_images.clear();

src/cpp/src/visual_language/phi3_vision/classes.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder {
3030

3131
ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector<ov::Tensor>& images, ov::genai::VLMPerfMetrics& metrics) override;
3232

33+
void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) override;
34+
3335
void start_chat(const std::string& system_message) override;
3436

3537
void finish_chat() override;
@@ -38,6 +40,7 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder {
3840
ov::InferRequest m_hd_feature_transformer;
3941
ov::InferRequest m_vision_projection;
4042
std::vector<size_t> m_tokens_per_images;
43+
std::vector<size_t> m_prev_tokens_per_images;
4144
};
4245

4346
} // namespace ov::genai

src/cpp/src/visual_language/pipeline.cpp

+6-10
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
6666
utils::print_compiled_model_properties(compiled_language_model, "VLM language model");
6767
auto language_model = compiled_language_model.get_runtime_model();
6868

69-
KVCacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state();
70-
kv_cache_state.seq_length_axis = ov::genai::utils::get_kv_axes_pos(language_model).seq_len;
69+
utils::KVCacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state();
70+
kv_cache_state.seq_length_axis = utils::get_kv_axes_pos(language_model).seq_len;
7171

7272
m_language = compiled_language_model.create_infer_request();
7373

@@ -140,25 +140,21 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
140140

141141
m_inputs_embedder->set_apply_chat_template_status(generation_config.apply_chat_template);
142142

143+
utils::KVCacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state();
144+
if (m_is_chat_conversation)
145+
utils::trim_kv_cache(m_language, kv_cache_state, std::nullopt);
146+
143147
auto start_get_inputs_embeds = std::chrono::steady_clock::now();
144148
ov::Tensor inputs_embeds = m_inputs_embedder->get_inputs_embeds(prompt, rgbs, perf_metrics);
145149
auto end_get_inputs_embeds = std::chrono::steady_clock::now();
146150

147-
KVCacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state();
148-
if (m_is_chat_conversation)
149-
if (kv_cache_state.get_state().empty())
150-
m_language.reset_state();
151-
else
152-
ov::genai::utils::trim_kv_cache(m_language, kv_cache_state.num_tokens_to_trim, kv_cache_state.seq_length_axis, std::nullopt);
153-
154151
std::vector<SequenceGroup::Ptr> requests;
155152
size_t request_id = 0;
156153
size_t block_size = 1; // not used
157154

158155
size_t history_size = m_language.get_tensor("attention_mask").get_shape().at(1) - kv_cache_state.num_tokens_to_trim;
159156
size_t inputs_embeds_size = inputs_embeds.get_shape().at(1);
160157

161-
162158
std::vector<int64_t> tokenized_history = kv_cache_state.get_state();
163159
ov::Tensor prompt_ids(ov::element::i64, { history_size + inputs_embeds_size });
164160
OPENVINO_ASSERT(prompt_ids.get_size() >= tokenized_history.size(), "Prompt ids size is less than tokenized history size");

0 commit comments

Comments
 (0)