@@ -35,8 +35,10 @@ class InputsEmbedder::IInputsEmbedder {
35
35
ChatHistory m_history;
36
36
// Templated chat history
37
37
std::string m_templated_chat_history;
38
- // Tokenized chat history
38
+ // Tokenized history
39
39
std::vector<int64_t > m_tokenized_history;
40
+ // Tokenized chat history on previous step
41
+ std::vector<int64_t > m_prev_tokenized_history;
40
42
// Tail of previous output for LM in chat mode is missing in KV cache.
41
43
std::optional<int64_t > m_last_disappeared_token = std::nullopt;
42
44
// If sequence contains some symbols, which could be ambiguous encoded by tokenizer, we need to trim kv cache
@@ -72,21 +74,33 @@ class InputsEmbedder::IInputsEmbedder {
72
74
return m_kv_history_manager.num_tokens_to_remove_from_kv_cache ;
73
75
}
74
76
77
+ bool should_reset_kv_cache () const {
78
+ return m_kv_history_manager.reset_kv_cache ;
79
+ }
80
+
75
81
void set_stop_token_ids (const std::set<int64_t >& stop_token_ids) {
76
82
m_stop_token_ids = stop_token_ids;
77
83
}
78
84
79
- void update_tokenized_history (const std::vector< int64_t >& encoded_result, std::optional< int64_t > last_disappeared_token , bool is_beam_search, size_t last_answer_len) {
85
+ void update_tokenized_history (const ov::genai::utils::GenerationFinishInfo generation_finish_info , bool is_beam_search, size_t last_answer_len, size_t inputs_embeds_size ) {
80
86
if (is_beam_search) {
81
87
m_kv_history_manager.trusted_history_length = m_tokenized_history.size ();
82
88
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = last_answer_len;
83
89
} else {
84
90
m_kv_history_manager.reset ();
85
91
}
86
92
87
- m_last_disappeared_token = last_disappeared_token;
88
-
89
- std::copy (encoded_result.begin (), encoded_result.end (), std::back_inserter (m_tokenized_history));
93
+ m_last_disappeared_token = generation_finish_info.probably_disappeared_token ;
94
+
95
+ if (generation_finish_info.streaming_finish_status == ov::genai::GenerationStatus::CANCEL) {
96
+ // let's remove last answer and prompt
97
+ m_kv_history_manager.num_tokens_to_remove_from_kv_cache = inputs_embeds_size + last_answer_len;
98
+ m_tokenized_history = std::move (m_prev_tokenized_history);
99
+ m_kv_history_manager.reset_kv_cache = m_tokenized_history.empty ();
100
+ } else {
101
+ auto encoded_result = generation_finish_info.results .tokens [0 ];
102
+ std::copy (encoded_result.begin (), encoded_result.end (), std::back_inserter (m_tokenized_history));
103
+ }
90
104
}
91
105
92
106
void set_apply_chat_template_status (bool apply_chat_template) {
@@ -100,6 +114,7 @@ class InputsEmbedder::IInputsEmbedder {
100
114
m_history.clear ();
101
115
m_templated_chat_history.clear ();
102
116
m_tokenized_history.clear ();
117
+ m_prev_tokenized_history.clear ();
103
118
}
104
119
if (system_message.empty ()) {
105
120
return ;
@@ -109,11 +124,16 @@ class InputsEmbedder::IInputsEmbedder {
109
124
m_templated_chat_history = m_tokenizer.apply_chat_template (m_history, add_generation_prompt);
110
125
}
111
126
112
- void update_chat_history (const std::string& decoded_results) {
113
- // Tail of chat template is missing in KV cache.
114
- // Find the tail to concatenate it with the next input prompt.
115
- m_templated_chat_history.append (decoded_results);
116
- m_history.push_back ({{" role" , " assistant" }, {" content" , decoded_results}});
127
+ void update_chat_history (const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
128
+ if (generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
129
+ // If chat generation process was cancelled by user, let's rollback to previous state of history
130
+ m_history.pop_back ();
131
+ } else {
132
+ // Tail of chat template is missing in KV cache.
133
+ // Find the tail to concatenate it with the next input prompt.
134
+ m_templated_chat_history.append (decoded_results);
135
+ m_history.push_back ({{" role" , " assistant" }, {" content" , decoded_results}});
136
+ }
117
137
}
118
138
119
139
virtual void finish_chat () {
@@ -123,6 +143,7 @@ class InputsEmbedder::IInputsEmbedder {
123
143
m_history.clear ();
124
144
m_templated_chat_history.clear ();
125
145
m_tokenized_history.clear ();
146
+ m_prev_tokenized_history.clear ();
126
147
}
127
148
128
149
protected:
@@ -213,21 +234,29 @@ class InputsEmbedder::IInputsEmbedder {
213
234
trusted_history_length = ov::genai::utils::get_first_history_difference (prev_chat_tokens, m_tokenized_history, m_stop_token_ids);
214
235
}
215
236
237
+ m_prev_tokenized_history.clear ();
216
238
if (m_tokenized_history.empty ()) {
217
239
encoded_input_ids = new_chat_tokens;
218
-
219
240
} else if (trusted_history_length != SIZE_MAX || m_kv_history_manager.does_history_cache_need_to_update ()) {
220
241
// does_history_cache_need_to_update will be true here if beam search is activated
221
242
// in beam search mode we want to remove all history about last model answer from kv cache and add the best answer directly
222
243
// if we have difference in model answer and decoded answer it anyway will be less then entire history, so let's use data from m_kv_history_manager
223
244
if (m_kv_history_manager.does_history_cache_need_to_update ()) {
224
245
trusted_history_length = m_kv_history_manager.trusted_history_length ;
225
246
} else {
226
- m_kv_history_manager. num_tokens_to_remove_from_kv_cache = m_tokenized_history.size () - trusted_history_length;
247
+ auto num_tokens_to_remove_from_kv_cache = m_tokenized_history.size () - trusted_history_length;
227
248
// last generated token is present in tokenized_history, but not included to attention mask, let's keep it in history
228
- m_kv_history_manager.num_tokens_to_remove_from_kv_cache -= 1 ;
249
+ if (num_tokens_to_remove_from_kv_cache > 0 )
250
+ num_tokens_to_remove_from_kv_cache -= 1 ;
251
+
252
+ // if streaming was used and cancelled on prev step, m_kv_history_manager.num_tokens_to_remove_from_kv_cache could be already set
253
+ // and it would be bigger as it includes answer + prompt
254
+ m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_kv_history_manager.num_tokens_to_remove_from_kv_cache > num_tokens_to_remove_from_kv_cache ?
255
+ m_kv_history_manager.num_tokens_to_remove_from_kv_cache : num_tokens_to_remove_from_kv_cache;
229
256
}
230
257
258
+ std::copy_n (m_tokenized_history.data (), trusted_history_length, std::back_inserter (m_prev_tokenized_history));
259
+
231
260
ov::Tensor new_tensor = ov::Tensor (new_chat_tokens.get_element_type (),
232
261
{1 , new_chat_tokens.get_shape ().at (1 ) - trusted_history_length},
233
262
new_chat_tokens.data <int64_t >() + trusted_history_length);
@@ -239,8 +268,12 @@ class InputsEmbedder::IInputsEmbedder {
239
268
{new_chat_tokens}, {prev_chat_tokens}
240
269
).input_ids ;
241
270
242
- if (m_last_disappeared_token.has_value ())
271
+ if (m_last_disappeared_token.has_value ()) {
243
272
encoded_input_ids = ov::genai::utils::push_front_inputs (encoded_input_ids, *m_last_disappeared_token);
273
+ std::copy_n (prev_chat_tokens.data <int64_t >(), prev_chat_tokens.get_size () - 1 , std::back_inserter (m_prev_tokenized_history));
274
+ } else {
275
+ std::copy_n (prev_chat_tokens.data <int64_t >(), prev_chat_tokens.get_size (), std::back_inserter (m_prev_tokenized_history));
276
+ }
244
277
}
245
278
m_tokenized_history.clear ();
246
279
std::copy_n (new_chat_tokens.data <int64_t >(), new_chat_tokens.get_size (), std::back_inserter (m_tokenized_history));
@@ -1436,6 +1469,8 @@ ov::Tensor insert_image_placeholders(const std::vector<ov::Tensor>& chunks, cons
1436
1469
length,
1437
1470
merged.data <int64_t >() + offset
1438
1471
);
1472
+ if (tokens_per_images.empty ())
1473
+ continue ;
1439
1474
offset += length;
1440
1475
if (offset < merged_length) {
1441
1476
std::fill_n (
@@ -1576,6 +1611,12 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder {
1576
1611
IInputsEmbedder::finish_chat ();
1577
1612
m_tokens_per_images.clear ();
1578
1613
}
1614
+
1615
+ virtual void update_tokenized_history (const ov::genai::utils::GenerationFinishInfo generation_finish_info, bool is_beam_search, size_t last_answer_len, size_t full_len) {
1616
+ IInputsEmbedder::update_tokenized_history (generation_finish_info, is_beam_search, last_answer_len, full_len);
1617
+ if (generation_finish_info.streaming_finish_status == ov::genai::GenerationStatus::CANCEL)
1618
+ m_tokens_per_images.clear ();
1619
+ }
1579
1620
};
1580
1621
1581
1622
class InputsEmbedderQwen2VL : public InputsEmbedder ::IInputsEmbedder {
@@ -2040,14 +2081,18 @@ std::vector<int64_t> InputsEmbedder::get_tokenized_history() const {
2040
2081
return m_impl->get_tokenized_history ();
2041
2082
}
2042
2083
2043
- void InputsEmbedder::update_tokenized_history (const std::vector< int64_t >& encoded_result, std::optional< int64_t > last_disappeared_token , bool is_beam_search, size_t last_answer_len) {
2044
- return m_impl->update_tokenized_history (encoded_result, last_disappeared_token, is_beam_search, last_answer_len);
2084
+ void InputsEmbedder::update_tokenized_history (const ov::genai::utils::GenerationFinishInfo generation_finish_info , bool is_beam_search, size_t last_answer_len, size_t inputs_embeds_size ) {
2085
+ return m_impl->update_tokenized_history (generation_finish_info, is_beam_search, last_answer_len, inputs_embeds_size );
2045
2086
}
2046
2087
2047
2088
size_t InputsEmbedder::get_num_tokens_to_remove_from_hist () const {
2048
2089
return m_impl->get_num_tokens_to_remove_from_hist ();
2049
2090
}
2050
2091
2092
+ bool InputsEmbedder::should_reset_kv_cache () const {
2093
+ return m_impl->should_reset_kv_cache ();
2094
+ }
2095
+
2051
2096
Tokenizer InputsEmbedder::get_tokenizer () const {
2052
2097
return m_impl->get_tokenizer ();
2053
2098
}
@@ -2056,8 +2101,8 @@ void InputsEmbedder::start_chat(const std::string& system_message) {
2056
2101
return m_impl->start_chat (system_message);
2057
2102
}
2058
2103
2059
- void InputsEmbedder::update_chat_history (const std::string& decoded_results) {
2060
- return m_impl->update_chat_history (decoded_results);
2104
+ void InputsEmbedder::update_chat_history (const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status ) {
2105
+ return m_impl->update_chat_history (decoded_results, generation_finish_status );
2061
2106
}
2062
2107
2063
2108
void InputsEmbedder::set_apply_chat_template_status (bool apply_chat_template) {
0 commit comments