@@ -60,7 +60,7 @@ StatefulLLMPipeline::StatefulLLMPipeline(
60
60
}
61
61
62
62
if (!m_use_full_chat_history)
63
- m_kv_history_trim_manager. kv_cache_seq_length_axis = kv_pos.seq_len ;
63
+ m_kv_cache_state. seq_length_axis = kv_pos.seq_len ;
64
64
65
65
auto filtered_properties = extract_adapters_from_properties (properties, &m_generation_config.adapters );
66
66
if (m_generation_config.adapters ) {
@@ -143,7 +143,7 @@ DecodedResults StatefulLLMPipeline::generate(
143
143
if (m_use_full_chat_history) {
144
144
encoded_input = new_chat_tokens;
145
145
} else {
146
- ov::genai::align_kv_cache_and_history (m_kv_history_trim_manager, new_chat_tokens.input_ids , m_kv_cache_state);
146
+ ov::genai::align_kv_cache_and_history (new_chat_tokens.input_ids , m_kv_cache_state);
147
147
encoded_input = get_chat_encoded_input (new_chat_tokens.input_ids , m_kv_cache_state);
148
148
}
149
149
// TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied
@@ -232,7 +232,7 @@ EncodedResults StatefulLLMPipeline::generate(
232
232
// Tail of previous output in chat mode is missing in KV cache.
233
233
if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
234
234
ov::Tensor new_chat_tokens = ov::Tensor{ov::element::i64, {1 , m_tokenized_chat_history.size ()}, m_tokenized_chat_history.data ()};
235
- ov::genai::align_kv_cache_and_history (m_kv_history_trim_manager, new_chat_tokens, m_kv_cache_state);
235
+ ov::genai::align_kv_cache_and_history (new_chat_tokens, m_kv_cache_state);
236
236
237
237
auto encoded_input = get_chat_encoded_input (new_chat_tokens, m_kv_cache_state);
238
238
input_ids = encoded_input.input_ids ;
@@ -278,8 +278,8 @@ EncodedResults StatefulLLMPipeline::generate(
278
278
if (m_kv_cache_state.get_state ().empty () || m_use_full_chat_history)
279
279
reset_kv_state ();
280
280
else
281
- ov::genai::utils::trim_kv_cache (m_model_runner, m_kv_history_trim_manager .num_tokens_to_trim ,
282
- m_kv_history_trim_manager. kv_cache_seq_length_axis , m_adapter_controller);
281
+ ov::genai::utils::trim_kv_cache (m_model_runner, m_kv_cache_state .num_tokens_to_trim ,
282
+ m_kv_cache_state. seq_length_axis , m_adapter_controller);
283
283
}
284
284
285
285
size_t kv_cache_len = 0 ;
@@ -352,7 +352,7 @@ EncodedResults StatefulLLMPipeline::generate(
352
352
m_chat_generation_finish_status = finish_info.streaming_finish_status ;
353
353
354
354
if (is_chat_conversation) {
355
- m_kv_history_trim_manager. reset () ;
355
+ m_kv_cache_state. num_tokens_to_trim = 0 ;
356
356
357
357
if (m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
358
358
if (m_chat_generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
@@ -361,7 +361,7 @@ EncodedResults StatefulLLMPipeline::generate(
361
361
std::copy (result.tokens [0 ].begin (), result.tokens [0 ].end (), std::back_inserter (m_tokenized_chat_history));
362
362
}
363
363
} else if (config.is_beam_search ()) {
364
- m_kv_history_trim_manager .num_tokens_to_trim = m_model_runner.get_tensor (" attention_mask" ).get_shape ()[1 ] - prev_attn_mask_size;
364
+ m_kv_cache_state .num_tokens_to_trim = m_model_runner.get_tensor (" attention_mask" ).get_shape ()[1 ] - prev_attn_mask_size;
365
365
}
366
366
} else {
367
367
m_kv_cache_state.reset_state ();
@@ -402,7 +402,6 @@ void StatefulLLMPipeline::reset_kv_state() {
402
402
403
403
void StatefulLLMPipeline::finish_chat () {
404
404
is_chat_conversation = false ;
405
- m_kv_history_trim_manager.reset ();
406
405
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
407
406
bool have_state = 0 != m_model_runner.get_tensor (" attention_mask" ).get_size ();
408
407
if (!m_kv_cache_state.get_state ().empty () || have_state) {
0 commit comments