Skip to content

Commit 621ad30

Browse files
committed
rebase + update
1 parent b7cbdbd commit 621ad30

File tree

9 files changed

+27
-21
lines changed

9 files changed

+27
-21
lines changed

src/cpp/src/llm_pipeline_stateful.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ EncodedResults StatefulLLMPipeline::generate(
212212
reset_kv_state();
213213
m_model_runner.get_tensor("attention_mask").set_shape({1, 0});
214214
m_kv_cache_state.reset_state();
215-
m_kv_history_trim_manager.reset();
216215
}
217216

218217
auto start_time = std::chrono::steady_clock::now();

src/cpp/src/visual_language/inputs_embedder.cpp

+7-5
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,17 @@ void InputsEmbedder::IInputsEmbedder::start_chat(const std::string& system_messa
3939
m_history = {{{"role", "system"}, {"content", system_message}}};
4040
}
4141

42-
void InputsEmbedder::IInputsEmbedder::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status, size_t processed_tokens_amount) {
42+
void InputsEmbedder::IInputsEmbedder::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
4343
m_kv_cache_state.num_tokens_to_trim = 0;
4444
if (generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
4545
// If chat generation process was cancelled by user, let's rollback to previous state of history
4646
m_history.pop_back();
4747

4848
std::vector<int64_t>& state = m_kv_cache_state.get_state();
49-
state.resize(state.size() - processed_tokens_amount);
49+
50+
m_kv_cache_state.num_tokens_to_trim = state.size() - m_prev_hist_length;
51+
state.resize(m_prev_hist_length);
5052
m_kv_cache_state.reset_mem_state = state.empty();
51-
m_kv_cache_state.num_tokens_to_trim = processed_tokens_amount;
5253
} else {
5354
// Tail of chat template is missing in KV cache.
5455
// Find the tail to concatenate it with the next input prompt.
@@ -142,6 +143,7 @@ ov::Tensor InputsEmbedder::IInputsEmbedder::update_history(const ov::Tensor& new
142143
ov::Tensor InputsEmbedder::IInputsEmbedder::get_encoded_input_ids(const std::string& prompt, ov::genai::VLMPerfMetrics& metrics) {
143144
const auto new_chat_tokens = apply_chat_template_tokenize(prompt, metrics);
144145
auto new_input_ids = update_history(new_chat_tokens);
146+
m_prev_hist_length = m_kv_cache_state.get_state().size();
145147
m_kv_cache_state.add_inputs(new_input_ids);
146148

147149
return new_input_ids;
@@ -244,8 +246,8 @@ void InputsEmbedder::start_chat(const std::string& system_message) {
244246
return m_impl->start_chat(system_message);
245247
}
246248

247-
void InputsEmbedder::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status, size_t processed_tokens_amount) {
248-
return m_impl->update_chat_history(decoded_results, generation_finish_status, processed_tokens_amount);
249+
void InputsEmbedder::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
250+
return m_impl->update_chat_history(decoded_results, generation_finish_status);
249251
}
250252

251253
void InputsEmbedder::set_apply_chat_template_status(bool apply_chat_template) {

src/cpp/src/visual_language/inputs_embedder.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class InputsEmbedder {
5151
void start_chat(const std::string& system_message);
5252

5353
// adds currently generated text to chat history
54-
void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status, size_t processed_tokens_amount);
54+
void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status);
5555

5656
// set the apply_chat_template flag, which determines whether chat template should be applied for non-chat scenarios
5757
void set_apply_chat_template_status(bool apply_chat_template);
@@ -83,6 +83,8 @@ class InputsEmbedder {
8383
ov::genai::GenerationStatus m_chat_generation_finish_status = ov::genai::GenerationStatus::RUNNING;
8484
// reflection of tokens contained in the kv cache
8585
utils::KVCacheState m_kv_cache_state;
86+
// length of attention_mask/kv cache at the beginning of generation()
87+
size_t m_prev_hist_length = 0;
8688
public:
8789
virtual ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector<ov::Tensor>& images, ov::genai::VLMPerfMetrics& metrics) = 0;
8890

@@ -106,7 +108,7 @@ class InputsEmbedder {
106108

107109
virtual void start_chat(const std::string& system_message);
108110

109-
virtual void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status, size_t processed_tokens_amount);
111+
virtual void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status);
110112

111113
virtual void finish_chat();
112114

src/cpp/src/visual_language/minicpm/classes.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -667,8 +667,8 @@ ov::Tensor InputsEmbedderMiniCPM::get_inputs_embeds(const std::string& prompt, c
667667
return inputs_embeds;
668668
}
669669

670-
void InputsEmbedderMiniCPM::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status, size_t processed_tokens_amount) {
671-
IInputsEmbedder::update_chat_history(decoded_results, generation_finish_status, processed_tokens_amount);
670+
void InputsEmbedderMiniCPM::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
671+
IInputsEmbedder::update_chat_history(decoded_results, generation_finish_status);
672672
if (generation_finish_status == ov::genai::GenerationStatus::CANCEL)
673673
m_image_id = m_prev_image_id;
674674
else

src/cpp/src/visual_language/minicpm/classes.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class InputsEmbedderMiniCPM : public InputsEmbedder::IInputsEmbedder {
4949

5050
ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector<ov::Tensor>& images, ov::genai::VLMPerfMetrics& metrics) override;
5151

52-
void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status, size_t processed_tokens_amount) override;
52+
void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) override;
5353

