Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement CANCEL for streaming with VLM Pipeline #1725

Merged
merged 8 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ StatefulLLMPipeline::StatefulLLMPipeline(
}

if (!m_use_full_chat_history)
m_kv_history_trim_manager.kv_cache_seq_length_axis = kv_pos.seq_len;
m_kv_cache_state.seq_length_axis = kv_pos.seq_len;

auto filtered_properties = extract_adapters_from_properties(properties, &m_generation_config.adapters);
if (m_generation_config.adapters) {
Expand Down Expand Up @@ -143,7 +143,7 @@ DecodedResults StatefulLLMPipeline::generate(
if (m_use_full_chat_history) {
encoded_input = new_chat_tokens;
} else {
ov::genai::align_kv_cache_and_history(m_kv_history_trim_manager, new_chat_tokens.input_ids, m_kv_cache_state);
ov::genai::align_kv_cache_and_history(new_chat_tokens.input_ids, m_kv_cache_state);
encoded_input = get_chat_encoded_input(new_chat_tokens.input_ids, m_kv_cache_state);
}
// TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied
Expand Down Expand Up @@ -212,7 +212,6 @@ EncodedResults StatefulLLMPipeline::generate(
reset_kv_state();
m_model_runner.get_tensor("attention_mask").set_shape({1, 0});
m_kv_cache_state.reset_state();
m_kv_history_trim_manager.reset();
}

auto start_time = std::chrono::steady_clock::now();
Expand All @@ -238,7 +237,7 @@ EncodedResults StatefulLLMPipeline::generate(
// Tail of previous output in chat mode is missing in KV cache.
if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
ov::Tensor new_chat_tokens = ov::Tensor{ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data()};
ov::genai::align_kv_cache_and_history(m_kv_history_trim_manager, new_chat_tokens, m_kv_cache_state);
ov::genai::align_kv_cache_and_history(new_chat_tokens, m_kv_cache_state);

auto encoded_input = get_chat_encoded_input(new_chat_tokens, m_kv_cache_state);
input_ids = encoded_input.input_ids;
Expand Down Expand Up @@ -281,11 +280,10 @@ EncodedResults StatefulLLMPipeline::generate(
"but you have '" + std::to_string(num_inputs) + "' inputs");

if (is_chat_conversation) {
if (m_kv_cache_state.get_state().empty() || m_use_full_chat_history)
if (m_use_full_chat_history)
reset_kv_state();
else
ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_history_trim_manager.num_tokens_to_trim,
m_kv_history_trim_manager.kv_cache_seq_length_axis, m_adapter_controller);
ov::genai::utils::trim_kv_cache(m_model_runner, m_kv_cache_state, m_adapter_controller);
}

size_t kv_cache_len = 0;
Expand Down Expand Up @@ -358,7 +356,7 @@ EncodedResults StatefulLLMPipeline::generate(
m_chat_generation_finish_status = finish_info.streaming_finish_status;

if (is_chat_conversation) {
m_kv_history_trim_manager.reset();
m_kv_cache_state.num_tokens_to_trim = 0;

if (m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS) {
if (m_chat_generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
Expand All @@ -367,7 +365,7 @@ EncodedResults StatefulLLMPipeline::generate(
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
}
} else if (config.is_beam_search()) {
m_kv_history_trim_manager.num_tokens_to_trim = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;
m_kv_cache_state.num_tokens_to_trim = m_model_runner.get_tensor("attention_mask").get_shape()[1] - prev_attn_mask_size;
}
}

Expand Down Expand Up @@ -406,7 +404,6 @@ void StatefulLLMPipeline::reset_kv_state() {

void StatefulLLMPipeline::finish_chat() {
is_chat_conversation = false;
m_kv_history_trim_manager.reset();
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
bool have_state = 0 != m_model_runner.get_tensor("attention_mask").get_size();
if (!m_kv_cache_state.get_state().empty() || have_state) {
Expand Down
8 changes: 2 additions & 6 deletions src/cpp/src/llm_pipeline_stateful.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,14 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
ChatHistory m_history;
std::vector<int64_t> m_tokenized_chat_history;
ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
// If sequence contains some symbols, which could be ambiguously encoded by tokenizer, we need to trim kv cache
// If we use beam search sampling with chat mode we need to remove last answer of the model from kv cache and add best answer to history
// so, let's keep info about amount of tokens to trim from kv cache and amount of tokens to keep in history
ov::genai::KVCacheTrimManager m_kv_history_trim_manager = {0, 2};
// Finish reason of last generation for chat scenario
ov::genai::GenerationStatus m_chat_generation_finish_status = ov::genai::GenerationStatus::RUNNING;
// if True, full history will be used as prompt on each chat generation
bool m_use_full_chat_history = false;
size_t m_max_kv_cache_size = std::numeric_limits<size_t>::max();
bool m_is_npu = false;
// reflection of tokens contained in the kv cache
KVCacheState m_kv_cache_state;
// include reflection of tokens contained in the kv cache and amount of tokens, which are needed to trim from kv cache on the next step of chat
utils::KVCacheState m_kv_cache_state;

void reset_kv_state();
public:
Expand Down
9 changes: 5 additions & 4 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ ov::genai::utils::GenerationFinishInfo get_lm_encoded_results(
Sampler& sampler,
std::vector<SequenceGroup::Ptr> sequence_groups,
std::optional<ov::Tensor> position_ids,
KVCacheState& kv_cache_state,
utils::KVCacheState& kv_cache_state,
std::optional<EmbeddingsModel> m_embedding,
std::optional<int64_t> rope_delta,
const size_t max_kv_cache_size
Expand Down Expand Up @@ -298,7 +298,7 @@ ov::genai::utils::GenerationFinishInfo get_lm_encoded_results(
}


TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state) {
TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, utils::KVCacheState& kv_cache_state) {
TokenizedInputs encoded_input;
size_t kv_cache_len = kv_cache_state.get_state().size();
if (kv_cache_len == 0) {
Expand All @@ -325,7 +325,7 @@ TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, KVCach
}


void align_kv_cache_and_history(ov::genai::KVCacheTrimManager& kv_history_manager, const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state) {
void align_kv_cache_and_history(const ov::Tensor& new_chat_tokens, utils::KVCacheState& kv_cache_state) {
// KV cache in model already contains prompts and answers from previous iterations.
// So only new prompt wrapped into chat template to be sent into model. Tokenizer always returns
// token_ids = {<bos token>, ...<valuable tokens>}. So if tokenizer applies only to the new prompt,
Expand All @@ -343,8 +343,9 @@ void align_kv_cache_and_history(ov::genai::KVCacheTrimManager& kv_history_manage
size_t first_diverse_tokens_idx = ov::genai::utils::get_first_history_difference(new_chat_tokens, state);
// in the case of beam_search the longest answer is in the kv cache, but the best one is needed
// so generated tokens were not added to KVCacheState and num_tokens_to_trim was set to the size of the generated serquence
kv_history_manager.num_tokens_to_trim = kv_history_manager.num_tokens_to_trim > 0 ? kv_history_manager.num_tokens_to_trim : (state.size() - first_diverse_tokens_idx);
kv_cache_state.num_tokens_to_trim += state.size() - first_diverse_tokens_idx;
state.resize(first_diverse_tokens_idx);
kv_cache_state.reset_mem_state = state.empty();
}

} // namespace genai
Expand Down
34 changes: 3 additions & 31 deletions src/cpp/src/lm_encoding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,44 +8,16 @@
namespace ov {
namespace genai {

class KVCacheState {
std::vector<int64_t> state;
public:
std::vector<int64_t>& get_state() {
return state;
}

void add_inputs(const ov::Tensor& inputs_ids) {
std::copy_n(inputs_ids.data<int64_t>(), inputs_ids.get_size(), std::back_inserter(state));
}

void reset_state() {
return state.clear();
}
};


struct KVCacheTrimManager
{
size_t num_tokens_to_trim = 0;
size_t kv_cache_seq_length_axis = 2;

void reset() {
num_tokens_to_trim = 0;
}
};


ov::genai::utils::GenerationFinishInfo get_lm_encoded_results(ov::InferRequest& m_llm, const ov::Tensor& input_ids, const ov::Tensor& attention_mask,
const std::shared_ptr<StreamerBase>& streamer_ptr, Sampler& sampler, std::vector<SequenceGroup::Ptr> sequence_groups,
std::optional<ov::Tensor> position_ids, KVCacheState& m_kv_cache_state, std::optional<EmbeddingsModel> m_embedding,
std::optional<ov::Tensor> position_ids, utils::KVCacheState& m_kv_cache_state, std::optional<EmbeddingsModel> m_embedding,
std::optional<int64_t> rope_delta = std::nullopt, const size_t max_kv_cache_size = std::numeric_limits<size_t>::max());


void align_kv_cache_and_history(ov::genai::KVCacheTrimManager& kv_history_manager, const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state);
void align_kv_cache_and_history(const ov::Tensor& new_chat_tokens, utils::KVCacheState& kv_cache_state);


TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, KVCacheState& kv_cache_state);
TokenizedInputs get_chat_encoded_input(const ov::Tensor& new_chat_tokens, utils::KVCacheState& kv_cache_state);

}
}
22 changes: 18 additions & 4 deletions src/cpp/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,27 @@ KVAxesPosition get_kv_axes_pos(std::shared_ptr<const ov::Model> model) {
return kv_pos;
}

void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional<AdapterController> adapter_controller) {
void trim_kv_cache(ov::InferRequest request, KVCacheState& kv_cache_state, std::optional<AdapterController> adapter_controller) {
if (kv_cache_state.reset_mem_state) {
if (adapter_controller) {
for(auto& state: request.query_state()) {
if(!adapter_controller->has_state_name(state.get_name())) {
state.reset();
}
}
} else {
request.reset_state();
}

return;
}

// nothing to trim in this case
if (remove_from_end == 0)
if (kv_cache_state.num_tokens_to_trim == 0)
return;

auto states = request.query_state();

OPENVINO_ASSERT(states.size() > 0, "Request contains no states.");

for (auto& state : states) {
Expand All @@ -341,7 +355,7 @@ void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t se
ov::Tensor old_tensor = state.get_state();
// [BATCH_SIZE, num_kv_heads, seq_len, head_size]
auto shape = old_tensor.get_shape();
shape[seq_length_axis] -= remove_from_end;
shape[kv_cache_state.seq_length_axis] -= kv_cache_state.num_tokens_to_trim;

ov::Coordinate new_shape_begin{0, 0, 0, 0};
ov::Coordinate new_shape_end{shape};
Expand Down
24 changes: 23 additions & 1 deletion src/cpp/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,29 @@ struct KVAxesPosition {

KVAxesPosition get_kv_axes_pos(std::shared_ptr<const ov::Model> model);

void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional<AdapterController> adapter_controller);
class KVCacheState {
std::vector<int64_t> state;
public:
size_t num_tokens_to_trim = 0;
size_t seq_length_axis = 2;
bool reset_mem_state = false;

std::vector<int64_t>& get_state() {
return state;
}

void add_inputs(const ov::Tensor& inputs_ids) {
std::copy_n(inputs_ids.data<int64_t>(), inputs_ids.get_size(), std::back_inserter(state));
}

void reset_state() {
reset_mem_state = false;
num_tokens_to_trim = 0;
state.clear();
}
};

void trim_kv_cache(ov::InferRequest request, KVCacheState& kv_cache_state, std::optional<AdapterController> adapter_controller);

ov::Tensor push_front_inputs(const ov::Tensor& base_tensor, int64_t add_to_front);

Expand Down
37 changes: 21 additions & 16 deletions src/cpp/src/visual_language/inputs_embedder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ std::pair<ov::Tensor, std::optional<int64_t>> InputsEmbedder::IInputsEmbedder::g

void InputsEmbedder::IInputsEmbedder::start_chat(const std::string& system_message) {
m_is_chat_conversation = true;
m_kv_history_trim_manager.reset();
if (!m_kv_cache_state.get_state().empty()) {
m_history.clear();
m_kv_cache_state.reset_state();
Expand All @@ -40,17 +39,26 @@ void InputsEmbedder::IInputsEmbedder::start_chat(const std::string& system_messa
m_history = {{{"role", "system"}, {"content", system_message}}};
}

void InputsEmbedder::IInputsEmbedder::update_chat_history(const std::string& decoded_results) {
// Tail of chat template is missing in KV cache.
// Find the tail to concatenate it with the next input prompt.
m_history.push_back({{"role", "assistant"}, {"content", decoded_results}});
m_kv_history_trim_manager.reset();
void InputsEmbedder::IInputsEmbedder::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
m_kv_cache_state.num_tokens_to_trim = 0;
if (generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
// If chat generation process was cancelled by user, let's rollback to previous state of history
m_history.pop_back();

std::vector<int64_t>& state = m_kv_cache_state.get_state();

m_kv_cache_state.num_tokens_to_trim = state.size() - m_prev_hist_length;
state.resize(m_prev_hist_length);
m_kv_cache_state.reset_mem_state = state.empty();
} else {
// Tail of chat template is missing in KV cache.
// Find the tail to concatenate it with the next input prompt.
m_history.push_back({{"role", "assistant"}, {"content", decoded_results}});
}
}

void InputsEmbedder::IInputsEmbedder::finish_chat() {
m_is_chat_conversation = false;
m_kv_history_trim_manager.reset();

m_history.clear();
m_kv_cache_state.reset_state();
}
Expand Down Expand Up @@ -123,7 +131,7 @@ ov::Tensor InputsEmbedder::IInputsEmbedder::apply_chat_template_tokenize(const s
ov::Tensor InputsEmbedder::IInputsEmbedder::update_history(const ov::Tensor& new_chat_tokens) {
ov::Tensor encoded_inputs;
if (m_is_chat_conversation) {
ov::genai::align_kv_cache_and_history(m_kv_history_trim_manager, new_chat_tokens, m_kv_cache_state);
ov::genai::align_kv_cache_and_history(new_chat_tokens, m_kv_cache_state);
encoded_inputs = get_chat_encoded_input(new_chat_tokens, m_kv_cache_state).input_ids;
} else {
encoded_inputs = new_chat_tokens;
Expand All @@ -135,6 +143,7 @@ ov::Tensor InputsEmbedder::IInputsEmbedder::update_history(const ov::Tensor& new
ov::Tensor InputsEmbedder::IInputsEmbedder::get_encoded_input_ids(const std::string& prompt, ov::genai::VLMPerfMetrics& metrics) {
const auto new_chat_tokens = apply_chat_template_tokenize(prompt, metrics);
auto new_input_ids = update_history(new_chat_tokens);
m_prev_hist_length = m_kv_cache_state.get_state().size();
m_kv_cache_state.add_inputs(new_input_ids);

return new_input_ids;
Expand Down Expand Up @@ -225,14 +234,10 @@ EmbeddingsModel InputsEmbedder::get_embedding_model() const {
return m_impl->get_embedding_model();
}

KVCacheState& InputsEmbedder::get_kv_cache_state() {
ov::genai::utils::KVCacheState& InputsEmbedder::get_kv_cache_state() {
return m_impl->get_kv_cache_state();
}

size_t InputsEmbedder::get_num_tokens_to_remove_from_hist() const {
return m_impl->get_num_tokens_to_remove_from_hist();
}

Tokenizer InputsEmbedder::get_tokenizer() const {
return m_impl->get_tokenizer();
}
Expand All @@ -241,8 +246,8 @@ void InputsEmbedder::start_chat(const std::string& system_message) {
return m_impl->start_chat(system_message);
}

void InputsEmbedder::update_chat_history(const std::string& decoded_results) {
return m_impl->update_chat_history(decoded_results);
void InputsEmbedder::update_chat_history(const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
return m_impl->update_chat_history(decoded_results, generation_finish_status);
}

void InputsEmbedder::set_apply_chat_template_status(bool apply_chat_template) {
Expand Down
Loading
Loading