diff --git a/src/cpp/src/llm_pipeline_stateful.cpp b/src/cpp/src/llm_pipeline_stateful.cpp index 2d8ba48f3c..5050bae790 100644 --- a/src/cpp/src/llm_pipeline_stateful.cpp +++ b/src/cpp/src/llm_pipeline_stateful.cpp @@ -60,7 +60,7 @@ StatefulLLMPipeline::StatefulLLMPipeline( } if (!m_use_full_chat_history) - m_kv_history_trim_manager.kv_cache_seq_length_axis = kv_pos.seq_len; + m_kv_cache_state.seq_length_axis = kv_pos.seq_len; auto filtered_properties = extract_adapters_from_properties(properties, &m_generation_config.adapters); if (m_generation_config.adapters) { @@ -143,7 +143,7 @@ DecodedResults StatefulLLMPipeline::generate( if (m_use_full_chat_history) { encoded_input = new_chat_tokens; } else { - ov::genai::align_kv_cache_and_history(m_kv_history_trim_manager, new_chat_tokens.input_ids, m_kv_cache_state); + ov::genai::align_kv_cache_and_history(new_chat_tokens.input_ids, m_kv_cache_state); encoded_input = get_chat_encoded_input(new_chat_tokens.input_ids, m_kv_cache_state); } // TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied @@ -212,7 +212,6 @@ EncodedResults StatefulLLMPipeline::generate( reset_kv_state(); m_model_runner.get_tensor("attention_mask").set_shape({1, 0}); m_kv_cache_state.reset_state(); - m_kv_history_trim_manager.reset(); } auto start_time = std::chrono::steady_clock::now(); @@ -238,7 +237,7 @@ EncodedResults StatefulLLMPipeline::generate( // Tail of previous output in chat mode is missing in KV cache. if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) { ov::Tensor new_chat_tokens = ov::Tensor{ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data()}; - ov::genai::align_kv_cache_and_history(m_kv_history_trim_manager, new_chat_tokens, m_kv_cache_state); + ov::genai::align_kv_cache_and_history(new_chat_tokens, m_kv_cache_state); auto encoded_input = get_chat_encoded_input(new_chat_tokens, m_kv_cache_state); input_ids = encoded_input.input_ids; @@ -281,11 +280,10 @@ EncodedResults StatefulLLMPipeline::generate( "but you have '" + std::to_string(num_inputs) + "' inputs"); if (is_chat_conversation) { - if (m_kv_cache_state.get_state().empty() || m_use_full_chat_history) + if (m_use_full_chat_history) reset_kv_state(); else - ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_trim_manager.num_tokens_to_trim, - m_kv_history_trim_manager.kv_cache_seq_length_axis, m_adapter_controller); + ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_cache_state, m_adapter_controller); } size_t kv_cache_len = 0; @@ -358,7 +356,7 @@ EncodedResults StatefulLLMPipeline::generate( m_chat_generation_finish_status = finish_info.streaming_finish_status; if (is_chat_conversation) { - m_kv_history_trim_manager.reset(); + m_kv_cache_state.num_tokens_to_trim = 0; if (m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) { if (m_chat_generation_finish_status == ov::genai::GenerationStatus::CANCEL) { @@ -367,7 +365,7 @@ EncodedResults StatefulLLMPipeline::generate( std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history)); } } else if (config.is_beam_search()) { - m_kv_history_trim_manager.num_tokens_to_trim = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size; + m_kv_cache_state.num_tokens_to_trim = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size; } } @@ -406,7 +404,6 @@ void StatefulLLMPipeline::reset_kv_state() { void StatefulLLMPipeline::finish_chat() { is_chat_conversation = false; - m_kv_history_trim_manager.reset(); m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF; bool have_state = 0 != m_model_runner.get_tensor("attention_mask").get_size(); if (!m_kv_cache_state.get_state().empty() || have_state) { diff --git a/src/cpp/src/llm_pipeline_stateful.hpp b/src/cpp/src/llm_pipeline_stateful.hpp index 3558c4c1f3..04c510a0c9 100644 --- a/src/cpp/src/llm_pipeline_stateful.hpp +++ b/src/cpp/src/llm_pipeline_stateful.hpp @@ -20,18 +20,14 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { ChatHistory m_history; std::vector m_tokenized_chat_history; ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF; - // If sequence contains some symbols, which could be ambiguously encoded by tokenizer, we need to trim kv cache - // If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history - // so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history - ov::genai::KVCacheTrimManager m_kv_history_trim_manager = {0, 2}; // Finish reason of last generation for chat scenario ov::genai::GenerationStatus m_chat_generation_finish_status = ov::genai::GenerationStatus::RUNNING; // if True, full history will be used as prompt on each chat generation bool m_use_full_chat_history = false; size_t m_max_kv_cache_size = std::numeric_limits::max(); bool m_is_npu = false; - // reflection of tokens contained in the kv cache - KVCacheState m_kv_cache_state; + // include reflection of tokens contained in the kv cache and amount of tokens, which are needed to trim from kv cache on the next step of chat + utils::KVCacheState m_kv_cache_state; void reset_kv_state(); public: diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index 61f9a169b5..53d95df3f6 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -82,7 +82,7 @@ ov::genai::utils::GenerationFinishInfo get_lm_encoded_results( Sampler& sampler, std::vector sequence_groups, std::optional position_ids, - KVCacheState& kv_cache_state, + utils::KVCacheState& kv_cache_state, std::optional m_embedding, std::optional rope_delta, const size_t max_kv_cache_size @@ -298,7 +298,7 @@ ov::genai::utils::GenerationFinishInfo get_lm_encoded_results( } -TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state) { +TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, utils::KVCacheState& kv_cache_state) { TokenizedInputs encoded_input; size_t kv_cache_len = kv_cache_state.get_state().size(); if (kv_cache_len == 0) { @@ -325,7 +325,7 @@ TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, KVCach } -void align_kv_cache_and_history(ov::genai::KVCacheTrimManager& kv_history_manager, const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state) { +void align_kv_cache_and_history(const ov::Tensor& new_chat_tokens, utils::KVCacheState& kv_cache_state) { // KV cache in model already contains prompts and answers from previous iterations. // So only new prompt wrapped into chat template to be sent into model. Tokenizer always returns // token_ids = {, ...}. So if tokenizer applies only to the new prompt, @@ -343,8 +343,9 @@ void align_kv_cache_and_history(ov::genai::KVCacheTrimManager& kv_history_manage size_t first_diverse_tokens_idx = ov::genai::utils::get_first_history_difference(new_chat_tokens, state); // in the case of beam_search the longest answer is in the kv cache, but the best one is needed // so generated tokens were not added to KVCacheState and num_tokens_to_trim was set to the size of the generated serquence - kv_history_manager.num_tokens_to_trim = kv_history_manager.num_tokens_to_trim > 0 ? kv_history_manager.num_tokens_to_trim : (state.size() - first_diverse_tokens_idx); + kv_cache_state.num_tokens_to_trim += state.size() - first_diverse_tokens_idx; state.resize(first_diverse_tokens_idx); + kv_cache_state.reset_mem_state = state.empty(); } } // namespace genai diff --git a/src/cpp/src/lm_encoding.hpp b/src/cpp/src/lm_encoding.hpp index c817ef19a6..69a787713f 100644 --- a/src/cpp/src/lm_encoding.hpp +++ b/src/cpp/src/lm_encoding.hpp @@ -8,44 +8,16 @@ namespace ov { namespace genai { -class KVCacheState { - std::vector state; -public: - std::vector& get_state() { - return state; - } - - void add_inputs(const ov::Tensor& inputs_ids) { - std::copy_n(inputs_ids.data(), inputs_ids.get_size(), std::back_inserter(state)); - } - - void reset_state() { - return state.clear(); - } -}; - - -struct KVCacheTrimManager -{ - size_t num_tokens_to_trim = 0; - size_t kv_cache_seq_length_axis = 2; - - void reset() { - num_tokens_to_trim = 0; - } -}; - - ov::genai::utils::GenerationFinishInfo get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask, const std::shared_ptr& streamer_ptr, Sampler& sampler, std::vector sequence_groups, - std::optional position_ids, KVCacheState& m_kv_cache_state, std::optional m_embedding, + std::optional position_ids, utils::KVCacheState& m_kv_cache_state, std::optional m_embedding, std::optional rope_delta = std::nullopt, const size_t max_kv_cache_size = std::numeric_limits::max()); -void align_kv_cache_and_history(ov::genai::KVCacheTrimManager& kv_history_manager, const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state); +void align_kv_cache_and_history(const ov::Tensor& new_chat_tokens, utils::KVCacheState& kv_cache_state); -TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state); +TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, utils::KVCacheState& kv_cache_state); } } diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index aca1693562..9c75a537d6 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -325,13 +325,27 @@ KVAxesPosition get_kv_axes_pos(std::shared_ptr model) { return kv_pos; } -void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional adapter_controller) { +void trim_kv_cache(ov::InferRequest request, KVCacheState& kv_cache_state, std::optional adapter_controller) { + if (kv_cache_state.reset_mem_state) { + if (adapter_controller) { + for(auto& state: request.query_state()) { + if(!adapter_controller->has_state_name(state.get_name())) { + state.reset(); + } + } + } else { + request.reset_state(); + } + + return; + } + // nothing to trim in this case - if (remove_from_end == 0) + if (kv_cache_state.num_tokens_to_trim == 0) return; auto states = request.query_state(); - + OPENVINO_ASSERT(states.size() > 0, "Request contains no states."); for (auto& state : states) { @@ -341,7 +355,7 @@ void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t se ov::Tensor old_tensor = state.get_state(); // [BATCH_SIZE, num_kv_heads, seq_len, head_size] auto shape = old_tensor.get_shape(); - shape[seq_length_axis] -= remove_from_end; + shape[kv_cache_state.seq_length_axis] -= kv_cache_state.num_tokens_to_trim; ov::Coordinate new_shape_begin{0, 0, 0, 0}; ov::Coordinate new_shape_end{shape}; diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index 4c8453b97b..9b60875b22 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -102,7 +102,29 @@ struct KVAxesPosition { KVAxesPosition get_kv_axes_pos(std::shared_ptr model); -void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional adapter_controller); +class KVCacheState { + std::vector state; +public: + size_t num_tokens_to_trim = 0; + size_t seq_length_axis = 2; + bool reset_mem_state = false; + + std::vector& get_state() { + return state; + } + + void add_inputs(const ov::Tensor& inputs_ids) { + std::copy_n(inputs_ids.data(), inputs_ids.get_size(), std::back_inserter(state)); + } + + void reset_state() { + reset_mem_state = false; + num_tokens_to_trim = 0; + state.clear(); + } +}; + +void trim_kv_cache(ov::InferRequest request, KVCacheState& kv_cache_state, std::optional adapter_controller); ov::Tensor push_front_inputs(const ov::Tensor& base_tensor, int64_t add_to_front); diff --git a/src/cpp/src/visual_language/inputs_embedder.cpp b/src/cpp/src/visual_language/inputs_embedder.cpp index cbc6cf2120..2d48db0ff1 100644 --- a/src/cpp/src/visual_language/inputs_embedder.cpp +++ b/src/cpp/src/visual_language/inputs_embedder.cpp @@ -29,7 +29,6 @@ std::pair> InputsEmbedder::IInputsEmbedder::g void InputsEmbedder::IInputsEmbedder::start_chat(const std::string& system_message) { m_is_chat_conversation = true; - m_kv_history_trim_manager.reset(); if (!m_kv_cache_state.get_state().empty()) { m_history.clear(); m_kv_cache_state.reset_state(); @@ -40,17 +39,26 @@ void InputsEmbedder::IInputsEmbedder::start_chat(const std::string& system_messa m_history = {{{"role", "system"}, {"content", system_message}}}; } -void InputsEmbedder::IInputsEmbedder::update_chat_history(const std::string& decoded_results) { - // Tail of chat template is missing in KV cache. - // Find the tail to concatenate it with the next input prompt. - m_history.push_back({{"role", "assistant"}, {"content", decoded_results}}); - m_kv_history_trim_manager.reset(); +void InputsEmbedder::IInputsEmbedder::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) { + m_kv_cache_state.num_tokens_to_trim = 0; + if (generation_finish_status == ov::genai::GenerationStatus::CANCEL) { + // If chat generation process was cancelled by user, let's rollback to previous state of history + m_history.pop_back(); + + std::vector& state = m_kv_cache_state.get_state(); + + m_kv_cache_state.num_tokens_to_trim = state.size() - m_prev_hist_length; + state.resize(m_prev_hist_length); + m_kv_cache_state.reset_mem_state = state.empty(); + } else { + // Tail of chat template is missing in KV cache. + // Find the tail to concatenate it with the next input prompt. + m_history.push_back({{"role", "assistant"}, {"content", decoded_results}}); + } } void InputsEmbedder::IInputsEmbedder::finish_chat() { m_is_chat_conversation = false; - m_kv_history_trim_manager.reset(); - m_history.clear(); m_kv_cache_state.reset_state(); } @@ -123,7 +131,7 @@ ov::Tensor InputsEmbedder::IInputsEmbedder::apply_chat_template_tokenize(const s ov::Tensor InputsEmbedder::IInputsEmbedder::update_history(const ov::Tensor& new_chat_tokens) { ov::Tensor encoded_inputs; if (m_is_chat_conversation) { - ov::genai::align_kv_cache_and_history(m_kv_history_trim_manager, new_chat_tokens, m_kv_cache_state); + ov::genai::align_kv_cache_and_history(new_chat_tokens, m_kv_cache_state); encoded_inputs = get_chat_encoded_input(new_chat_tokens, m_kv_cache_state).input_ids; } else { encoded_inputs = new_chat_tokens; @@ -135,6 +143,7 @@ ov::Tensor InputsEmbedder::IInputsEmbedder::update_history(const ov::Tensor& new ov::Tensor InputsEmbedder::IInputsEmbedder::get_encoded_input_ids(const std::string& prompt, ov::genai::VLMPerfMetrics& metrics) { const auto new_chat_tokens = apply_chat_template_tokenize(prompt, metrics); auto new_input_ids = update_history(new_chat_tokens); + m_prev_hist_length = m_kv_cache_state.get_state().size(); m_kv_cache_state.add_inputs(new_input_ids); return new_input_ids; @@ -225,14 +234,10 @@ EmbeddingsModel InputsEmbedder::get_embedding_model() const { return m_impl->get_embedding_model(); } -KVCacheState& InputsEmbedder::get_kv_cache_state() { +ov::genai::utils::KVCacheState& InputsEmbedder::get_kv_cache_state() { return m_impl->get_kv_cache_state(); } -size_t InputsEmbedder::get_num_tokens_to_remove_from_hist() const { - return m_impl->get_num_tokens_to_remove_from_hist(); -} - Tokenizer InputsEmbedder::get_tokenizer() const { return m_impl->get_tokenizer(); } @@ -241,8 +246,8 @@ void InputsEmbedder::start_chat(const std::string& system_message) { return m_impl->start_chat(system_message); } -void InputsEmbedder::update_chat_history(const std::string& decoded_results) { - return m_impl->update_chat_history(decoded_results); +void InputsEmbedder::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) { + return m_impl->update_chat_history(decoded_results, generation_finish_status); } void InputsEmbedder::set_apply_chat_template_status(bool apply_chat_template) { diff --git a/src/cpp/src/visual_language/inputs_embedder.hpp b/src/cpp/src/visual_language/inputs_embedder.hpp index 5c8fcbbce9..5eec6cd41e 100644 --- a/src/cpp/src/visual_language/inputs_embedder.hpp +++ b/src/cpp/src/visual_language/inputs_embedder.hpp @@ -45,16 +45,13 @@ class InputsEmbedder { Tokenizer get_tokenizer() const; // get reflection of tokens contained in the kv cache - KVCacheState& get_kv_cache_state(); - - // returns amount of elements, which need to remove from the end of the KV cache - size_t get_num_tokens_to_remove_from_hist() const; + utils::KVCacheState& get_kv_cache_state(); // starts chat and adds optional system_message to chat history void start_chat(const std::string& system_message); // adds currently generated text to chat history - void update_chat_history(const std::string& decoded_results); + void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status); // set the apply_chat_template flag, which determines whether chat template should be applied for non-chat scenarios void set_apply_chat_template_status(bool apply_chat_template); @@ -80,16 +77,14 @@ class InputsEmbedder { bool m_is_chat_conversation = false; // Chat history ChatHistory m_history; - // If sequence contains some symbols, which could be ambiguous encoded by tokenizer, we need to trim kv cache - // If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history - // so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history - ov::genai::KVCacheTrimManager m_kv_history_trim_manager = {0, 2}; // True if chat template should be applied for non-chat scenario bool m_apply_chat_template = true; // Finish reason of last generation for chat scenario ov::genai::GenerationStatus m_chat_generation_finish_status = ov::genai::GenerationStatus::RUNNING; // reflection of tokens contained in the kv cache - KVCacheState m_kv_cache_state; + utils::KVCacheState m_kv_cache_state; + // length of attention_mask/kv cache at the beginning of generation() + size_t m_prev_hist_length = 0; public: virtual ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector& images, ov::genai::VLMPerfMetrics& metrics) = 0; @@ -103,21 +98,17 @@ class InputsEmbedder { return m_tokenizer; } - KVCacheState& get_kv_cache_state() { + utils::KVCacheState& get_kv_cache_state() { return m_kv_cache_state; } - size_t get_num_tokens_to_remove_from_hist() const { - return m_kv_history_trim_manager.num_tokens_to_trim; - } - void set_apply_chat_template_status(bool apply_chat_template) { m_apply_chat_template = apply_chat_template; } virtual void start_chat(const std::string& system_message); - void update_chat_history(const std::string& decoded_results); + virtual void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status); virtual void finish_chat(); diff --git a/src/cpp/src/visual_language/minicpm/classes.cpp b/src/cpp/src/visual_language/minicpm/classes.cpp index 697ea64e50..e24cd22438 100644 --- a/src/cpp/src/visual_language/minicpm/classes.cpp +++ b/src/cpp/src/visual_language/minicpm/classes.cpp @@ -667,6 +667,14 @@ ov::Tensor InputsEmbedderMiniCPM::get_inputs_embeds(const std::string& prompt, c return inputs_embeds; } +void InputsEmbedderMiniCPM::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) { + IInputsEmbedder::update_chat_history(decoded_results, generation_finish_status); + if (generation_finish_status == ov::genai::GenerationStatus::CANCEL) + m_image_id = m_prev_image_id; + else + m_prev_image_id = m_image_id; +} + void InputsEmbedderMiniCPM::start_chat(const std::string& system_message) { IInputsEmbedder::start_chat(system_message); m_image_id = 0; diff --git a/src/cpp/src/visual_language/minicpm/classes.hpp b/src/cpp/src/visual_language/minicpm/classes.hpp index 0ddc160231..99e71faf44 100644 --- a/src/cpp/src/visual_language/minicpm/classes.hpp +++ b/src/cpp/src/visual_language/minicpm/classes.hpp @@ -30,6 +30,7 @@ class InputsEmbedderMiniCPM : public InputsEmbedder::IInputsEmbedder { ov::Tensor m_pos_embed_cache; // Used to insert i per image (not a slice). size_t m_image_id = 0; + size_t m_prev_image_id = 0; public: InputsEmbedderMiniCPM( @@ -48,6 +49,8 @@ class InputsEmbedderMiniCPM : public InputsEmbedder::IInputsEmbedder { ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector& images, ov::genai::VLMPerfMetrics& metrics) override; + void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) override; + void start_chat(const std::string& system_message) override; void finish_chat() override; diff --git a/src/cpp/src/visual_language/phi3_vision/classes.cpp b/src/cpp/src/visual_language/phi3_vision/classes.cpp index f1f094d1d6..ac0d8adbaa 100644 --- a/src/cpp/src/visual_language/phi3_vision/classes.cpp +++ b/src/cpp/src/visual_language/phi3_vision/classes.cpp @@ -471,6 +471,8 @@ ov::Tensor insert_image_placeholders(const std::vector& chunks, cons length, merged.data() + offset ); + if (tokens_per_images.empty()) + continue; offset += length; if (offset < merged_length) { std::fill_n( @@ -548,6 +550,7 @@ ov::Tensor InputsEmbedderPhi3V::get_inputs_embeds(const std::string& prompt, con } ov::Tensor new_merged_tokens = insert_image_placeholders(new_chat_tokens, m_tokens_per_images); ov::Tensor new_tokens = update_history(new_merged_tokens); + m_prev_hist_length = m_kv_cache_state.get_state().size(); m_kv_cache_state.add_inputs(new_tokens); std::vector tokens = drop_image_placeholders(new_tokens); @@ -602,6 +605,14 @@ ov::Tensor InputsEmbedderPhi3V::get_inputs_embeds(const std::string& prompt, con return inputs_embeds; } +void InputsEmbedderPhi3V::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) { + IInputsEmbedder::update_chat_history(decoded_results, generation_finish_status); + if (generation_finish_status == ov::genai::GenerationStatus::CANCEL) + m_tokens_per_images = m_prev_tokens_per_images; + else + m_prev_tokens_per_images = m_tokens_per_images; +} + void InputsEmbedderPhi3V::start_chat(const std::string& system_message) { IInputsEmbedder::start_chat(system_message); m_tokens_per_images.clear(); diff --git a/src/cpp/src/visual_language/phi3_vision/classes.hpp b/src/cpp/src/visual_language/phi3_vision/classes.hpp index 6fd922125e..006429723a 100644 --- a/src/cpp/src/visual_language/phi3_vision/classes.hpp +++ b/src/cpp/src/visual_language/phi3_vision/classes.hpp @@ -30,6 +30,8 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector& images, ov::genai::VLMPerfMetrics& metrics) override; + void update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) override; + void start_chat(const std::string& system_message) override; void finish_chat() override; @@ -38,6 +40,7 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { ov::InferRequest m_hd_feature_transformer; ov::InferRequest m_vision_projection; std::vector m_tokens_per_images; + std::vector m_prev_tokens_per_images; }; } // namespace ov::genai diff --git a/src/cpp/src/visual_language/pipeline.cpp b/src/cpp/src/visual_language/pipeline.cpp index 612f34187f..57f325d7fa 100644 --- a/src/cpp/src/visual_language/pipeline.cpp +++ b/src/cpp/src/visual_language/pipeline.cpp @@ -40,8 +40,6 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{ bool m_is_chat_conversation = false; // InputsEmbedder std::shared_ptr m_inputs_embedder; - // Axis num in kv cache from m_language model, which contains information about history len - size_t m_kv_cache_seq_length_axis = 2; // Component for applying sampling to lm outputs Sampler m_sampler; size_t m_max_kv_cache_size = std::numeric_limits::max(); @@ -63,7 +61,6 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{ auto language_model_path = models_dir / "openvino_language_model.xml"; auto language_model = utils::singleton_core().read_model(language_model_path, {}, properties_copy); auto kv_pos = ov::genai::utils::get_kv_axes_pos(language_model); - m_kv_cache_seq_length_axis = kv_pos.seq_len; // In case user provided properties per-device // { @@ -93,7 +90,6 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{ ov::genai::utils::print_compiled_model_properties(compiled_language_model, "VLM language model"); m_language = compiled_language_model.create_infer_request(); - m_kv_cache_seq_length_axis = utils::get_kv_axes_pos(language_model).seq_len; m_language.get_tensor("attention_mask").set_shape({1, 0}); auto embedder_properties = device_propertes.empty() @@ -103,6 +99,9 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{ m_tokenizer = m_inputs_embedder->get_tokenizer(); m_embedding = m_inputs_embedder->get_embedding_model(); + utils::KVCacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state(); + kv_cache_state.seq_length_axis = kv_pos.seq_len; + // If eos_token_id was not provided, take value if (m_generation_config.eos_token_id == -1) { m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); @@ -186,17 +185,17 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{ ov::Tensor inputs_embeds = m_inputs_embedder->get_inputs_embeds(prompt, rgbs, perf_metrics); auto end_get_inputs_embeds = std::chrono::steady_clock::now(); - auto to_remove_from_hist = m_inputs_embedder->get_num_tokens_to_remove_from_hist(); - utils::trim_kv_cache(m_language, to_remove_from_hist, m_kv_cache_seq_length_axis, std::nullopt); + utils::KVCacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state(); + if (m_is_chat_conversation) + utils::trim_kv_cache(m_language, kv_cache_state, std::nullopt); std::vector requests; size_t request_id = 0; size_t block_size = 1; // not used - size_t history_size = m_language.get_tensor("attention_mask").get_shape().at(1) - to_remove_from_hist; + size_t history_size = m_language.get_tensor("attention_mask").get_shape().at(1) - kv_cache_state.num_tokens_to_trim; size_t inputs_embeds_size = inputs_embeds.get_shape().at(1); - KVCacheState& kv_cache_state = m_inputs_embedder->get_kv_cache_state(); std::vector tokenized_history = kv_cache_state.get_state(); ov::Tensor prompt_ids(ov::element::i64, { history_size + inputs_embeds_size }); OPENVINO_ASSERT(prompt_ids.get_size() >= tokenized_history.size(), "Prompt ids size is less than tokenized history size"); @@ -237,7 +236,7 @@ class VLMPipeline::VLMPipelineImpl : public VLMPipelineBase{ std::string decoded_results = decoded.texts.at(0); if (m_is_chat_conversation) - m_inputs_embedder->update_chat_history(decoded_results); + m_inputs_embedder->update_chat_history(decoded_results, finish_info.streaming_finish_status); else kv_cache_state.reset_state(); diff --git a/tests/python_tests/test_vlm_pipeline.py b/tests/python_tests/test_vlm_pipeline.py index f2b697285a..575c848388 100644 --- a/tests/python_tests/test_vlm_pipeline.py +++ b/tests/python_tests/test_vlm_pipeline.py @@ -8,7 +8,7 @@ import sys import transformers from optimum.intel.openvino import OVModelForVisualCausalLM -from openvino_genai import VLMPipeline, GenerationConfig, SchedulerConfig, ContinuousBatchingPipeline, GenerationStatus +from openvino_genai import VLMPipeline, GenerationConfig, SchedulerConfig, ContinuousBatchingPipeline, GenerationStatus, StreamingStatus from utils.network import retry_request from utils.generation_config import get_beam_search, get_multinomial_all_parameters, get_greedy @@ -224,7 +224,12 @@ def test_vlm_with_scheduler_vs_default(config, cache): @pytest.mark.nightly @pytest.mark.parametrize("model_id", model_ids) @pytest.mark.parametrize("system_message", ["", "You are a helpful assistant."]) -def test_vlm_pipeline_chat(model_id, system_message, cache): +@pytest.mark.parametrize("iteration_images", [[image_links_for_testing[0], image_links_for_testing[0]], # generation with text input only + [image_links_for_testing[0], image_links_for_testing[2], image_links_for_testing[0]], # combination of generations with text input and text + image input, empty image first + [image_links_for_testing[2], image_links_for_testing[1]], # generation with text + image input + [image_links_for_testing[2], image_links_for_testing[0], image_links_for_testing[1]]] # combination of generations with text input and text + image input, image input first + ) +def test_vlm_pipeline_chat(model_id, system_message, iteration_images, cache): def streamer(word: str) -> bool: nonlocal result_from_streamer result_from_streamer.append(word) @@ -236,23 +241,26 @@ def streamer(word: str) -> bool: generation_config.max_new_tokens = 30 generation_config.set_eos_token_id(ov_pipe.get_tokenizer().get_eos_token_id()) - for links in image_links_for_testing: + ov_pipe.start_chat(system_message) + + images = [] + for link in iteration_images[0]: + images.append(get_image_by_link(link)) + + result_from_streamer = [] + res = ov_pipe.generate(prompts[0], images=images, generation_config=generation_config, streamer=streamer) + assert res.texts[0] == ''.join(result_from_streamer) + + for image_set in iteration_images[1:]: images = [] - for link in links: + for link in image_set: images.append(get_image_by_link(link)) - ov_pipe.start_chat(system_message) - result_from_streamer = [] - res = ov_pipe.generate(prompts[0], images=images, generation_config=generation_config, streamer=streamer) + res = ov_pipe.generate(prompts[1], images=images, generation_config=generation_config, streamer=streamer) assert res.texts[0] == ''.join(result_from_streamer) - for prompt in prompts[1:]: - result_from_streamer = [] - res = ov_pipe.generate(prompt, generation_config=generation_config, streamer=streamer) - assert res.texts[0] == ''.join(result_from_streamer) - - ov_pipe.finish_chat() + ov_pipe.finish_chat() @pytest.mark.precommit @@ -359,3 +367,95 @@ def test_vlm_npu_no_exception(model_id, cache): for link in image_links_for_testing[2]: image = get_image_by_link(link) out = ov_pipe.generate(prompts[0], images=[image], generation_config=generation_config) + + +@pytest.mark.precommit +@pytest.mark.nightly +@pytest.mark.parametrize("model_id", model_ids) +@pytest.mark.parametrize("iteration_images", [image_links_for_testing[1], []]) +def test_vlm_pipeline_chat_streamer_cancel_second_generate(model_id, iteration_images, cache): + callback_questions = [ + '1+1=', + 'Why is the Sun yellow?', + 'What is the previous answer?' + ] + + current_iter = 0 + num_iters = 3 + def streamer(subword): + nonlocal current_iter + current_iter += 1 + return StreamingStatus.CANCEL if current_iter == num_iters else StreamingStatus.RUNNING + + + models_path = get_ov_model(model_id, cache) + ov_pipe = VLMPipeline(models_path, "CPU") + generation_config = ov_pipe.get_generation_config() + generation_config.max_new_tokens = 30 + generation_config.set_eos_token_id(ov_pipe.get_tokenizer().get_eos_token_id()) + generation_config.ignore_eos = True + + images = [] + for link in iteration_images: + images.append(get_image_by_link(link)) + + results_with_cancel = "" + ov_pipe.start_chat() + results_with_cancel += ov_pipe.generate(callback_questions[0], images=images, generation_config=generation_config).texts[0] + # doesn't add to results_with_cancel as it should be complitely removed from the history + ov_pipe.generate(callback_questions[1], images=images, generation_config=generation_config, streamer=streamer) + results_with_cancel += ov_pipe.generate(callback_questions[2], images=images, generation_config=generation_config).texts[0] + ov_pipe.finish_chat() + + results = "" + ov_pipe.start_chat() + results += ov_pipe.generate(callback_questions[0], images=images, generation_config=generation_config).texts[0] + results += ov_pipe.generate(callback_questions[2], images=images, generation_config=generation_config).texts[0] + ov_pipe.finish_chat() + + assert results_with_cancel == results + + results = "" + ov_pipe.start_chat() + results += ov_pipe.generate(callback_questions[0], images=images, generation_config=generation_config).texts[0] + results += ov_pipe.generate(callback_questions[2], images=images, generation_config=generation_config).texts[0] + ov_pipe.finish_chat() + + assert results_with_cancel == results + + +@pytest.mark.precommit +@pytest.mark.nightly +@pytest.mark.parametrize("model_id", model_ids) +@pytest.mark.parametrize("iteration_images", [image_links_for_testing[1], []]) +def test_vlm_pipeline_chat_streamer_cancel_first_generate(model_id, iteration_images, cache): + callback_questions = [ + 'Why is the Sun yellow?', + '1+1=', + ] + + current_iter = 0 + num_iters = 3 + def streamer(subword): + nonlocal current_iter + current_iter += 1 + return StreamingStatus.CANCEL if current_iter == num_iters else StreamingStatus.RUNNING + + models_path = get_ov_model(model_id, cache) + ov_pipe = VLMPipeline(models_path, "CPU") + generation_config = ov_pipe.get_generation_config() + generation_config.max_new_tokens = 30 + generation_config.ignore_eos = True + generation_config.set_eos_token_id(ov_pipe.get_tokenizer().get_eos_token_id()) + + images = [] + for link in iteration_images: + images.append(get_image_by_link(link)) + + ov_pipe.start_chat() + res_first = ov_pipe.generate(callback_questions[0], images=images, generation_config=generation_config, streamer=streamer).texts[0] + current_iter = 0 + res_second = ov_pipe.generate(callback_questions[0], images=images, generation_config=generation_config, streamer=streamer).texts[0] + ov_pipe.finish_chat() + + assert res_first == res_second