Skip to content

Commit 434c2a9

Browse files
authoredMar 6, 2025··
Implement CANCEL for streaming with VLM Pipeline (#1725)
- added support of CANCEL for streamer VLM Pipeline - please, note, that for phi3 added possibility to run without picture, for cases when CANCEL is happened on the step with image - added tests for CANCEL for VLM pipeline
1 parent 50c45f0 commit 434c2a9

14 files changed

+232
-110
lines changed
 

‎src/cpp/src/llm_pipeline_stateful.cpp

+7-10
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ StatefulLLMPipeline::StatefulLLMPipeline(
6060
}
6161

6262
if (!m_use_full_chat_history)
63-
m_kv_history_trim_manager.kv_cache_seq_length_axis = kv_pos.seq_len;
63+
m_kv_cache_state.seq_length_axis = kv_pos.seq_len;
6464

6565
auto filtered_properties = extract_adapters_from_properties(properties, &m_generation_config.adapters);
6666
if (m_generation_config.adapters) {
@@ -143,7 +143,7 @@ DecodedResults StatefulLLMPipeline::generate(
143143
if (m_use_full_chat_history) {
144144
encoded_input = new_chat_tokens;
145145
} else {
146-
ov::genai::align_kv_cache_and_history(m_kv_history_trim_manager, new_chat_tokens.input_ids, m_kv_cache_state);
146+
ov::genai::align_kv_cache_and_history(new_chat_tokens.input_ids, m_kv_cache_state);
147147
encoded_input = get_chat_encoded_input(new_chat_tokens.input_ids, m_kv_cache_state);
148148
}
149149
// TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied
@@ -212,7 +212,6 @@ EncodedResults StatefulLLMPipeline::generate(
212212
reset_kv_state();
213213
m_model_runner.get_tensor("attention_mask").set_shape({1, 0});
214214
m_kv_cache_state.reset_state();
215-
m_kv_history_trim_manager.reset();
216215
}
217216

218217
auto start_time = std::chrono::steady_clock::now();
@@ -238,7 +237,7 @@ EncodedResults StatefulLLMPipeline::generate(
238237
// Tail of previous output in chat mode is missing in KV cache.
239238
if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
240239
ov::Tensor new_chat_tokens = ov::Tensor{ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data()};
241-
ov::genai::align_kv_cache_and_history(m_kv_history_trim_manager, new_chat_tokens, m_kv_cache_state);
240+
ov::genai::align_kv_cache_and_history(new_chat_tokens, m_kv_cache_state);
242241

243242
auto encoded_input = get_chat_encoded_input(new_chat_tokens, m_kv_cache_state);
244243
input_ids = encoded_input.input_ids;
@@ -281,11 +280,10 @@ EncodedResults StatefulLLMPipeline::generate(
281280
"but you have '" + std::to_string(num_inputs) + "' inputs");
282281

283282
if (is_chat_conversation) {
284-
if (m_kv_cache_state.get_state().empty() || m_use_full_chat_history)
283+
if (m_use_full_chat_history)
285284
reset_kv_state();
286285
else
287-
ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_trim_manager.num_tokens_to_trim,
288-
m_kv_history_trim_manager.kv_cache_seq_length_axis, m_adapter_controller);
286+
ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_cache_state, m_adapter_controller);
289287
}
290288

291289
size_t kv_cache_len = 0;
@@ -358,7 +356,7 @@ EncodedResults StatefulLLMPipeline::generate(
358356
m_chat_generation_finish_status = finish_info.streaming_finish_status;
359357

360358
if (is_chat_conversation) {
361-
m_kv_history_trim_manager.reset();
359+
m_kv_cache_state.num_tokens_to_trim = 0;
362360

363361
if (m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
364362
if (m_chat_generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
@@ -367,7 +365,7 @@ EncodedResults StatefulLLMPipeline::generate(
367365
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
368366
}
369367
} else if (config.is_beam_search()) {
370-
m_kv_history_trim_manager.num_tokens_to_trim = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;
368+
m_kv_cache_state.num_tokens_to_trim = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;
371369
}
372370
}
373371

@@ -406,7 +404,6 @@ void StatefulLLMPipeline::reset_kv_state() {
406404

407405
void StatefulLLMPipeline::finish_chat() {
408406
is_chat_conversation = false;
409-
m_kv_history_trim_manager.reset();
410407
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
411408
bool have_state = 0 != m_model_runner.get_tensor("attention_mask").get_size();
412409
if (!m_kv_cache_state.get_state().empty() || have_state) {

‎src/cpp/src/llm_pipeline_stateful.hpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,14 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
2020
ChatHistory m_history;
2121
std::vector<int64_t> m_tokenized_chat_history;
2222
ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
23-
// If sequence contains some symbols, which could be ambiguously encoded by tokenizer, we need to trim kv cache
24-
// If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history
25-
// so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history
26-
ov::genai::KVCacheTrimManager m_kv_history_trim_manager = {0, 2};
2723
// Finish reason of last generation for chat scenario
2824
ov::genai::GenerationStatus m_chat_generation_finish_status = ov::genai::GenerationStatus::RUNNING;
2925
// if True, full history will be used as prompt on each chat generation
3026
bool m_use_full_chat_history = false;
3127
size_t m_max_kv_cache_size = std::numeric_limits<size_t>::max();
3228
bool m_is_npu = false;
33-
// reflection of tokens contained in the kv cache
34-
KVCacheState m_kv_cache_state;
29+
// include reflection of tokens contained in the kv cache and amount of tokens, which are needed to trim from kv cache on the next step of chat
30+
utils::KVCacheState m_kv_cache_state;
3531

3632
void reset_kv_state();
3733
public:

‎src/cpp/src/lm_encoding.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ ov::genai::utils::GenerationFinishInfo get_lm_encoded_results(
8282
Sampler& sampler,
8383
std::vector<SequenceGroup::Ptr> sequence_groups,
8484
std::optional<ov::Tensor> position_ids,
85-
KVCacheState& kv_cache_state,
85+
utils::KVCacheState& kv_cache_state,
8686
std::optional<EmbeddingsModel> m_embedding,
8787
std::optional<int64_t> rope_delta,
8888
const size_t max_kv_cache_size
@@ -298,7 +298,7 @@ ov::genai::utils::GenerationFinishInfo get_lm_encoded_results(
298298
}
299299

300300

301-
TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state) {
301+
TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, utils::KVCacheState& kv_cache_state) {
302302
TokenizedInputs encoded_input;
303303
size_t kv_cache_len = kv_cache_state.get_state().size();
304304
if (kv_cache_len == 0) {
@@ -325,7 +325,7 @@ TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, KVCach
325325
}
326326

327327

328-
void align_kv_cache_and_history(ov::genai::KVCacheTrimManager& kv_history_manager, const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state) {
328+
void align_kv_cache_and_history(const ov::Tensor& new_chat_tokens, utils::KVCacheState& kv_cache_state) {
329329
// KV cache in model already contains prompts and answers from previous iterations.
330330
// So only new prompt wrapped into chat template to be sent into model. Tokenizer always returns
331331
// token_ids = {<bos token>, ...<valuable tokens>}. So if tokenizer applies only to the new prompt,
@@ -343,8 +343,9 @@ void align_kv_cache_and_history(ov::genai::KVCacheTrimManager& kv_history_manage
343343
size_t first_diverse_tokens_idx = ov::genai::utils::get_first_history_difference(new_chat_tokens, state);
344344
// in the case of beam_search the longest answer is in the kv cache, but the best one is needed
345345
// so generated tokens were not added to KVCacheState and num_tokens_to_trim was set to the size of the generated serquence
346-
kv_history_manager.num_tokens_to_trim = kv_history_manager.num_tokens_to_trim > 0 ? kv_history_manager.num_tokens_to_trim : (state.size() - first_diverse_tokens_idx);
346+
kv_cache_state.num_tokens_to_trim += state.size() - first_diverse_tokens_idx;
347347
state.resize(first_diverse_tokens_idx);
348+
kv_cache_state.reset_mem_state = state.empty();
348349
}
349350

350351
} // namespace genai

‎src/cpp/src/lm_encoding.hpp

+3-31
Original file line numberDiff line numberDiff line change
@@ -8,44 +8,16 @@
88
namespace ov {
99
namespace genai {
1010

11-
class KVCacheState {
12-
std::vector<int64_t> state;
13-
public:
14-
std::vector<int64_t>& get_state() {
15-
return state;
16-
}
17-
18-
void add_inputs(const ov::Tensor& inputs_ids) {
19-
std::copy_n(inputs_ids.data<int64_t>(), inputs_ids.get_size(), std::back_inserter(state));
20-
}
21-
22-
void reset_state() {
23-
return state.clear();
24-
}
25-
};
26-
27-
28-
struct KVCacheTrimManager
29-
{
30-
size_t num_tokens_to_trim = 0;
31-
size_t kv_cache_seq_length_axis = 2;
32-
33-
void reset() {
34-
num_tokens_to_trim = 0;
35-
}
36-
};
37-
38-
3911
ov::genai::utils::GenerationFinishInfo get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask,
4012
const std::shared_ptr<StreamerBase>& streamer_ptr, Sampler& sampler, std::vector<SequenceGroup::Ptr> sequence_groups,
41-
std::optional<ov::Tensor> position_ids, KVCacheState& m_kv_cache_state, std::optional<EmbeddingsModel> m_embedding,
13+
std::optional<ov::Tensor> position_ids, utils::KVCacheState& m_kv_cache_state, std::optional<EmbeddingsModel> m_embedding,
4214
std::optional<int64_t> rope_delta = std::nullopt, const size_t max_kv_cache_size = std::numeric_limits<size_t>::max());
4315

4416

45-
void align_kv_cache_and_history(ov::genai::KVCacheTrimManager& kv_history_manager, const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state);
17+
void align_kv_cache_and_history(const ov::Tensor& new_chat_tokens, utils::KVCacheState& kv_cache_state);
4618

4719

48-
TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state);
20+
TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, utils::KVCacheState& kv_cache_state);
4921

5022
}
5123
}

‎src/cpp/src/utils.cpp

+18-4
Original file line numberDiff line numberDiff line change
@@ -325,13 +325,27 @@ KVAxesPosition get_kv_axes_pos(std::shared_ptr<const ov::Model> model) {
325325
return kv_pos;
326326
}
327327

328-
void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional<AdapterController> adapter_controller) {
328+
void trim_kv_cache(ov::InferRequest request, KVCacheState& kv_cache_state, std::optional<AdapterController> adapter_controller) {
329+
if (kv_cache_state.reset_mem_state) {
330+
if (adapter_controller) {
331+
for(auto& state: request.query_state()) {
332+
if(!adapter_controller->has_state_name(state.get_name())) {
333+
state.reset();
334+
}
335+
}
336+
} else {
337+
request.reset_state();
338+
}
339+
340+
return;
341+
}
342+
329343
// nothing to trim in this case
330-
if (remove_from_end == 0)
344+
if (kv_cache_state.num_tokens_to_trim == 0)
331345
return;
332346

333347
auto states = request.query_state();
334-
348+
335349
OPENVINO_ASSERT(states.size() > 0, "Request contains no states.");
336350

337351
for (auto& state : states) {
@@ -341,7 +355,7 @@ void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t se
341355
ov::Tensor old_tensor = state.get_state();
342356
// [BATCH_SIZE, num_kv_heads, seq_len, head_size]
343357
auto shape = old_tensor.get_shape();
344-
shape[seq_length_axis] -= remove_from_end;
358+
shape[kv_cache_state.seq_length_axis] -= kv_cache_state.num_tokens_to_trim;
345359

346360
ov::Coordinate new_shape_begin{0, 0, 0, 0};
347361
ov::Coordinate new_shape_end{shape};

‎src/cpp/src/utils.hpp

+23-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,29 @@ struct KVAxesPosition {
102102

103103
KVAxesPosition get_kv_axes_pos(std::shared_ptr<const ov::Model> model);
104104

105-
void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional<AdapterController> adapter_controller);
105+
class KVCacheState {
106+
std::vector<int64_t> state;
107+
public:
108+
size_t num_tokens_to_trim = 0;
109+
size_t seq_length_axis = 2;
110+
bool reset_mem_state = false;
111+
112+
std::vector<int64_t>& get_state() {
113+
return state;
114+
}
115+
116+
void add_inputs(const ov::Tensor& inputs_ids) {
117+
std::copy_n(inputs_ids.data<int64_t>(), inputs_ids.get_size(), std::back_inserter(state));
118+
}
119+
120+
void reset_state() {
121+
reset_mem_state = false;
122+
num_tokens_to_trim = 0;
123+
state.clear();
124+
}
125+
};
126+
127+
void trim_kv_cache(ov::InferRequest request, KVCacheState& kv_cache_state, std::optional<AdapterController> adapter_controller);
106128

107129
ov::Tensor push_front_inputs(const ov::Tensor& base_tensor, int64_t add_to_front);
108130

‎src/cpp/src/visual_language/inputs_embedder.cpp

+21-16
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ std::pair<ov::Tensor, std::optional<int64_t>> InputsEmbedder::IInputsEmbedder::g
2929

3030
void InputsEmbedder::IInputsEmbedder::start_chat(const std::string& system_message) {
3131
m_is_chat_conversation = true;
32-
m_kv_history_trim_manager.reset();
3332
if (!m_kv_cache_state.get_state().empty()) {
3433
m_history.clear();
3534
m_kv_cache_state.reset_state();
@@ -40,17 +39,26 @@ void InputsEmbedder::IInputsEmbedder::start_chat(const std::string& system_messa
4039
m_history = {{{"role", "system"}, {"content", system_message}}};
4140
}
4241

43-
void InputsEmbedder::IInputsEmbedder::update_chat_history(const std::string& decoded_results) {
44-
// Tail of chat template is missing in KV cache.
45-
// Find the tail to concatenate it with the next input prompt.
46-
m_history.push_back({{"role", "assistant"}, {"content", decoded_results}});
47-
m_kv_history_trim_manager.reset();
42+
void InputsEmbedder::IInputsEmbedder::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
43+
m_kv_cache_state.num_tokens_to_trim = 0;
44+
if (generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
45+
// If chat generation process was cancelled by user, let's rollback to previous state of history
46+
m_history.pop_back();
47+
48+
std::vector<int64_t>& state = m_kv_cache_state.get_state();
49+
50+
m_kv_cache_state.num_tokens_to_trim = state.size() - m_prev_hist_length;
51+
state.resize(m_prev_hist_length);
52+
m_kv_cache_state.reset_mem_state = state.empty();
53+
} else {
54+
// Tail of chat template is missing in KV cache.
55+
// Find the tail to concatenate it with the next input prompt.
56+
m_history.push_back({{"role", "assistant"}, {"content", decoded_results}});
57+
}
4858
}
4959

5060
void InputsEmbedder::IInputsEmbedder::finish_chat() {
5161
m_is_chat_conversation = false;
52-
m_kv_history_trim_manager.reset();
53-
5462
m_history.clear();
5563
m_kv_cache_state.reset_state();
5664
}
@@ -123,7 +131,7 @@ ov::Tensor InputsEmbedder::IInputsEmbedder::apply_chat_template_tokenize(const s
123131
ov::Tensor InputsEmbedder::IInputsEmbedder::update_history(const ov::Tensor& new_chat_tokens) {
124132
ov::Tensor encoded_inputs;
125133
if (m_is_chat_conversation) {
126-
ov::genai::align_kv_cache_and_history(m_kv_history_trim_manager, new_chat_tokens, m_kv_cache_state);
134+
ov::genai::align_kv_cache_and_history(new_chat_tokens, m_kv_cache_state);
127135
encoded_inputs = get_chat_encoded_input(new_chat_tokens, m_kv_cache_state).input_ids;
128136
} else {
129137
encoded_inputs = new_chat_tokens;
@@ -135,6 +143,7 @@ ov::Tensor InputsEmbedder::IInputsEmbedder::update_history(const ov::Tensor& new
135143
ov::Tensor InputsEmbedder::IInputsEmbedder::get_encoded_input_ids(const std::string& prompt, ov::genai::VLMPerfMetrics& metrics) {
136144
const auto new_chat_tokens = apply_chat_template_tokenize(prompt, metrics);
137145
auto new_input_ids = update_history(new_chat_tokens);
146+
m_prev_hist_length = m_kv_cache_state.get_state().size();
138147
m_kv_cache_state.add_inputs(new_input_ids);
139148

140149
return new_input_ids;
@@ -225,14 +234,10 @@ EmbeddingsModel InputsEmbedder::get_embedding_model() const {
225234
return m_impl->get_embedding_model();
226235
}
227236

228-
KVCacheState& InputsEmbedder::get_kv_cache_state() {
237+
ov::genai::utils::KVCacheState& InputsEmbedder::get_kv_cache_state() {
229238
return m_impl->get_kv_cache_state();
230239
}
231240

232-
size_t InputsEmbedder::get_num_tokens_to_remove_from_hist() const {
233-
return m_impl->get_num_tokens_to_remove_from_hist();
234-
}
235-
236241
Tokenizer InputsEmbedder::get_tokenizer() const {
237242
return m_impl->get_tokenizer();
238243
}
@@ -241,8 +246,8 @@ void InputsEmbedder::start_chat(const std::string& system_message) {
241246
return m_impl->start_chat(system_message);
242247
}
243248

244-
void InputsEmbedder::update_chat_history(const std::string& decoded_results) {
245-
return m_impl->update_chat_history(decoded_results);
249+
void InputsEmbedder::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
250+
return m_impl->update_chat_history(decoded_results, generation_finish_status);
246251
}
247252

248253
void InputsEmbedder::set_apply_chat_template_status(bool apply_chat_template) {

0 commit comments

Comments
 (0)
Please sign in to comment.