Skip to content

Commit faf4eb4

Browse files
committed
Implement CANCEL for streaming with VLM Pipeline
1 parent b4632ab commit faf4eb4

File tree

4 files changed

+152
-27
lines changed

4 files changed

+152
-27
lines changed

src/cpp/src/visual_language/inputs_embedder.cpp

+63-18
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ class InputsEmbedder::IInputsEmbedder {
3535
ChatHistory m_history;
3636
// Templated chat history
3737
std::string m_templated_chat_history;
38-
// Tokenized chat history
38+
// Tokenized history
3939
std::vector<int64_t> m_tokenized_history;
40+
// Tokenized chat history on previous step
41+
std::vector<int64_t> m_prev_tokenized_history;
4042
// Tail of previous output for LM in chat mode is missing in KV cache.
4143
std::optional<int64_t> m_last_disappeared_token = std::nullopt;
4244
// If sequence contains some symbols, which could be ambiguous encoded by tokenizer, we need to trim kv cache
@@ -72,21 +74,33 @@ class InputsEmbedder::IInputsEmbedder {
7274
return m_kv_history_manager.num_tokens_to_remove_from_kv_cache;
7375
}
7476

77+
bool should_reset_kv_cache() const {
78+
return m_kv_history_manager.reset_kv_cache;
79+
}
80+
7581
void set_stop_token_ids(const std::set<int64_t>& stop_token_ids) {
7682
m_stop_token_ids = stop_token_ids;
7783
}
7884

79-
void update_tokenized_history(const std::vector<int64_t>& encoded_result, std::optional<int64_t> last_disappeared_token, bool is_beam_search, size_t last_answer_len) {
85+
virtual void update_tokenized_history(const ov::genai::utils::GenerationFinishInfo generation_finish_info, bool is_beam_search, size_t last_answer_len, size_t inputs_embeds_size) {
8086
if (is_beam_search) {
8187
m_kv_history_manager.trusted_history_length = m_tokenized_history.size();
8288
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = last_answer_len;
8389
} else {
8490
m_kv_history_manager.reset();
8591
}
8692

87-
m_last_disappeared_token = last_disappeared_token;
88-
89-
std::copy(encoded_result.begin(), encoded_result.end(), std::back_inserter(m_tokenized_history));
93+
m_last_disappeared_token = generation_finish_info.probably_disappeared_token;
94+
95+
if (generation_finish_info.streaming_finish_status == ov::genai::GenerationStatus::CANCEL) {
96+
// let's remove last answer and prompt
97+
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = inputs_embeds_size + last_answer_len;
98+
m_tokenized_history = std::move(m_prev_tokenized_history);
99+
m_kv_history_manager.reset_kv_cache = m_tokenized_history.empty();
100+
} else {
101+
auto encoded_result = generation_finish_info.results.tokens[0];
102+
std::copy(encoded_result.begin(), encoded_result.end(), std::back_inserter(m_tokenized_history));
103+
}
90104
}
91105

92106
void set_apply_chat_template_status(bool apply_chat_template) {
@@ -100,6 +114,7 @@ class InputsEmbedder::IInputsEmbedder {
100114
m_history.clear();
101115
m_templated_chat_history.clear();
102116
m_tokenized_history.clear();
117+
m_prev_tokenized_history.clear();
103118
}
104119
if (system_message.empty()) {
105120
return;
@@ -109,11 +124,16 @@ class InputsEmbedder::IInputsEmbedder {
109124
m_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
110125
}
111126

112-
void update_chat_history(const std::string& decoded_results) {
113-
// Tail of chat template is missing in KV cache.
114-
// Find the tail to concatenate it with the next input prompt.
115-
m_templated_chat_history.append(decoded_results);
116-
m_history.push_back({{"role", "assistant"}, {"content", decoded_results}});
127+
void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
128+
if (generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
129+
// If chat generation process was cancelled by user, let's rollback to previous state of history
130+
m_history.pop_back();
131+
} else {
132+
// Tail of chat template is missing in KV cache.
133+
// Find the tail to concatenate it with the next input prompt.
134+
m_templated_chat_history.append(decoded_results);
135+
m_history.push_back({{"role", "assistant"}, {"content", decoded_results}});
136+
}
117137
}
118138

119139
virtual void finish_chat() {
@@ -123,6 +143,7 @@ class InputsEmbedder::IInputsEmbedder {
123143
m_history.clear();
124144
m_templated_chat_history.clear();
125145
m_tokenized_history.clear();
146+
m_prev_tokenized_history.clear();
126147
}
127148

128149
protected:
@@ -213,21 +234,29 @@ class InputsEmbedder::IInputsEmbedder {
213234
trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens, m_tokenized_history, m_stop_token_ids);
214235
}
215236

237+
m_prev_tokenized_history.clear();
216238
if (m_tokenized_history.empty()) {
217239
encoded_input_ids = new_chat_tokens;
218-
219240
} else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_history_cache_need_to_update()) {
220241
// does_history_cache_need_to_update will be true here if beam search is activated
221242
// in beam search mode we want to remove all history about last model answer from kv cache and add the best answer directly
222243
// if we have difference in model answer and decoded answer it anyway will be less then entire history, so let's use data from m_kv_history_manager
223244
if (m_kv_history_manager.does_history_cache_need_to_update()) {
224245
trusted_history_length = m_kv_history_manager.trusted_history_length;
225246
} else {
226-
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_tokenized_history.size() - trusted_history_length;
247+
auto num_tokens_to_remove_from_kv_cache = m_tokenized_history.size() - trusted_history_length;
227248
// last generated token is present in tokenized_history, but not included to attention mask, let's keep it in history
228-
m_kv_history_manager.num_tokens_to_remove_from_kv_cache -= 1;
249+
if (num_tokens_to_remove_from_kv_cache > 0)
250+
num_tokens_to_remove_from_kv_cache -= 1;
251+
252+
// if streaming was used and cancelled on prev step, m_kv_history_manager.num_tokens_to_remove_from_kv_cache could be already set
253+
// and it would be bigger as it includes answer + prompt
254+
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_kv_history_manager.num_tokens_to_remove_from_kv_cache > num_tokens_to_remove_from_kv_cache ?
255+
m_kv_history_manager.num_tokens_to_remove_from_kv_cache : num_tokens_to_remove_from_kv_cache;
229256
}
230257

258+
std::copy_n(m_tokenized_history.data(), trusted_history_length, std::back_inserter(m_prev_tokenized_history));
259+
231260
ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.get_element_type(),
232261
{1, new_chat_tokens.get_shape().at(1) - trusted_history_length},
233262
new_chat_tokens.data<int64_t>() + trusted_history_length);
@@ -239,8 +268,12 @@ class InputsEmbedder::IInputsEmbedder {
239268
{new_chat_tokens}, {prev_chat_tokens}
240269
).input_ids;
241270

242-
if (m_last_disappeared_token.has_value())
271+
if (m_last_disappeared_token.has_value()) {
243272
encoded_input_ids = ov::genai::utils::push_front_inputs(encoded_input_ids, *m_last_disappeared_token);
273+
std::copy_n(prev_chat_tokens.data<int64_t>(), prev_chat_tokens.get_size() - 1, std::back_inserter(m_prev_tokenized_history));
274+
} else {
275+
std::copy_n(prev_chat_tokens.data<int64_t>(), prev_chat_tokens.get_size(), std::back_inserter(m_prev_tokenized_history));
276+
}
244277
}
245278
m_tokenized_history.clear();
246279
std::copy_n(new_chat_tokens.data<int64_t>(), new_chat_tokens.get_size(), std::back_inserter(m_tokenized_history));
@@ -1436,6 +1469,8 @@ ov::Tensor insert_image_placeholders(const std::vector<ov::Tensor>& chunks, cons
14361469
length,
14371470
merged.data<int64_t>() + offset
14381471
);
1472+
if (tokens_per_images.empty())
1473+
continue;
14391474
offset += length;
14401475
if (offset < merged_length) {
14411476
std::fill_n(
@@ -1576,6 +1611,12 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder {
15761611
IInputsEmbedder::finish_chat();
15771612
m_tokens_per_images.clear();
15781613
}
1614+
1615+
virtual void update_tokenized_history(const ov::genai::utils::GenerationFinishInfo generation_finish_info, bool is_beam_search, size_t last_answer_len, size_t full_len) {
1616+
IInputsEmbedder::update_tokenized_history(generation_finish_info, is_beam_search, last_answer_len, full_len);
1617+
if (generation_finish_info.streaming_finish_status == ov::genai::GenerationStatus::CANCEL)
1618+
m_tokens_per_images.clear();
1619+
}
15791620
};
15801621

15811622
class InputsEmbedderQwen2VL : public InputsEmbedder::IInputsEmbedder {
@@ -2040,14 +2081,18 @@ std::vector<int64_t> InputsEmbedder::get_tokenized_history() const {
20402081
return m_impl->get_tokenized_history();
20412082
}
20422083

2043-
void InputsEmbedder::update_tokenized_history(const std::vector<int64_t>& encoded_result, std::optional<int64_t> last_disappeared_token, bool is_beam_search, size_t last_answer_len) {
2044-
return m_impl->update_tokenized_history(encoded_result, last_disappeared_token, is_beam_search, last_answer_len);
2084+
void InputsEmbedder::update_tokenized_history(const ov::genai::utils::GenerationFinishInfo generation_finish_info, bool is_beam_search, size_t last_answer_len, size_t inputs_embeds_size) {
2085+
return m_impl->update_tokenized_history(generation_finish_info, is_beam_search, last_answer_len, inputs_embeds_size);
20452086
}
20462087

20472088
size_t InputsEmbedder::get_num_tokens_to_remove_from_hist() const {
20482089
return m_impl->get_num_tokens_to_remove_from_hist();
20492090
}
20502091

2092+
bool InputsEmbedder::should_reset_kv_cache() const {
2093+
return m_impl->should_reset_kv_cache();
2094+
}
2095+
20512096
Tokenizer InputsEmbedder::get_tokenizer() const {
20522097
return m_impl->get_tokenizer();
20532098
}
@@ -2056,8 +2101,8 @@ void InputsEmbedder::start_chat(const std::string& system_message) {
20562101
return m_impl->start_chat(system_message);
20572102
}
20582103

2059-
void InputsEmbedder::update_chat_history(const std::string& decoded_results) {
2060-
return m_impl->update_chat_history(decoded_results);
2104+
void InputsEmbedder::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
2105+
return m_impl->update_chat_history(decoded_results, generation_finish_status);
20612106
}
20622107

20632108
void InputsEmbedder::set_apply_chat_template_status(bool apply_chat_template) {

src/cpp/src/visual_language/inputs_embedder.hpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <vector>
88
#include <filesystem>
99

10+
#include "utils.hpp"
1011
#include "openvino/genai/tokenizer.hpp"
1112
#include "openvino/genai/visual_language/pipeline.hpp"
1213
#include "openvino/runtime/tensor.hpp"
@@ -49,16 +50,19 @@ class InputsEmbedder {
4950
std::vector<int64_t> get_tokenized_history() const;
5051

5152
// add new results to tokenized history
52-
void update_tokenized_history(const std::vector<int64_t>& encoded_result, std::optional<int64_t> last_disappeared_token, bool is_beam_search, size_t last_answer_len);
53+
void update_tokenized_history(const ov::genai::utils::GenerationFinishInfo generation_finish_info, bool is_beam_search, size_t last_answer_len, size_t inputs_embeds_size);
5354

5455
// returns amount of elements, which need to remove from the end of the KV cache
5556
size_t get_num_tokens_to_remove_from_hist() const;
5657

58+
// returns true, if we need to remove full kv cache, in that case it's needed to reset it instead of manually updating
59+
bool should_reset_kv_cache() const;
60+
5761
// starts chat and adds optional system_message to chat history
5862
void start_chat(const std::string& system_message);
5963

6064
// adds currently generated text to chat history
61-
void update_chat_history(const std::string& decoded_results);
65+
void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status);
6266

6367
// set the apply_chat_template flag, which determines whether chat template should be applied for non-chat scenarios
6468
void set_apply_chat_template_status(bool apply_chat_template);

src/cpp/src/visual_language/pipeline.cpp

+12-6
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,10 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
169169
if (generation_config.eos_token_id == -1)
170170
generation_config.set_eos_token_id(m_generation_config.eos_token_id);
171171
generation_config.validate();
172-
172+
173+
// keep it in case of generation will be canceled
174+
auto prev_tokenized_history = m_inputs_embedder->get_tokenized_history();
175+
173176
m_inputs_embedder->set_stop_token_ids(generation_config.stop_token_ids);
174177

175178
m_inputs_embedder->set_apply_chat_template_status(generation_config.apply_chat_template);
@@ -179,7 +182,12 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
179182
auto end_get_inputs_embeds = std::chrono::steady_clock::now();
180183

181184
auto to_remove_from_hist = m_inputs_embedder->get_num_tokens_to_remove_from_hist();
182-
ov::genai::utils::trim_kv_cache(m_language, to_remove_from_hist, m_kv_cache_seq_length_axis, std::nullopt);
185+
if (m_inputs_embedder->should_reset_kv_cache())
186+
m_language.reset_state();
187+
else
188+
ov::genai::utils::trim_kv_cache(m_language, to_remove_from_hist, m_kv_cache_seq_length_axis, std::nullopt);
189+
190+
size_t attention_mask_size = m_language.get_tensor("attention_mask").get_shape().at(1);
183191

184192
std::vector<SequenceGroup::Ptr> requests;
185193
size_t request_id = 0;
@@ -218,7 +226,6 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
218226
position_ids, m_embedding, rope_delta);
219227
ov::genai::EncodedResults& encoded_result = finish_info.results;
220228

221-
222229
auto decode_start_time = std::chrono::steady_clock::now();
223230
VLMDecodedResults decoded;
224231
for (size_t idx = 0; idx < encoded_result.tokens.size(); ++idx) {
@@ -227,12 +234,11 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
227234
}
228235
auto decode_end_time = std::chrono::steady_clock::now();
229236

230-
m_inputs_embedder->update_tokenized_history(encoded_result.tokens[0], finish_info.probably_disappeared_token, generation_config.is_beam_search(),
231-
m_language.get_tensor("attention_mask").get_shape()[1] - (history_size + inputs_embeds_size));
237+
m_inputs_embedder->update_tokenized_history(finish_info, generation_config.is_beam_search(), m_language.get_tensor("attention_mask").get_shape()[1] - (history_size + inputs_embeds_size), inputs_embeds_size);
232238

233239
std::string decoded_results = decoded.texts.at(0);
234240
if (m_is_chat_conversation)
235-
m_inputs_embedder->update_chat_history(decoded_results);
241+
m_inputs_embedder->update_chat_history(decoded_results, finish_info.streaming_finish_status);
236242

237243
auto generate_end_time = std::chrono::steady_clock::now();
238244
decoded.perf_metrics = encoded_result.perf_metrics;

tests/python_tests/test_vlm_pipeline.py

+71-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
import transformers
88
from optimum.intel.openvino import OVModelForVisualCausalLM
9-
from openvino_genai import VLMPipeline, GenerationConfig
9+
from openvino_genai import VLMPipeline, GenerationConfig, StreamingStatus
1010

1111
from utils.generation_config import get_beam_search, get_multinomial_all_parameters
1212
from utils.constants import get_default_llm_properties
@@ -184,3 +184,73 @@ def test_perf_metrics(cache):
184184
mean_dur, std_dur = perf_metrics.get_prepare_embeddings_duration()
185185
assert np.allclose(mean_dur, np.mean(raw_dur))
186186
assert np.allclose(std_dur, np.std(raw_dur))
187+
188+
189+
@pytest.mark.precommit
190+
@pytest.mark.nightly
191+
@pytest.mark.parametrize("model_id", model_ids)
192+
def test_vlm_pipeline_chat_streamer_cancel_second_generate(model_id, cache):
193+
callback_questions = [
194+
'1+1=',
195+
'Why is the Sun yellow?',
196+
'What is the previous answer?'
197+
]
198+
199+
current_iter = 0
200+
num_iters = 3
201+
def streamer(subword):
202+
nonlocal current_iter
203+
current_iter += 1
204+
return StreamingStatus.CANCEL if current_iter == num_iters else StreamingStatus.RUNNING
205+
206+
207+
models_path = get_ov_model(model_id, cache)
208+
ov_pipe = VLMPipeline(models_path, "CPU")
209+
generation_config = ov_pipe.get_generation_config()
210+
generation_config.max_new_tokens = 30
211+
generation_config.set_eos_token_id(ov_pipe.get_tokenizer().get_eos_token_id())
212+
213+
images = []
214+
for link in image_links_for_testing[1]:
215+
images.append(get_image_by_link(link))
216+
217+
ov_pipe.start_chat()
218+
ov_pipe.generate(callback_questions[0], images=images, generation_config=generation_config)
219+
220+
generation_config.ignore_eos = True
221+
ov_pipe.generate(callback_questions[1], generation_config=generation_config, streamer=streamer)
222+
ov_pipe.generate(callback_questions[2], generation_config=generation_config)
223+
ov_pipe.finish_chat()
224+
225+
226+
@pytest.mark.precommit
227+
@pytest.mark.nightly
228+
@pytest.mark.parametrize("model_id", model_ids)
229+
def test_vlm_pipeline_chat_streamer_cancel_first_generate(model_id, cache):
230+
callback_questions = [
231+
'Why is the Sun yellow?',
232+
'1+1=',
233+
]
234+
235+
current_iter = 0
236+
num_iters = 3
237+
def streamer(subword):
238+
nonlocal current_iter
239+
current_iter += 1
240+
return StreamingStatus.CANCEL if current_iter == num_iters else StreamingStatus.RUNNING
241+
242+
models_path = get_ov_model(model_id, cache)
243+
ov_pipe = VLMPipeline(models_path, "CPU")
244+
generation_config = ov_pipe.get_generation_config()
245+
generation_config.max_new_tokens = 30
246+
generation_config.ignore_eos = True
247+
generation_config.set_eos_token_id(ov_pipe.get_tokenizer().get_eos_token_id())
248+
249+
images = []
250+
for link in image_links_for_testing[1]:
251+
images.append(get_image_by_link(link))
252+
253+
ov_pipe.start_chat()
254+
ov_pipe.generate(callback_questions[0], images=images, generation_config=generation_config, streamer=streamer)
255+
ov_pipe.generate(callback_questions[1], generation_config=generation_config)
256+
ov_pipe.finish_chat()

0 commit comments

Comments
 (0)