@@ -36,13 +36,15 @@ std::pair<EncodedResults, int32_t> beam_search(
36
36
class StatefulLLMPipeline final : public LLMPipelineImplBase {
37
37
public:
38
38
ov::InferRequest m_model_runner;
39
-
40
39
bool is_chat_conversation = false ;
41
- bool m_is_cache_empty = true ;
40
+ bool m_trust_encoded_history = true ;
42
41
std::optional<int32_t > m_selected_beam = std::nullopt;
43
42
ChatHistory m_history;
44
43
std::string m_templated_chat_history = {};
45
- TokenizedInputs m_tokenized_chat_history;
44
+ std::vector<int64_t > m_tokenized_chat_history;
45
+ ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
46
+ size_t m_to_remove_from_hist = 0 ;
47
+ size_t m_kv_cache_seq_length_axis = 2 ;
46
48
47
49
StatefulLLMPipeline (
48
50
const ov::InferRequest& request,
@@ -77,6 +79,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
77
79
ov::Core core;
78
80
auto [core_plugin_config, plugin_config] = ov::genai::utils::split_core_compile_config (config);
79
81
utils::slice_matmul_statefull_model (model);
82
+ m_kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis (model);
80
83
81
84
if (auto filtered_plugin_config = extract_adapters_from_properties (plugin_config, &m_generation_config.adapters )) {
82
85
m_generation_config.adapters ->set_tensor_name_prefix (" base_model.model.model." );
@@ -102,8 +105,20 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
102
105
OptionalGenerationConfig generation_config,
103
106
StreamerVariant streamer
104
107
) override {
108
+ if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::UNDEF)
109
+ m_chat_input_type = ov::genai::utils::GenerationChatInputsType::STRING;
110
+
111
+ if (is_chat_conversation)
112
+ OPENVINO_ASSERT (m_chat_input_type != ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS,
113
+ " Chat doesn't support switching between input types. Please, continue using EncodedInputs or restart the chat." );
114
+
105
115
auto start_time = std::chrono::steady_clock::now ();
106
116
GenerationConfig config = (generation_config.has_value ()) ? *generation_config : m_generation_config;
117
+ // If eos_token_id was not provided, take value from default m_generation_config
118
+ if (config.eos_token_id == -1 )
119
+ config.set_eos_token_id (m_generation_config.eos_token_id );
120
+ config.validate ();
121
+
107
122
TokenizedInputs encoded_input;
108
123
109
124
if (auto input_vector = std::get_if<std::vector<std::string>>(&inputs)) {
@@ -127,19 +142,51 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
127
142
auto new_templated_chat_history = m_tokenizer.apply_chat_template (m_history, add_generation_prompt);
128
143
// Do not add special tokens in chat scenario to be aligned with HF.
129
144
auto new_chat_tokens = m_tokenizer.encode (new_templated_chat_history, ov::genai::add_special_tokens (false ));
130
- if (m_is_cache_empty) {
145
+ auto prev_chat_tokens = m_tokenizer.encode (m_templated_chat_history, ov::genai::add_special_tokens (false ));
146
+
147
+ // some symbols combinations can be encoded by the tokenizer in different ways
148
+ // if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history
149
+ // so let's check it out, find the trusted part and use it in on the next step
150
+ size_t last_same_hist_token = 0 ;
151
+ if (!m_tokenized_chat_history.empty ()) {
152
+ std::set<int64_t > stop_tokens = config.stop_token_ids ;
153
+ last_same_hist_token = ov::genai::utils::get_first_history_difference (prev_chat_tokens.input_ids , m_tokenized_chat_history, stop_tokens);
154
+ m_trust_encoded_history = last_same_hist_token == SIZE_MAX;
155
+ }
156
+
157
+ if (m_tokenized_chat_history.empty ()) {
131
158
encoded_input = new_chat_tokens;
159
+ } else if (last_same_hist_token != SIZE_MAX) {
160
+ m_to_remove_from_hist = m_tokenized_chat_history.size () - last_same_hist_token;
161
+
162
+ ov::Tensor new_tensor = ov::Tensor (new_chat_tokens.input_ids .get_element_type (),
163
+ {1 , new_chat_tokens.input_ids .get_shape ().at (1 ) - last_same_hist_token},
164
+ new_chat_tokens.input_ids .data <int64_t >() + last_same_hist_token);
165
+
166
+ ov::Tensor new_attention_mask (ov::element::i64, new_tensor.get_shape ());
167
+ std::fill_n (new_attention_mask.data <int64_t >(), new_tensor.get_shape ()[1 ], 1 );
168
+
169
+ encoded_input.input_ids = ov::Tensor (new_chat_tokens.input_ids .get_element_type (),
170
+ {1 , new_chat_tokens.input_ids .get_shape ().at (1 ) - last_same_hist_token});
171
+ new_tensor.copy_to (encoded_input.input_ids );
172
+ encoded_input.attention_mask = new_attention_mask;
173
+
174
+ m_selected_beam = std::nullopt;
132
175
} else {
133
- auto prev_chat_tokens = m_tokenizer.encode (m_templated_chat_history, ov::genai::add_special_tokens (false ));
134
176
encoded_input = utils::subtract_chat_tokenized_inputs (new_chat_tokens, prev_chat_tokens);
135
177
}
136
178
m_templated_chat_history = new_templated_chat_history;
137
- m_tokenized_chat_history = new_chat_tokens;
179
+ m_tokenized_chat_history.clear ();
180
+ m_tokenized_chat_history.reserve (new_chat_tokens.input_ids .get_size ());
181
+ std::copy_n (new_chat_tokens.input_ids .data <int64_t >(), new_chat_tokens.input_ids .get_size (),
182
+ std::back_inserter (m_tokenized_chat_history));
183
+
138
184
// TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied
139
185
} else {
140
186
encoded_input = m_tokenizer.encode (prompt);
141
187
}
142
188
}
189
+
143
190
auto encode_stop_time = std::chrono::steady_clock::now ();
144
191
auto encoded_results = generate (encoded_input, config, streamer);
145
192
@@ -188,6 +235,14 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
188
235
OptionalGenerationConfig generation_config,
189
236
StreamerVariant streamer
190
237
) override {
238
+ if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::UNDEF)
239
+ m_chat_input_type = ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS;
240
+
241
+ if (is_chat_conversation)
242
+ // if chat was run in StringInputs mode, but it was called EncodedInputs generate, last m_history entry will be with assistant role
243
+ OPENVINO_ASSERT (m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS || m_history.back ()[" role" ] == " user" ,
244
+ " Chat doesn't support switching between input types. Please, continue using StringInputs or restart the chat." );
245
+
191
246
auto start_time = std::chrono::steady_clock::now ();
192
247
ov::Tensor input_ids;
193
248
ov::Tensor attention_mask;
@@ -199,6 +254,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
199
254
attention_mask = data->attention_mask ;
200
255
}
201
256
257
+ if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
258
+ std::copy (input_ids.data <int64_t >(), input_ids.data <int64_t >() + input_ids.get_size (), std::back_inserter (m_tokenized_chat_history));
259
+
202
260
GenerationConfig config = (generation_config.has_value ()) ? *generation_config : m_generation_config;
203
261
204
262
// If eos_token_id was not provided, take value from default m_generation_config
@@ -230,16 +288,17 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
230
288
" (input_ids, attention_mask, position_ids, beam_idx) "
231
289
" but you have '" + std::to_string (num_inputs) + " ' inputs" );
232
290
291
+ ov::genai::utils::trim_kv_cache (m_model_runner, m_to_remove_from_hist, m_kv_cache_seq_length_axis, m_adapter_controller);
233
292
234
293
size_t kv_cache_len = 0 ;
235
294
ov::Tensor concatenated_attention_mask;
236
- if (is_chat_conversation && !m_is_cache_empty ) {
295
+ if (is_chat_conversation && !m_tokenized_chat_history. empty () ) {
237
296
OPENVINO_ASSERT (batch_size == 1 , " continuation of generation is possible only for batch 1" );
238
297
// If history is saved in KV cache, concatenate new attention_mask with the already existing.
239
298
// Between subsequent runs attention_mask should not be modified.
240
299
auto atten_mask_history = m_model_runner.get_tensor (" attention_mask" );
241
300
auto prompt_len = attention_mask.get_shape ()[1 ];
242
- kv_cache_len = atten_mask_history.get_shape ()[1 ];
301
+ kv_cache_len = atten_mask_history.get_shape ()[1 ] - m_to_remove_from_hist ;
243
302
244
303
ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}};
245
304
auto start_atten_hst = atten_mask_history.data <int64_t >() + kv_cache_len * (*m_selected_beam);
@@ -263,6 +322,11 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
263
322
m_adapter_controller->apply (m_model_runner, config.adapters );
264
323
}
265
324
325
+ if (is_chat_conversation && !m_trust_encoded_history) {
326
+ m_trust_encoded_history = true ;
327
+ m_to_remove_from_hist = 0 ;
328
+ }
329
+
266
330
ov::genai::EncodedResults result;
267
331
if (config.is_beam_search () && is_chat_conversation) {
268
332
std::tie (result, m_selected_beam) = beam_search (m_model_runner, input_ids, concatenated_attention_mask,
@@ -274,8 +338,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
274
338
275
339
for (size_t request_id = 0 ; request_id < batch_size; request_id++) {
276
340
SequenceGroup::Ptr sequence_group;
277
- if (is_chat_conversation && !m_is_cache_empty) {
278
- sequence_group = std::make_shared<SequenceGroup>(request_id, m_tokenized_chat_history.input_ids , config, block_size, enable_prefix_caching);
341
+ if (is_chat_conversation) {
342
+ ov::Tensor tokenized_chat_history = ov::Tensor (ov::element::i64, {1 , m_tokenized_chat_history.size ()}, m_tokenized_chat_history.data ());
343
+ sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching);
279
344
} else {
280
345
size_t seq_len = input_ids.get_shape ().at (1 );
281
346
size_t batch_offset = request_id * seq_len;
@@ -294,12 +359,13 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
294
359
sampler, requests, position_ids, std::nullopt, m_selected_beam);
295
360
}
296
361
297
- if (!is_chat_conversation) {
362
+ if (is_chat_conversation) {
363
+ std::copy (result.tokens [0 ].begin (), result.tokens [0 ].end (), std::back_inserter (m_tokenized_chat_history));
364
+ } else {
298
365
reset_kv_state ();
299
366
m_selected_beam = std::nullopt;
300
- } else {
301
- m_is_cache_empty = false ;
302
367
}
368
+
303
369
auto stop_time = std::chrono::steady_clock::now ();
304
370
305
371
// If is called without tokenization then that stat will not be reported.
@@ -313,12 +379,15 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
313
379
314
380
void start_chat (const std::string& system_message) override {
315
381
is_chat_conversation = true ;
316
- m_selected_beam = std::nullopt;
317
- if (!m_is_cache_empty) {
382
+ m_selected_beam = std::nullopt;
383
+ m_trust_encoded_history = true ;
384
+ m_to_remove_from_hist = 0 ;
385
+ m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
386
+ if (!m_tokenized_chat_history.empty ()) {
318
387
reset_kv_state ();
319
- m_is_cache_empty = true ;
320
388
m_history = {};
321
389
m_templated_chat_history = " " ;
390
+ m_tokenized_chat_history.clear ();
322
391
}
323
392
if (system_message.empty ())
324
393
return ;
@@ -332,11 +401,14 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
332
401
void finish_chat () override {
333
402
is_chat_conversation = false ;
334
403
m_selected_beam = std::nullopt;
335
- if (!m_is_cache_empty) {
404
+ m_trust_encoded_history = true ;
405
+ m_to_remove_from_hist = 0 ;
406
+ m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
407
+ if (!m_tokenized_chat_history.empty ()) {
336
408
reset_kv_state ();
337
- m_is_cache_empty = true ;
338
409
m_history.clear ();
339
410
m_templated_chat_history.clear ();
411
+ m_tokenized_chat_history.clear ();
340
412
}
341
413
}
342
414
};
0 commit comments