Skip to content

Commit f9adf6b

Browse files
committed
update according to m_kv_cache_state
1 parent c95ae4f commit f9adf6b

7 files changed

+24
-101
lines changed

src/cpp/src/llm_pipeline_stateful.cpp

+7-8
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ StatefulLLMPipeline::StatefulLLMPipeline(
4545
m_use_full_chat_history = true;
4646

4747
if (!m_use_full_chat_history)
48-
m_kv_history_trim_manager.kv_cache_seq_length_axis = ov::genai::utils::get_kv_axes_pos(model).seq_len;
48+
m_kv_cache_state.seq_length_axis = ov::genai::utils::get_kv_axes_pos(model).seq_len;
4949

5050
auto filtered_properties = extract_adapters_from_properties(properties, &m_generation_config.adapters);
5151
if (m_generation_config.adapters) {
@@ -119,7 +119,7 @@ DecodedResults StatefulLLMPipeline::generate(
119119
if (m_use_full_chat_history) {
120120
encoded_input = new_chat_tokens;
121121
} else {
122-
ov::genai::align_kv_cache_and_history(m_kv_history_trim_manager, new_chat_tokens.input_ids, m_kv_cache_state);
122+
ov::genai::align_kv_cache_and_history(new_chat_tokens.input_ids, m_kv_cache_state);
123123
encoded_input = get_chat_encoded_input(new_chat_tokens.input_ids, m_kv_cache_state);
124124
}
125125
// TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied
@@ -208,7 +208,7 @@ EncodedResults StatefulLLMPipeline::generate(
208208
// Tail of previous output in chat mode is missing in KV cache.
209209
if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
210210
ov::Tensor new_chat_tokens = ov::Tensor{ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data()};
211-
ov::genai::align_kv_cache_and_history(m_kv_history_trim_manager, new_chat_tokens, m_kv_cache_state);
211+
ov::genai::align_kv_cache_and_history(new_chat_tokens, m_kv_cache_state);
212212

213213
auto encoded_input = get_chat_encoded_input(new_chat_tokens, m_kv_cache_state);
214214
input_ids = encoded_input.input_ids;
@@ -245,8 +245,8 @@ EncodedResults StatefulLLMPipeline::generate(
245245
if (m_kv_cache_state.get_state().empty() || m_use_full_chat_history)
246246
reset_kv_state();
247247
else
248-
ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_trim_manager.num_tokens_to_trim,
249-
m_kv_history_trim_manager.kv_cache_seq_length_axis, m_adapter_controller);
248+
ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_cache_state.num_tokens_to_trim,
249+
m_kv_cache_state.seq_length_axis, m_adapter_controller);
250250
}
251251

252252
size_t kv_cache_len = 0;
@@ -319,7 +319,7 @@ EncodedResults StatefulLLMPipeline::generate(
319319
m_chat_generation_finish_status = finish_info.streaming_finish_status;
320320

321321
if (is_chat_conversation) {
322-
m_kv_history_trim_manager.reset();
322+
m_kv_cache_state.num_tokens_to_trim = 0;
323323

324324
if (m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
325325
if (m_chat_generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
@@ -328,7 +328,7 @@ EncodedResults StatefulLLMPipeline::generate(
328328
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
329329
}
330330
} else if (config.is_beam_search()) {
331-
m_kv_history_trim_manager.num_tokens_to_trim = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;
331+
m_kv_cache_state.num_tokens_to_trim = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;
332332
}
333333
} else {
334334
m_kv_cache_state.reset_state();
@@ -369,7 +369,6 @@ void StatefulLLMPipeline::reset_kv_state() {
369369

370370
void StatefulLLMPipeline::finish_chat() {
371371
is_chat_conversation = false;
372-
m_kv_history_trim_manager.reset();
373372
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
374373
bool have_state = 0 != m_model_runner.get_tensor("attention_mask").get_size();
375374
if (!m_kv_cache_state.get_state().empty() || have_state) {

src/cpp/src/llm_pipeline_stateful.hpp

+1-5
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,11 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
1818
ChatHistory m_history;
1919
std::vector<int64_t> m_tokenized_chat_history;
2020
ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
21-
// If sequence contains some symbols, which could be ambiguously encoded by tokenizer, we need to trim kv cache
22-
// If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history
23-
// so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history
24-
ov::genai::KVCacheTrimManager m_kv_history_trim_manager = {0, 2};
2521
// Finish reason of last generation for chat scenario
2622
ov::genai::GenerationStatus m_chat_generation_finish_status = ov::genai::GenerationStatus::RUNNING;
2723
// if True, full history will be used as prompt on each chat generation
2824
bool m_use_full_chat_history = false;
29-
// reflection of tokens contained in the kv cache
25+
// include reflection of tokens contained in the kv cache and amount of tokens, which are needed to trim from kv cache on the next step of chat
3026
KVCacheState m_kv_cache_state;
3127

3228
void reset_kv_state();

src/cpp/src/lm_encoding.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, KVCach
320320
}
321321

322322

323-
void align_kv_cache_and_history(ov::genai::KVCacheTrimManager& kv_history_manager, const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state) {
323+
void align_kv_cache_and_history(const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state) {
324324
// KV cache in model already contains prompts and answers from previous iterations.
325325
// So only new prompt wrapped into chat template to be sent into model. Tokenizer always returns
326326
// token_ids = {<bos token>, ...<valuable tokens>}. So if tokenizer applies only to the new prompt,
@@ -338,7 +338,7 @@ void align_kv_cache_and_history(ov::genai::KVCacheTrimManager& kv_history_manage
338338
size_t first_diverse_tokens_idx = ov::genai::utils::get_first_history_difference(new_chat_tokens, state);
339339
// in the case of beam_search the longest answer is in the kv cache, but the best one is needed
340340
// so generated tokens were not added to KVCacheState and num_tokens_to_trim was set to the size of the generated serquence
341-
kv_history_manager.num_tokens_to_trim = kv_history_manager.num_tokens_to_trim > 0 ? kv_history_manager.num_tokens_to_trim : (state.size() - first_diverse_tokens_idx);
341+
kv_cache_state.num_tokens_to_trim = kv_cache_state.num_tokens_to_trim > 0 ? kv_cache_state.num_tokens_to_trim : (state.size() - first_diverse_tokens_idx);
342342
state.resize(first_diverse_tokens_idx);
343343
}
344344

src/cpp/src/lm_encoding.hpp

+5-12
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ namespace genai {
1111
class KVCacheState {
1212
std::vector<int64_t> state;
1313
public:
14+
size_t num_tokens_to_trim = 0;
15+
size_t seq_length_axis = 2;
16+
1417
std::vector<int64_t>& get_state() {
1518
return state;
1619
}
@@ -20,18 +23,8 @@ class KVCacheState {
2023
}
2124

2225
void reset_state() {
23-
return state.clear();
24-
}
25-
};
26-
27-
28-
struct KVCacheTrimManager
29-
{
30-
size_t num_tokens_to_trim = 0;
31-
size_t kv_cache_seq_length_axis = 2;
32-
33-
void reset() {
3426
num_tokens_to_trim = 0;
27+
state.clear();
3528
}
3629
};
3730

@@ -41,7 +34,7 @@ ov::genai::utils::GenerationFinishInfo get_lm_encoded_results(ov::InferRequest&
4134
std::optional<ov::Tensor> position_ids, KVCacheState& m_kv_cache_state, std::optional<EmbeddingsModel> m_embedding, std::optional<int64_t> rope_delta = std::nullopt);
4235

4336

44-
void align_kv_cache_and_history(ov::genai::KVCacheTrimManager& kv_history_manager, const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state);
37+
void align_kv_cache_and_history(const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state);
4538

4639

4740
TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state);

src/cpp/src/visual_language/inputs_embedder.cpp

+8-52
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,6 @@ class InputsEmbedder::IInputsEmbedder {
4444
public:
4545
virtual ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector<ov::Tensor>& images, ov::genai::VLMPerfMetrics& metrics) = 0;
4646

47-
ov::Tensor get_input_embeddings(const std::string& prompt, const std::vector<ov::Tensor>& images, ov::genai::VLMPerfMetrics& metrics) {
48-
ov::Tensor inputs_embeds = get_inputs_embeds(prompt, images, metrics);
49-
m_inputs_embeds_size = inputs_embeds.get_shape().at(1);
50-
return inputs_embeds;
51-
}
52-
5347
virtual std::pair<ov::Tensor, std::optional<int64_t>> get_position_ids(const size_t inputs_embeds_size, const size_t history_size) {
5448
ov::Tensor position_ids = ov::Tensor{ov::element::i64, { 1, inputs_embeds_size }};
5549
std::iota(position_ids.data<int64_t>(), position_ids.data<int64_t>() + position_ids.get_size(), history_size);
@@ -72,32 +66,6 @@ class InputsEmbedder::IInputsEmbedder {
7266
m_stop_token_ids = stop_token_ids;
7367
}
7468

75-
<<<<<<< HEAD
76-
=======
77-
virtual void update_tokenized_history(const ov::genai::utils::GenerationFinishInfo generation_finish_info, bool is_beam_search, size_t last_answer_len) {
78-
if (is_beam_search) {
79-
m_kv_history_manager.trusted_history_length = m_tokenized_history.size();
80-
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = last_answer_len;
81-
} else {
82-
m_kv_history_manager.reset();
83-
}
84-
85-
m_last_disappeared_token = generation_finish_info.probably_disappeared_token;
86-
87-
if (m_is_chat_conversation) {
88-
if (generation_finish_info.streaming_finish_status == ov::genai::GenerationStatus::CANCEL) {
89-
// let's remove last answer and prompt
90-
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_inputs_embeds_size + last_answer_len;
91-
m_tokenized_history = m_prev_tokenized_history;
92-
m_kv_history_manager.reset_kv_cache = m_tokenized_history.empty();
93-
} else {
94-
auto encoded_result = generation_finish_info.results.tokens[0];
95-
std::copy(encoded_result.begin(), encoded_result.end(), std::back_inserter(m_tokenized_history));
96-
}
97-
}
98-
}
99-
100-
>>>>>>> 19b756cd (update comments)
10169
void set_apply_chat_template_status(bool apply_chat_template) {
10270
m_apply_chat_template = apply_chat_template;
10371
}
@@ -114,19 +82,15 @@ class InputsEmbedder::IInputsEmbedder {
11482
m_history = {{{"role", "system"}, {"content", system_message}}};
11583
}
11684

117-
void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
85+
virtual void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
86+
m_kv_cache_state.num_tokens_to_trim = 0;
11887
if (generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
11988
// If chat generation process was cancelled by user, let's rollback to previous state of history
12089
m_history.pop_back();
121-
if (!m_history.empty()) {
122-
constexpr bool add_generation_prompt = true;
123-
m_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
124-
}
12590
} else {
12691
// Tail of chat template is missing in KV cache.
12792
// Find the tail to concatenate it with the next input prompt.
12893
m_history.push_back({{"role", "assistant"}, {"content", decoded_results}});
129-
m_kv_cache_state.num_tokens_to_trim = 0;
13094
}
13195
}
13296

@@ -410,9 +374,9 @@ class InputsEmbedderMiniCPM : public InputsEmbedder::IInputsEmbedder {
410374
return inputs_embeds;
411375
}
412376

413-
virtual void update_tokenized_history(const ov::genai::utils::GenerationFinishInfo generation_finish_info, bool is_beam_search, size_t last_answer_len) {
414-
IInputsEmbedder::update_tokenized_history(generation_finish_info, is_beam_search, last_answer_len);
415-
if (generation_finish_info.streaming_finish_status == ov::genai::GenerationStatus::CANCEL) {
377+
virtual void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
378+
IInputsEmbedder::update_chat_history(decoded_results, generation_finish_status);
379+
if (generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
416380
m_image_id = m_prev_image_id;
417381
}
418382
}
@@ -1562,9 +1526,9 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder {
15621526
m_tokens_per_images.clear();
15631527
}
15641528

1565-
virtual void update_tokenized_history(const ov::genai::utils::GenerationFinishInfo generation_finish_info, bool is_beam_search, size_t last_answer_len) {
1566-
IInputsEmbedder::update_tokenized_history(generation_finish_info, is_beam_search, last_answer_len);
1567-
if (generation_finish_info.streaming_finish_status == ov::genai::GenerationStatus::CANCEL)
1529+
virtual void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
1530+
IInputsEmbedder::update_chat_history(decoded_results, generation_finish_status);
1531+
if (generation_finish_status == ov::genai::GenerationStatus::CANCEL)
15681532
m_tokens_per_images = m_prev_tokens_per_images;
15691533
}
15701534
};
@@ -2014,10 +1978,6 @@ ov::Tensor InputsEmbedder::get_inputs_embeds(const std::string& prompt, const st
20141978
return m_impl->get_inputs_embeds(prompt, images, metrics);
20151979
}
20161980

2017-
ov::Tensor InputsEmbedder::get_input_embeddings(const std::string& prompt, const std::vector<ov::Tensor>& images, ov::genai::VLMPerfMetrics& metrics) {
2018-
return m_impl->get_input_embeddings(prompt, images, metrics);
2019-
}
2020-
20211981
std::pair<ov::Tensor, std::optional<int64_t>> InputsEmbedder::get_position_ids(const size_t inputs_embeds_size, const size_t history_size) {
20221982
return m_impl->get_position_ids(inputs_embeds_size, history_size);
20231983
}
@@ -2034,10 +1994,6 @@ KVCacheState& InputsEmbedder::get_kv_cache_state() {
20341994
return m_impl->get_kv_cache_state();
20351995
}
20361996

2037-
bool InputsEmbedder::should_reset_kv_cache() const {
2038-
return m_impl->should_reset_kv_cache();
2039-
}
2040-
20411997
Tokenizer InputsEmbedder::get_tokenizer() const {
20421998
return m_impl->get_tokenizer();
20431999
}

src/cpp/src/visual_language/inputs_embedder.hpp

-6
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@ class InputsEmbedder {
3636
// compute input embedding for prompt and multiple images
3737
ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector<ov::Tensor>& images, ov::genai::VLMPerfMetrics& metrics);
3838

39-
// computes input embedding for prompt and multiple images and saves input_embeddings size
40-
ov::Tensor get_input_embeddings(const std::string& prompt, const std::vector<ov::Tensor>& images, ov::genai::VLMPerfMetrics& metrics);
41-
4239
// compute position ids for language model input
4340
std::pair<ov::Tensor, std::optional<int64_t>> get_position_ids(const size_t inputs_embeds_size, const size_t history_size);
4441

@@ -53,9 +50,6 @@ class InputsEmbedder {
5350
// get reflection of tokens contained in the kv cache
5451
KVCacheState& get_kv_cache_state();
5552

56-
// returns true, if we need to remove full kv cache, in that case it's needed to reset it instead of manually updating
57-
bool should_reset_kv_cache() const;
58-
5953
// starts chat and adds optional system_message to chat history
6054
void start_chat(const std::string& system_message);
6155

src/cpp/src/visual_language/pipeline.cpp

+1-16
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ class ov::genai::VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
4545
std::shared_ptr<InputsEmbedder> m_inputs_embedder;
4646
// Axis num in kv cache from m_language model, which contains information about history len
4747
size_t m_kv_cache_seq_length_axis = 2;
48-
// Load pipeline time
49-
float m_load_time_ms = 0;
5048
// Component for applying sampling to lm outputs
5149
Sampler m_sampler;
5250
public:
@@ -163,23 +161,15 @@ class ov::genai::VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
163161
m_inputs_embedder->set_apply_chat_template_status(generation_config.apply_chat_template);
164162

165163
auto start_get_inputs_embeds = std::chrono::steady_clock::now();
166-
ov::Tensor inputs_embeds = m_inputs_embedder->get_input_embeddings(prompt, rgbs, perf_metrics);
164+
ov::Tensor inputs_embeds = m_inputs_embedder->get_inputs_embeds(prompt, rgbs, perf_metrics);
167165
auto end_get_inputs_embeds = std::chrono::steady_clock::now();
168166

169-
<<<<<<< HEAD
170167
KVCacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state();
171168
if (m_is_chat_conversation)
172169
if (kv_cache_state.get_state().empty())
173170
m_language.reset_state();
174171
else
175172
ov::genai::utils::trim_kv_cache(m_language, kv_cache_state.num_tokens_to_trim, kv_cache_state.seq_length_axis, std::nullopt);
176-
=======
177-
auto to_remove_from_hist = m_inputs_embedder->get_num_tokens_to_remove_from_hist();
178-
if (m_inputs_embedder->should_reset_kv_cache())
179-
m_language.reset_state();
180-
else
181-
ov::genai::utils::trim_kv_cache(m_language, to_remove_from_hist, m_kv_cache_seq_length_axis, std::nullopt);
182-
>>>>>>> 19b756cd (update comments)
183173

184174
std::vector<SequenceGroup::Ptr> requests;
185175
size_t request_id = 0;
@@ -228,11 +218,6 @@ class ov::genai::VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
228218
}
229219
auto decode_end_time = std::chrono::steady_clock::now();
230220

231-
<<<<<<< HEAD
232-
=======
233-
m_inputs_embedder->update_tokenized_history(finish_info, generation_config.is_beam_search(), m_language.get_tensor("attention_mask").get_shape()[1] - (history_size + inputs_embeds_size));
234-
235-
>>>>>>> 19b756cd (update comments)
236221
std::string decoded_results = decoded.texts.at(0);
237222
if (m_is_chat_conversation)
238223
m_inputs_embedder->update_chat_history(decoded_results, finish_info.streaming_finish_status);

0 commit comments

Comments
 (0)