5454
void start_chat(const std::string& system_message) override;
5555

src/cpp/src/visual_language/phi3_vision/classes.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,7 @@ ov::Tensor InputsEmbedderPhi3V::get_inputs_embeds(const std::string& prompt, con
550550
}
551551
ov::Tensor new_merged_tokens = insert_image_placeholders(new_chat_tokens, m_tokens_per_images);
552552
ov::Tensor new_tokens = update_history(new_merged_tokens);
553+
m_prev_hist_length = m_kv_cache_state.get_state().size();
553554
m_kv_cache_state.add_inputs(new_tokens);
554555

555556
std::vector<ov::Tensor> tokens = drop_image_placeholders(new_tokens);
@@ -604,8 +605,8 @@ ov::Tensor InputsEmbedderPhi3V::get_inputs_embeds(const std::string& prompt, con
604605
return inputs_embeds;
605606
}
606607

607-
void InputsEmbedderPhi3V::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status, size_t processed_tokens_amount) {
608-
IInputsEmbedder::update_chat_history(decoded_results, generation_finish_status, processed_tokens_amount);
608+
void InputsEmbedderPhi3V::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
609+
IInputsEmbedder::update_chat_history(decoded_results, generation_finish_status);
609610
if (generation_finish_status == ov::genai::GenerationStatus::CANCEL)
610611
m_tokens_per_images = m_prev_tokens_per_images;
611612
else

src/cpp/src/visual_language/phi3_vision/classes.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder {
3030

3131
ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector<ov::Tensor>& images, ov::genai::VLMPerfMetrics& metrics) override;
3232

33-
void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status, size_t processed_tokens_amount) override;
33+
void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) override;
3434

3535
void start_chat(const std::string& system_message) override;
3636

src/cpp/src/visual_language/pipeline.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,6 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
9090
ov::genai::utils::print_compiled_model_properties(compiled_language_model, "VLM language model");
9191

9292
m_language = compiled_language_model.create_infer_request();
93-
94-
utils::KVCacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state();
95-
kv_cache_state.seq_length_axis = kv_pos.seq_len;
9693
m_language.get_tensor("attention_mask").set_shape({1, 0});
9794

9895
auto embedder_properties = device_propertes.empty()
@@ -102,6 +99,9 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
10299
m_tokenizer = m_inputs_embedder->get_tokenizer();
103100
m_embedding = m_inputs_embedder->get_embedding_model();
104101

102+
utils::KVCacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state();
103+
kv_cache_state.seq_length_axis = kv_pos.seq_len;
104+
105105
// If eos_token_id was not provided, take value
106106
if (m_generation_config.eos_token_id == -1) {
107107
m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id());
@@ -236,7 +236,7 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{
236236

237237
std::string decoded_results = decoded.texts.at(0);
238238
if (m_is_chat_conversation)
239-
m_inputs_embedder->update_chat_history(decoded_results, finish_info.streaming_finish_status, m_language.get_tensor("attention_mask").get_shape().at(1) - history_size);
239+
m_inputs_embedder->update_chat_history(decoded_results, finish_info.streaming_finish_status);
240240
else
241241
kv_cache_state.reset_state();
242242

tests/python_tests/test_vlm_pipeline.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,11 @@ def test_vlm_with_scheduler_vs_default(config, cache):
224224
@pytest.mark.nightly
225225
@pytest.mark.parametrize("model_id", model_ids)
226226
@pytest.mark.parametrize("system_message", ["", "You are a helpful assistant."])
227-
@pytest.mark.parametrize("iteration_images", [[image_links_for_testing[0], image_links_for_testing[0]], [image_links_for_testing[0], image_links_for_testing[2], image_links_for_testing[0]],
228-
[image_links_for_testing[1], image_links_for_testing[1]], [image_links_for_testing[1], image_links_for_testing[1], image_links_for_testing[1]],
229-
[image_links_for_testing[2], image_links_for_testing[1]], [image_links_for_testing[2], image_links_for_testing[0], image_links_for_testing[1]]])
227+
@pytest.mark.parametrize("iteration_images", [[image_links_for_testing[0], image_links_for_testing[0]], # generation with text input only
228+
[image_links_for_testing[0], image_links_for_testing[2], image_links_for_testing[0]], # combination of generations with text input and image input, empty string first
229+
[image_links_for_testing[2], image_links_for_testing[1]], # text + image input
230+
[image_links_for_testing[2], image_links_for_testing[0], image_links_for_testing[1]]] # combination of generations with text input and image input, image input first
231+
)
230232
def test_vlm_pipeline_chat(model_id, system_message, iteration_images, cache):
231233
def streamer(word: str) -> bool:
232234
nonlocal result_from_streamer

0 commit comments

Comments
 (0)