@@ -29,7 +29,6 @@ std::pair<ov::Tensor, std::optional<int64_t>> InputsEmbedder::IInputsEmbedder::g
29
29
30
30
void InputsEmbedder::IInputsEmbedder::start_chat (const std::string& system_message) {
31
31
m_is_chat_conversation = true ;
32
- m_kv_history_trim_manager.reset ();
33
32
if (!m_kv_cache_state.get_state ().empty ()) {
34
33
m_history.clear ();
35
34
m_kv_cache_state.reset_state ();
@@ -40,17 +39,26 @@ void InputsEmbedder::IInputsEmbedder::start_chat(const std::string& system_messa
40
39
m_history = {{{" role" , " system" }, {" content" , system_message}}};
41
40
}
42
41
43
- void InputsEmbedder::IInputsEmbedder::update_chat_history (const std::string& decoded_results) {
44
- // Tail of chat template is missing in KV cache.
45
- // Find the tail to concatenate it with the next input prompt.
46
- m_history.push_back ({{" role" , " assistant" }, {" content" , decoded_results}});
47
- m_kv_history_trim_manager.reset ();
42
+ void InputsEmbedder::IInputsEmbedder::update_chat_history (const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
43
+ m_kv_cache_state.num_tokens_to_trim = 0 ;
44
+ if (generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
45
+ // If chat generation process was cancelled by user, let's rollback to previous state of history
46
+ m_history.pop_back ();
47
+
48
+ std::vector<int64_t >& state = m_kv_cache_state.get_state ();
49
+
50
+ m_kv_cache_state.num_tokens_to_trim = state.size () - m_prev_hist_length;
51
+ state.resize (m_prev_hist_length);
52
+ m_kv_cache_state.reset_mem_state = state.empty ();
53
+ } else {
54
+ // Tail of chat template is missing in KV cache.
55
+ // Find the tail to concatenate it with the next input prompt.
56
+ m_history.push_back ({{" role" , " assistant" }, {" content" , decoded_results}});
57
+ }
48
58
}
49
59
50
60
void InputsEmbedder::IInputsEmbedder::finish_chat () {
51
61
m_is_chat_conversation = false ;
52
- m_kv_history_trim_manager.reset ();
53
-
54
62
m_history.clear ();
55
63
m_kv_cache_state.reset_state ();
56
64
}
@@ -123,7 +131,7 @@ ov::Tensor InputsEmbedder::IInputsEmbedder::apply_chat_template_tokenize(const s
123
131
ov::Tensor InputsEmbedder::IInputsEmbedder::update_history (const ov::Tensor& new_chat_tokens) {
124
132
ov::Tensor encoded_inputs;
125
133
if (m_is_chat_conversation) {
126
- ov::genai::align_kv_cache_and_history (m_kv_history_trim_manager, new_chat_tokens, m_kv_cache_state);
134
+ ov::genai::align_kv_cache_and_history (new_chat_tokens, m_kv_cache_state);
127
135
encoded_inputs = get_chat_encoded_input (new_chat_tokens, m_kv_cache_state).input_ids ;
128
136
} else {
129
137
encoded_inputs = new_chat_tokens;
@@ -135,6 +143,7 @@ ov::Tensor InputsEmbedder::IInputsEmbedder::update_history(const ov::Tensor& new
135
143
ov::Tensor InputsEmbedder::IInputsEmbedder::get_encoded_input_ids (const std::string& prompt, ov::genai::VLMPerfMetrics& metrics) {
136
144
const auto new_chat_tokens = apply_chat_template_tokenize (prompt, metrics);
137
145
auto new_input_ids = update_history (new_chat_tokens);
146
+ m_prev_hist_length = m_kv_cache_state.get_state ().size ();
138
147
m_kv_cache_state.add_inputs (new_input_ids);
139
148
140
149
return new_input_ids;
@@ -225,14 +234,10 @@ EmbeddingsModel InputsEmbedder::get_embedding_model() const {
225
234
return m_impl->get_embedding_model ();
226
235
}
227
236
228
- KVCacheState& InputsEmbedder::get_kv_cache_state () {
237
+ ov::genai::utils:: KVCacheState& InputsEmbedder::get_kv_cache_state () {
229
238
return m_impl->get_kv_cache_state ();
230
239
}
231
240
232
- size_t InputsEmbedder::get_num_tokens_to_remove_from_hist () const {
233
- return m_impl->get_num_tokens_to_remove_from_hist ();
234
- }
235
-
236
241
Tokenizer InputsEmbedder::get_tokenizer () const {
237
242
return m_impl->get_tokenizer ();
238
243
}
@@ -241,8 +246,8 @@ void InputsEmbedder::start_chat(const std::string& system_message) {
241
246
return m_impl->start_chat (system_message);
242
247
}
243
248
244
- void InputsEmbedder::update_chat_history (const std::string& decoded_results) {
245
- return m_impl->update_chat_history (decoded_results);
249
+ void InputsEmbedder::update_chat_history (const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status ) {
250
+ return m_impl->update_chat_history (decoded_results, generation_finish_status );
246
251
}
247
252
248
253
void InputsEmbedder::set_apply_chat_template_status (bool apply_chat_template) {
0 commit comments