Skip to content

Commit d8733cb

Browse files
committed
update according to m_kv_cache_state
1 parent 892a4f7 commit d8733cb

7 files changed

+16
-44
lines changed

src/cpp/src/llm_pipeline_stateful.cpp

+7-8
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ StatefulLLMPipeline::StatefulLLMPipeline(
6060
}
6161

6262
if (!m_use_full_chat_history)
63-
m_kv_history_trim_manager.kv_cache_seq_length_axis = kv_pos.seq_len;
63+
m_kv_cache_state.seq_length_axis = kv_pos.seq_len;
6464

6565
auto filtered_properties = extract_adapters_from_properties(properties, &m_generation_config.adapters);
6666
if (m_generation_config.adapters) {
@@ -143,7 +143,7 @@ DecodedResults StatefulLLMPipeline::generate(
143143
if (m_use_full_chat_history) {
144144
encoded_input = new_chat_tokens;
145145
} else {
146-
ov::genai::align_kv_cache_and_history(m_kv_history_trim_manager, new_chat_tokens.input_ids, m_kv_cache_state);
146+
ov::genai::align_kv_cache_and_history(new_chat_tokens.input_ids, m_kv_cache_state);
147147
encoded_input = get_chat_encoded_input(new_chat_tokens.input_ids, m_kv_cache_state);
148148
}
149149
// TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied
@@ -232,7 +232,7 @@ EncodedResults StatefulLLMPipeline::generate(
232232
// Tail of previous output in chat mode is missing in KV cache.
233233
if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
234234
ov::Tensor new_chat_tokens = ov::Tensor{ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data()};
235-
ov::genai::align_kv_cache_and_history(m_kv_history_trim_manager, new_chat_tokens, m_kv_cache_state);
235+
ov::genai::align_kv_cache_and_history(new_chat_tokens, m_kv_cache_state);
236236

237237
auto encoded_input = get_chat_encoded_input(new_chat_tokens, m_kv_cache_state);
238238
input_ids = encoded_input.input_ids;
@@ -278,8 +278,8 @@ EncodedResults StatefulLLMPipeline::generate(
278278
if (m_kv_cache_state.get_state().empty() || m_use_full_chat_history)
279279
reset_kv_state();
280280
else
281-
ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_trim_manager.num_tokens_to_trim,
282-
m_kv_history_trim_manager.kv_cache_seq_length_axis, m_adapter_controller);
281+
ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_cache_state.num_tokens_to_trim,
282+
m_kv_cache_state.seq_length_axis, m_adapter_controller);
283283
}
284284

285285
size_t kv_cache_len = 0;
@@ -352,7 +352,7 @@ EncodedResults StatefulLLMPipeline::generate(
352352
m_chat_generation_finish_status = finish_info.streaming_finish_status;
353353

354354
if (is_chat_conversation) {
355-
m_kv_history_trim_manager.reset();
355+
m_kv_cache_state.num_tokens_to_trim = 0;
356356

357357
if (m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
358358
if (m_chat_generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
@@ -361,7 +361,7 @@ EncodedResults StatefulLLMPipeline::generate(
361361
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
362362
}
363363
} else if (config.is_beam_search()) {
364-
m_kv_history_trim_manager.num_tokens_to_trim = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;
364+
m_kv_cache_state.num_tokens_to_trim = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;
365365
}
366366
} else {
367367
m_kv_cache_state.reset_state();
@@ -402,7 +402,6 @@ void StatefulLLMPipeline::reset_kv_state() {
402402

403403
void StatefulLLMPipeline::finish_chat() {
404404
is_chat_conversation = false;
405-
m_kv_history_trim_manager.reset();
406405
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
407406
bool have_state = 0 != m_model_runner.get_tensor("attention_mask").get_size();
408407
if (!m_kv_cache_state.get_state().empty() || have_state) {

src/cpp/src/llm_pipeline_stateful.hpp

+1-5
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,13 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
2020
ChatHistory m_history;
2121
std::vector<int64_t> m_tokenized_chat_history;
2222
ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
23-
// If sequence contains some symbols, which could be ambiguously encoded by tokenizer, we need to trim kv cache
24-
// If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history
25-
// so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history
26-
ov::genai::KVCacheTrimManager m_kv_history_trim_manager = {0, 2};
2723
// Finish reason of last generation for chat scenario
2824
ov::genai::GenerationStatus m_chat_generation_finish_status = ov::genai::GenerationStatus::RUNNING;
2925
// if True, full history will be used as prompt on each chat generation
3026
bool m_use_full_chat_history = false;
3127
size_t m_max_kv_cache_size = std::numeric_limits<size_t>::max();
3228
bool m_is_npu = false;
33-
// reflection of tokens contained in the kv cache
29+
// include reflection of tokens contained in the kv cache and amount of tokens, which are needed to trim from kv cache on the next step of chat
3430
KVCacheState m_kv_cache_state;
3531

3632
void reset_kv_state();

src/cpp/src/lm_encoding.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, KVCach
325325
}
326326

327327

328-
void align_kv_cache_and_history(ov::genai::KVCacheTrimManager& kv_history_manager, const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state) {
328+
void align_kv_cache_and_history(const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state) {
329329
// KV cache in model already contains prompts and answers from previous iterations.
330330
// So only new prompt wrapped into chat template to be sent into model. Tokenizer always returns
331331
// token_ids = {<bos token>, ...<valuable tokens>}. So if tokenizer applies only to the new prompt,
@@ -343,7 +343,7 @@ void align_kv_cache_and_history(ov::genai::KVCacheTrimManager& kv_history_manage
343343
size_t first_diverse_tokens_idx = ov::genai::utils::get_first_history_difference(new_chat_tokens, state);
344344
// in the case of beam_search the longest answer is in the kv cache, but the best one is needed
345345
// so generated tokens were not added to KVCacheState and num_tokens_to_trim was set to the size of the generated serquence
346-
kv_history_manager.num_tokens_to_trim = kv_history_manager.num_tokens_to_trim > 0 ? kv_history_manager.num_tokens_to_trim : (state.size() - first_diverse_tokens_idx);
346+
kv_cache_state.num_tokens_to_trim = kv_cache_state.num_tokens_to_trim > 0 ? kv_cache_state.num_tokens_to_trim : (state.size() - first_diverse_tokens_idx);
347347
state.resize(first_diverse_tokens_idx);
348348
}
349349

src/cpp/src/lm_encoding.hpp

+5-12
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ namespace genai {
1111
class KVCacheState {
1212
std::vector<int64_t> state;
1313
public:
14+
size_t num_tokens_to_trim = 0;
15+
size_t seq_length_axis = 2;
16+
1417
std::vector<int64_t>& get_state() {
1518
return state;
1619
}
@@ -20,18 +23,8 @@ class KVCacheState {
2023
}
2124

2225
void reset_state() {
23-
return state.clear();
24-
}
25-
};
26-
27-
28-
struct KVCacheTrimManager
29-
{
30-
size_t num_tokens_to_trim = 0;
31-
size_t kv_cache_seq_length_axis = 2;
32-
33-
void reset() {
3426
num_tokens_to_trim = 0;
27+
state.clear();
3528
}
3629
};
3730

@@ -42,7 +35,7 @@ ov::genai::utils::GenerationFinishInfo get_lm_encoded_results(ov::InferRequest&
4235
std::optional<int64_t> rope_delta = std::nullopt, const size_t max_kv_cache_size = std::numeric_limits<size_t>::max());
4336

4437

45-
void align_kv_cache_and_history(ov::genai::KVCacheTrimManager& kv_history_manager, const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state);
38+
void align_kv_cache_and_history(const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state);
4639

4740

4841
TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state);

src/cpp/src/visual_language/inputs_embedder.cpp

-8
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,6 @@ ov::Tensor InputsEmbedder::get_inputs_embeds(const std::string& prompt, const st
217217
return m_impl->get_inputs_embeds(prompt, images, metrics);
218218
}
219219

220-
ov::Tensor InputsEmbedder::get_input_embeddings(const std::string& prompt, const std::vector<ov::Tensor>& images, ov::genai::VLMPerfMetrics& metrics) {
221-
return m_impl->get_input_embeddings(prompt, images, metrics);
222-
}
223-
224220
std::pair<ov::Tensor, std::optional<int64_t>> InputsEmbedder::get_position_ids(const size_t inputs_embeds_size, const size_t history_size) {
225221
return m_impl->get_position_ids(inputs_embeds_size, history_size);
226222
}
@@ -233,10 +229,6 @@ KVCacheState& InputsEmbedder::get_kv_cache_state() {
233229
return m_impl->get_kv_cache_state();
234230
}
235231

236-
bool InputsEmbedder::should_reset_kv_cache() const {
237-
return m_impl->should_reset_kv_cache();
238-
}
239-
240232
Tokenizer InputsEmbedder::get_tokenizer() const {
241233
return m_impl->get_tokenizer();
242234
}

src/cpp/src/visual_language/inputs_embedder.hpp

-6
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@ class InputsEmbedder {
3535
// compute input embedding for prompt and multiple images
3636
ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector<ov::Tensor>& images, ov::genai::VLMPerfMetrics& metrics);
3737

38-
// computes input embedding for prompt and multiple images and saves input_embeddings size
39-
ov::Tensor get_input_embeddings(const std::string& prompt, const std::vector<ov::Tensor>& images, ov::genai::VLMPerfMetrics& metrics);
40-
4138
// compute position ids for language model input
4239
std::pair<ov::Tensor, std::optional<int64_t>> get_position_ids(const size_t inputs_embeds_size, const size_t history_size);
4340

@@ -50,9 +47,6 @@ class InputsEmbedder {
5047
// get reflection of tokens contained in the kv cache
5148
KVCacheState& get_kv_cache_state();
5249

53-
// returns true, if we need to remove full kv cache, in that case it's needed to reset it instead of manually updating
54-
bool should_reset_kv_cache() const;
55-
5650
// starts chat and adds optional system_message to chat history
5751
void start_chat(const std::string& system_message);
5852

src/cpp/src/visual_language/pipeline.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
4242
std::shared_ptr<InputsEmbedder> m_inputs_embedder;
4343
// Axis num in kv cache from m_language model, which contains information about history len
4444
size_t m_kv_cache_seq_length_axis = 2;
45-
// Load pipeline time
46-
float m_load_time_ms = 0;
4745
// Component for applying sampling to lm outputs
4846
Sampler m_sampler;
4947
public:
@@ -143,7 +141,7 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
143141
m_inputs_embedder->set_apply_chat_template_status(generation_config.apply_chat_template);
144142

145143
auto start_get_inputs_embeds = std::chrono::steady_clock::now();
146-
ov::Tensor inputs_embeds = m_inputs_embedder->get_input_embeddings(prompt, rgbs, perf_metrics);
144+
ov::Tensor inputs_embeds = m_inputs_embedder->get_inputs_embeds(prompt, rgbs, perf_metrics);
147145
auto end_get_inputs_embeds = std::chrono::steady_clock::now();
148146

149147
KVCacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state();

0 commit comments

Comments
 (0)