@@ -35,8 +35,10 @@ class InputsEmbedder::IInputsEmbedder {
35
35
ChatHistory m_history;
36
36
// Templated chat history
37
37
std::string m_templated_chat_history;
38
- // Tokenized chat history
38
+ // Tokenized history
39
39
std::vector<int64_t > m_tokenized_history;
40
+ // Tokenized chat history on previous step
41
+ std::vector<int64_t > m_prev_tokenized_history;
40
42
// Tail of previous output for LM in chat mode is missing in KV cache.
41
43
std::optional<int64_t > m_last_disappeared_token = std::nullopt;
42
44
// If sequence contains some symbols, which could be ambiguous encoded by tokenizer, we need to trim kv cache
@@ -72,21 +74,32 @@ class InputsEmbedder::IInputsEmbedder {
72
74
return m_kv_history_manager.num_tokens_to_remove_from_kv_cache ;
73
75
}
74
76
77
+ bool should_reset_kv_cache () const {
78
+ return m_kv_history_manager.reset_kv_cache ;
79
+ }
80
+
75
81
void set_stop_token_ids (const std::set<int64_t >& stop_token_ids) {
76
82
m_stop_token_ids = stop_token_ids;
77
83
}
78
84
79
- void update_tokenized_history (const std::vector< int64_t >& encoded_result, std::optional< int64_t > last_disappeared_token , bool is_beam_search, size_t last_answer_len) {
85
+ void update_tokenized_history (const ov::genai::utils::GenerationFinishInfo generation_finish_info , bool is_beam_search, size_t last_answer_len, size_t inputs_embeds_size ) {
80
86
if (is_beam_search) {
81
87
m_kv_history_manager.trusted_history_length = m_tokenized_history.size ();
82
88
m_kv_history_manager.num_tokens_to_remove_from_kv_cache = last_answer_len;
83
89
} else {
84
90
m_kv_history_manager.reset ();
85
91
}
86
92
87
- m_last_disappeared_token = last_disappeared_token;
88
-
89
- std::copy (encoded_result.begin (), encoded_result.end (), std::back_inserter (m_tokenized_history));
93
+ m_last_disappeared_token = generation_finish_info.probably_disappeared_token ;
94
+
95
+ if (generation_finish_info.streaming_finish_status == ov::genai::GenerationStatus::CANCEL) {
96
+ // let's remove last answer and prompt
97
+ m_kv_history_manager.num_tokens_to_remove_from_kv_cache = inputs_embeds_size + last_answer_len;
98
+ m_tokenized_history = std::move (m_prev_tokenized_history);
99
+ } else {
100
+ auto encoded_result = generation_finish_info.results .tokens [0 ];
101
+ std::copy (encoded_result.begin (), encoded_result.end (), std::back_inserter (m_tokenized_history));
102
+ }
90
103
}
91
104
92
105
void set_apply_chat_template_status (bool apply_chat_template) {
@@ -100,6 +113,7 @@ class InputsEmbedder::IInputsEmbedder {
100
113
m_history.clear ();
101
114
m_templated_chat_history.clear ();
102
115
m_tokenized_history.clear ();
116
+ m_prev_tokenized_history.clear ();
103
117
}
104
118
if (system_message.empty ()) {
105
119
return ;
@@ -109,11 +123,16 @@ class InputsEmbedder::IInputsEmbedder {
109
123
m_templated_chat_history = m_tokenizer.apply_chat_template (m_history, add_generation_prompt);
110
124
}
111
125
112
- void update_chat_history (const std::string& decoded_results) {
113
- // Tail of chat template is missing in KV cache.
114
- // Find the tail to concatenate it with the next input prompt.
115
- m_templated_chat_history.append (decoded_results);
116
- m_history.push_back ({{" role" , " assistant" }, {" content" , decoded_results}});
126
+ void update_chat_history (const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
127
+ if (generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
128
+ // If chat generation process was cancelled by user, let's rollback to previous state of history
129
+ m_history.pop_back ();
130
+ } else {
131
+ // Tail of chat template is missing in KV cache.
132
+ // Find the tail to concatenate it with the next input prompt.
133
+ m_templated_chat_history.append (decoded_results);
134
+ m_history.push_back ({{" role" , " assistant" }, {" content" , decoded_results}});
135
+ }
117
136
}
118
137
119
138
virtual void finish_chat () {
@@ -123,6 +142,7 @@ class InputsEmbedder::IInputsEmbedder {
123
142
m_history.clear ();
124
143
m_templated_chat_history.clear ();
125
144
m_tokenized_history.clear ();
145
+ m_prev_tokenized_history.clear ();
126
146
}
127
147
128
148
protected:
@@ -213,6 +233,9 @@ class InputsEmbedder::IInputsEmbedder {
213
233
trusted_history_length = ov::genai::utils::get_first_history_difference (prev_chat_tokens, m_tokenized_history, m_stop_token_ids);
214
234
}
215
235
236
+ m_prev_tokenized_history.clear ();
237
+ std::copy_n (prev_chat_tokens.data <int64_t >(), prev_chat_tokens.get_size (), std::back_inserter (m_prev_tokenized_history));
238
+
216
239
if (m_tokenized_history.empty ()) {
217
240
encoded_input_ids = new_chat_tokens;
218
241
@@ -223,9 +246,14 @@ class InputsEmbedder::IInputsEmbedder {
223
246
if (m_kv_history_manager.does_history_cache_need_to_update ()) {
224
247
trusted_history_length = m_kv_history_manager.trusted_history_length ;
225
248
} else {
226
- m_kv_history_manager. num_tokens_to_remove_from_kv_cache = m_tokenized_history.size () - trusted_history_length;
249
+ auto num_tokens_to_remove_from_kv_cache = m_tokenized_history.size () - trusted_history_length;
227
250
// last generated token is present in tokenized_history, but not included to attention mask, let's keep it in history
228
- m_kv_history_manager.num_tokens_to_remove_from_kv_cache -= 1 ;
251
+ num_tokens_to_remove_from_kv_cache -= 1 ;
252
+
253
+ // if streaming was used and cancelled on prev step, m_kv_history_manager.num_tokens_to_remove_from_kv_cache could be already set
254
+ // and it would be bigger as it includes answer + prompt
255
+ m_kv_history_manager.num_tokens_to_remove_from_kv_cache = m_kv_history_manager.num_tokens_to_remove_from_kv_cache > num_tokens_to_remove_from_kv_cache ?
256
+ m_kv_history_manager.num_tokens_to_remove_from_kv_cache : num_tokens_to_remove_from_kv_cache;
229
257
}
230
258
231
259
ov::Tensor new_tensor = ov::Tensor (new_chat_tokens.get_element_type (),
@@ -2040,14 +2068,18 @@ std::vector<int64_t> InputsEmbedder::get_tokenized_history() const {
2040
2068
return m_impl->get_tokenized_history ();
2041
2069
}
2042
2070
2043
- void InputsEmbedder::update_tokenized_history (const std::vector< int64_t >& encoded_result, std::optional< int64_t > last_disappeared_token , bool is_beam_search, size_t last_answer_len) {
2044
- return m_impl->update_tokenized_history (encoded_result, last_disappeared_token, is_beam_search, last_answer_len);
2071
+ void InputsEmbedder::update_tokenized_history (const ov::genai::utils::GenerationFinishInfo generation_finish_info , bool is_beam_search, size_t last_answer_len, size_t inputs_embeds_size ) {
2072
+ return m_impl->update_tokenized_history (generation_finish_info, is_beam_search, last_answer_len, inputs_embeds_size );
2045
2073
}
2046
2074
2047
2075
size_t InputsEmbedder::get_num_tokens_to_remove_from_hist () const {
2048
2076
return m_impl->get_num_tokens_to_remove_from_hist ();
2049
2077
}
2050
2078
2079
+ bool InputsEmbedder::should_reset_kv_cache () const {
2080
+ return m_impl->should_reset_kv_cache ();
2081
+ }
2082
+
2051
2083
Tokenizer InputsEmbedder::get_tokenizer () const {
2052
2084
return m_impl->get_tokenizer ();
2053
2085
}
@@ -2056,8 +2088,8 @@ void InputsEmbedder::start_chat(const std::string& system_message) {
2056
2088
return m_impl->start_chat (system_message);
2057
2089
}
2058
2090
2059
- void InputsEmbedder::update_chat_history (const std::string& decoded_results) {
2060
- return m_impl->update_chat_history (decoded_results);
2091
+ void InputsEmbedder::update_chat_history (const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status ) {
2092
+ return m_impl->update_chat_history (decoded_results, generation_finish_status );
2061
2093
}
2062
2094
2063
2095
void InputsEmbedder::set_apply_chat_template_status (bool apply_chat_template) {
0 commit comments