Skip to content

Commit 69d3622

Browse files
committed
Implement CANCEL for streaming with VLM Pipeline
1 parent 56fe5bf commit 69d3622

File tree

4 files changed

+102
-25
lines changed

4 files changed

+102
-25
lines changed

src/cpp/src/visual_language/inputs_embedder.cpp

+48-16
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ class InputsEmbedder::IInputsEmbedder {
3535
ChatHistory m_history;
3636
// Templated chat history
3737
std::string m_templated_chat_history;
38-
// Tokenized chat history
38+
// Tokenized history
3939
std::vector<int64_t> m_tokenized_history;
40+
// Tokenized chat history on previous step
41+
std::vector<int64_t> m_prev_tokenized_history;
4042
// Tail of previous output for LM in chat mode is missing in KV cache.
4143
std::optional<int64_t> m_last_disappeared_token = std::nullopt;
4244
// If sequence contains some symbols, which could be ambiguous encoded by tokenizer, we need to trim kv cache
@@ -72,21 +74,32 @@ class InputsEmbedder::IInputsEmbedder {
7274
return m_kv_history_manager.num_tokens_to_remove_from_kv_cache;
7375
}
7476

77+
bool should_reset_kv_cache() const {
78+
return m_kv_history_manager.reset_kv_cache;
79+
}
80+
7581
void set_stop_token_ids(const std::set<int64_t>& stop_token_ids) {
7682
m_stop_token_ids = stop_token_ids;
7783
}
7884

79-
void update_tokenized_history(const std::vector<int64_t>& encoded_result, std::optional<int64_t> last_disappeared_token, bool is_beam_search, size_t last_answer_len) {
85+
void update_tokenized_history(const ov::genai::utils::GenerationFinishInfo generation_finish_info, bool is_beam_search, size_t last_answer_len, size_t inputs_embeds_size) {
8086
if (is_beam_search) {
8187
m_kv_history_manager.trusted_history_length = m_tokenized_history.size();
8288
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = last_answer_len;
8389
} else {
8490
m_kv_history_manager.reset();
8591
}
8692

87-
m_last_disappeared_token = last_disappeared_token;
88-
89-
std::copy(encoded_result.begin(), encoded_result.end(), std::back_inserter(m_tokenized_history));
93+
m_last_disappeared_token = generation_finish_info.probably_disappeared_token;
94+
95+
if (generation_finish_info.streaming_finish_status == ov::genai::GenerationStatus::CANCEL) {
96+
// let's remove last answer and prompt
97+
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = inputs_embeds_size + last_answer_len;
98+
m_tokenized_history = std::move(m_prev_tokenized_history);
99+
} else {
100+
auto encoded_result = generation_finish_info.results.tokens[0];
101+
std::copy(encoded_result.begin(), encoded_result.end(), std::back_inserter(m_tokenized_history));
102+
}
90103
}
91104

92105
void set_apply_chat_template_status(bool apply_chat_template) {
@@ -100,6 +113,7 @@ class InputsEmbedder::IInputsEmbedder {
100113
m_history.clear();
101114
m_templated_chat_history.clear();
102115
m_tokenized_history.clear();
116+
m_prev_tokenized_history.clear();
103117
}
104118
if (system_message.empty()) {
105119
return;
@@ -109,11 +123,16 @@ class InputsEmbedder::IInputsEmbedder {
109123
m_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
110124
}
111125

112-
void update_chat_history(const std::string& decoded_results) {
113-
// Tail of chat template is missing in KV cache.
114-
// Find the tail to concatenate it with the next input prompt.
115-
m_templated_chat_history.append(decoded_results);
116-
m_history.push_back({{"role", "assistant"}, {"content", decoded_results}});
126+
void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
127+
if (generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
128+
// If chat generation process was cancelled by user, let's rollback to previous state of history
129+
m_history.pop_back();
130+
} else {
131+
// Tail of chat template is missing in KV cache.
132+
// Find the tail to concatenate it with the next input prompt.
133+
m_templated_chat_history.append(decoded_results);
134+
m_history.push_back({{"role", "assistant"}, {"content", decoded_results}});
135+
}
117136
}
118137

119138
virtual void finish_chat() {
@@ -123,6 +142,7 @@ class InputsEmbedder::IInputsEmbedder {
123142
m_history.clear();
124143
m_templated_chat_history.clear();
125144
m_tokenized_history.clear();
145+
m_prev_tokenized_history.clear();
126146
}
127147

128148
protected:
@@ -213,6 +233,9 @@ class InputsEmbedder::IInputsEmbedder {
213233
trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens, m_tokenized_history, m_stop_token_ids);
214234
}
215235

236+
m_prev_tokenized_history.clear();
237+
std::copy_n(prev_chat_tokens.data<int64_t>(), prev_chat_tokens.get_size(), std::back_inserter(m_prev_tokenized_history));
238+
216239
if (m_tokenized_history.empty()) {
217240
encoded_input_ids = new_chat_tokens;
218241

@@ -223,9 +246,14 @@ class InputsEmbedder::IInputsEmbedder {
223246
if (m_kv_history_manager.does_history_cache_need_to_update()) {
224247
trusted_history_length = m_kv_history_manager.trusted_history_length;
225248
} else {
226-
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_tokenized_history.size() - trusted_history_length;
249+
auto num_tokens_to_remove_from_kv_cache = m_tokenized_history.size() - trusted_history_length;
227250
// last generated token is present in tokenized_history, but not included to attention mask, let's keep it in history
228-
m_kv_history_manager.num_tokens_to_remove_from_kv_cache -= 1;
251+
num_tokens_to_remove_from_kv_cache -= 1;
252+
253+
// if streaming was used and cancelled on prev step, m_kv_history_manager.num_tokens_to_remove_from_kv_cache could be already set
254+
// and it would be bigger as it includes answer + prompt
255+
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_kv_history_manager.num_tokens_to_remove_from_kv_cache > num_tokens_to_remove_from_kv_cache ?
256+
m_kv_history_manager.num_tokens_to_remove_from_kv_cache : num_tokens_to_remove_from_kv_cache;
229257
}
230258

231259
ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.get_element_type(),
@@ -2040,14 +2068,18 @@ std::vector<int64_t> InputsEmbedder::get_tokenized_history() const {
20402068
return m_impl->get_tokenized_history();
20412069
}
20422070

2043-
void InputsEmbedder::update_tokenized_history(const std::vector<int64_t>& encoded_result, std::optional<int64_t> last_disappeared_token, bool is_beam_search, size_t last_answer_len) {
2044-
return m_impl->update_tokenized_history(encoded_result, last_disappeared_token, is_beam_search, last_answer_len);
2071+
void InputsEmbedder::update_tokenized_history(const ov::genai::utils::GenerationFinishInfo generation_finish_info, bool is_beam_search, size_t last_answer_len, size_t inputs_embeds_size) {
2072+
return m_impl->update_tokenized_history(generation_finish_info, is_beam_search, last_answer_len, inputs_embeds_size);
20452073
}
20462074

20472075
size_t InputsEmbedder::get_num_tokens_to_remove_from_hist() const {
20482076
return m_impl->get_num_tokens_to_remove_from_hist();
20492077
}
20502078

2079+
bool InputsEmbedder::should_reset_kv_cache() const {
2080+
return m_impl->should_reset_kv_cache();
2081+
}
2082+
20512083
Tokenizer InputsEmbedder::get_tokenizer() const {
20522084
return m_impl->get_tokenizer();
20532085
}
@@ -2056,8 +2088,8 @@ void InputsEmbedder::start_chat(const std::string& system_message) {
20562088
return m_impl->start_chat(system_message);
20572089
}
20582090

2059-
void InputsEmbedder::update_chat_history(const std::string& decoded_results) {
2060-
return m_impl->update_chat_history(decoded_results);
2091+
void InputsEmbedder::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
2092+
return m_impl->update_chat_history(decoded_results, generation_finish_status);
20612093
}
20622094

20632095
void InputsEmbedder::set_apply_chat_template_status(bool apply_chat_template) {

src/cpp/src/visual_language/inputs_embedder.hpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <vector>
88
#include <filesystem>
99

10+
#include "utils.hpp"
1011
#include "openvino/genai/tokenizer.hpp"
1112
#include "openvino/genai/visual_language/pipeline.hpp"
1213
#include "openvino/runtime/tensor.hpp"
@@ -49,16 +50,19 @@ class InputsEmbedder {
4950
std::vector<int64_t> get_tokenized_history() const;
5051

5152
// add new results to tokenized history
52-
void update_tokenized_history(const std::vector<int64_t>& encoded_result, std::optional<int64_t> last_disappeared_token, bool is_beam_search, size_t last_answer_len);
53+
void update_tokenized_history(const ov::genai::utils::GenerationFinishInfo generation_finish_info, bool is_beam_search, size_t last_answer_len, size_t inputs_embeds_size);
5354

5455
// returns amount of elements, which need to remove from the end of the KV cache
5556
size_t get_num_tokens_to_remove_from_hist() const;
5657

58+
// returns true, if we need to remove full kv cache, in that case it's needed to reset it instead of manually updating
59+
bool should_reset_kv_cache() const;
60+
5761
// starts chat and adds optional system_message to chat history
5862
void start_chat(const std::string& system_message);
5963

6064
// adds currently generated text to chat history
61-
void update_chat_history(const std::string& decoded_results);
65+
void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status);
6266

6367
// set the apply_chat_template flag, which determines whether chat template should be applied for non-chat scenarios
6468
void set_apply_chat_template_status(bool apply_chat_template);

src/cpp/src/visual_language/pipeline.cpp

+12-6
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,10 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
169169
if (generation_config.eos_token_id == -1)
170170
generation_config.set_eos_token_id(m_generation_config.eos_token_id);
171171
generation_config.validate();
172-
172+
173+
// keep it in case of generation will be canceled
174+
auto prev_tokenized_history = m_inputs_embedder->get_tokenized_history();
175+
173176
m_inputs_embedder->set_stop_token_ids(generation_config.stop_token_ids);
174177

175178
m_inputs_embedder->set_apply_chat_template_status(generation_config.apply_chat_template);
@@ -179,7 +182,12 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
179182
auto end_get_inputs_embeds = std::chrono::steady_clock::now();
180183

181184
auto to_remove_from_hist = m_inputs_embedder->get_num_tokens_to_remove_from_hist();
182-
ov::genai::utils::trim_kv_cache(m_language, to_remove_from_hist, m_kv_cache_seq_length_axis, std::nullopt);
185+
if (m_inputs_embedder->should_reset_kv_cache())
186+
m_language.reset_state();
187+
else
188+
ov::genai::utils::trim_kv_cache(m_language, to_remove_from_hist, m_kv_cache_seq_length_axis, std::nullopt);
189+
190+
size_t attention_mask_size = m_language.get_tensor("attention_mask").get_shape().at(1);
183191

184192
std::vector<SequenceGroup::Ptr> requests;
185193
size_t request_id = 0;
@@ -218,7 +226,6 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
218226
position_ids, m_embedding, rope_delta);
219227
ov::genai::EncodedResults& encoded_result = finish_info.results;
220228

221-
222229
auto decode_start_time = std::chrono::steady_clock::now();
223230
VLMDecodedResults decoded;
224231
for (size_t idx = 0; idx < encoded_result.tokens.size(); ++idx) {
@@ -227,12 +234,11 @@ class ov::genai::VLMPipeline::VLMPipelineImpl {
227234
}
228235
auto decode_end_time = std::chrono::steady_clock::now();
229236

230-
m_inputs_embedder->update_tokenized_history(encoded_result.tokens[0], finish_info.probably_disappeared_token, generation_config.is_beam_search(),
231-
m_language.get_tensor("attention_mask").get_shape()[1] - (history_size + inputs_embeds_size));
237+
m_inputs_embedder->update_tokenized_history(finish_info, generation_config.is_beam_search(), m_language.get_tensor("attention_mask").get_shape()[1] - (history_size + inputs_embeds_size), inputs_embeds_size);
232238

233239
std::string decoded_results = decoded.texts.at(0);
234240
if (m_is_chat_conversation)
235-
m_inputs_embedder->update_chat_history(decoded_results);
241+
m_inputs_embedder->update_chat_history(decoded_results, finish_info.streaming_finish_status);
236242

237243
auto generate_end_time = std::chrono::steady_clock::now();
238244
decoded.perf_metrics = encoded_result.perf_metrics;

tests/python_tests/test_vlm_pipeline.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77
import transformers
88
from optimum.intel.openvino import OVModelForVisualCausalLM
9-
from openvino_genai import VLMPipeline, GenerationConfig
9+
from openvino_genai import VLMPipeline, GenerationConfig, StreamingStatus
1010

1111
from utils.generation_config import get_beam_search, get_multinomial_all_parameters
1212
from utils.constants import get_default_llm_properties
@@ -184,3 +184,38 @@ def test_perf_metrics(cache):
184184
mean_dur, std_dur = perf_metrics.get_prepare_embeddings_duration()
185185
assert np.allclose(mean_dur, np.mean(raw_dur))
186186
assert np.allclose(std_dur, np.std(raw_dur))
187+
188+
189+
@pytest.mark.precommit
190+
@pytest.mark.nightly
191+
@pytest.mark.parametrize("model_id", model_ids)
192+
def test_vlm_pipeline_chat(model_id, cache):
193+
callback_questions = [
194+
'1+1=',
195+
'Why is the Sun yellow?',
196+
'What is the previous answer?'
197+
]
198+
199+
current_iter = 0
200+
num_iters = 3
201+
def streamer(subword):
202+
nonlocal current_iter
203+
current_iter += 1
204+
return StreamingStatus.CANCEL if current_iter == num_iters else StreamingStatus.RUNNING
205+
206+
207+
models_path = get_ov_model(model_id, cache)
208+
ov_pipe = VLMPipeline(models_path, "CPU")
209+
generation_config = ov_pipe.get_generation_config()
210+
generation_config.max_new_tokens = 30
211+
generation_config.set_eos_token_id(ov_pipe.get_tokenizer().get_eos_token_id())
212+
213+
images = []
214+
for link in image_links_for_testing[1]:
215+
images.append(get_image_by_link(link))
216+
217+
ov_pipe.start_chat()
218+
ov_pipe.generate(callback_questions[0], images=images, generation_config=generation_config)
219+
ov_pipe.generate(callback_questions[1], generation_config=generation_config, streamer=streamer)
220+
ov_pipe.generate(callback_questions[2], generation_config=generation_config)
221+
ov_pipe.finish_chat()

0 commit comments

Comments
 (0)