Skip to content

Commit 9e9b409

Browse files
authored
Use whole history in case of undetermined tokenization of sequence (openvinotoolkit#1254)
Task: [CVS-157295](https://jira.devtools.intel.com/browse/CVS-157295) - fist commit is cherry-pick from openvinotoolkit#1268 and openvinotoolkit#1361 - next commit includes applying comments from openvinotoolkit#1268 and adding usage of kv cache for LLM
1 parent 8ce5eb3 commit 9e9b409

File tree

6 files changed

+264
-37
lines changed

6 files changed

+264
-37
lines changed

src/cpp/src/llm_pipeline.cpp

+90-18
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,15 @@ std::pair<EncodedResults, int32_t> beam_search(
3636
class StatefulLLMPipeline final : public LLMPipelineImplBase {
3737
public:
3838
ov::InferRequest m_model_runner;
39-
4039
bool is_chat_conversation = false;
41-
bool m_is_cache_empty = true;
40+
bool m_trust_encoded_history = true;
4241
std::optional<int32_t> m_selected_beam = std::nullopt;
4342
ChatHistory m_history;
4443
std::string m_templated_chat_history = {};
45-
TokenizedInputs m_tokenized_chat_history;
44+
std::vector<int64_t> m_tokenized_chat_history;
45+
ov::genai::utils::GenerationChatInputsType m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
46+
size_t m_to_remove_from_hist = 0;
47+
size_t m_kv_cache_seq_length_axis = 2;
4648

4749
StatefulLLMPipeline(
4850
const ov::InferRequest& request,
@@ -77,6 +79,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
7779
ov::Core core;
7880
auto [core_plugin_config, plugin_config] = ov::genai::utils::split_core_compile_config(config);
7981
utils::slice_matmul_statefull_model(model);
82+
m_kv_cache_seq_length_axis = ov::genai::utils::get_seq_len_axis(model);
8083

8184
if (auto filtered_plugin_config = extract_adapters_from_properties(plugin_config, &m_generation_config.adapters)) {
8285
m_generation_config.adapters->set_tensor_name_prefix("base_model.model.model.");
@@ -102,8 +105,20 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
102105
OptionalGenerationConfig generation_config,
103106
StreamerVariant streamer
104107
) override {
108+
if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::UNDEF)
109+
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::STRING;
110+
111+
if (is_chat_conversation)
112+
OPENVINO_ASSERT(m_chat_input_type != ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS,
113+
"Chat doesn't support switching between input types. Please, continue using EncodedInputs or restart the chat.");
114+
105115
auto start_time = std::chrono::steady_clock::now();
106116
GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config;
117+
// If eos_token_id was not provided, take value from default m_generation_config
118+
if (config.eos_token_id == -1)
119+
config.set_eos_token_id(m_generation_config.eos_token_id);
120+
config.validate();
121+
107122
TokenizedInputs encoded_input;
108123

109124
if (auto input_vector = std::get_if<std::vector<std::string>>(&inputs)) {
@@ -127,19 +142,51 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
127142
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
128143
// Do not add special tokens in chat scenario to be aligned with HF.
129144
auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false));
130-
if (m_is_cache_empty) {
145+
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false));
146+
147+
// some symbols combinations can be encoded by the tokenizer in different ways
148+
// if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history
149+
// so let's check it out, find the trusted part and use it in on the next step
150+
size_t last_same_hist_token = 0;
151+
if (!m_tokenized_chat_history.empty()) {
152+
std::set<int64_t> stop_tokens = config.stop_token_ids;
153+
last_same_hist_token = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_chat_history, stop_tokens);
154+
m_trust_encoded_history = last_same_hist_token == SIZE_MAX;
155+
}
156+
157+
if (m_tokenized_chat_history.empty()) {
131158
encoded_input = new_chat_tokens;
159+
} else if (last_same_hist_token != SIZE_MAX) {
160+
m_to_remove_from_hist = m_tokenized_chat_history.size() - last_same_hist_token;
161+
162+
ov::Tensor new_tensor = ov::Tensor(new_chat_tokens.input_ids.get_element_type(),
163+
{1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token},
164+
new_chat_tokens.input_ids.data<int64_t>() + last_same_hist_token);
165+
166+
ov::Tensor new_attention_mask(ov::element::i64, new_tensor.get_shape());
167+
std::fill_n(new_attention_mask.data<int64_t>(), new_tensor.get_shape()[1], 1);
168+
169+
encoded_input.input_ids = ov::Tensor(new_chat_tokens.input_ids.get_element_type(),
170+
{1, new_chat_tokens.input_ids.get_shape().at(1) - last_same_hist_token});
171+
new_tensor.copy_to(encoded_input.input_ids);
172+
encoded_input.attention_mask = new_attention_mask;
173+
174+
m_selected_beam = std::nullopt;
132175
} else {
133-
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false));
134176
encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens);
135177
}
136178
m_templated_chat_history = new_templated_chat_history;
137-
m_tokenized_chat_history = new_chat_tokens;
179+
m_tokenized_chat_history.clear();
180+
m_tokenized_chat_history.reserve(new_chat_tokens.input_ids.get_size());
181+
std::copy_n(new_chat_tokens.input_ids.data<int64_t>(), new_chat_tokens.input_ids.get_size(),
182+
std::back_inserter(m_tokenized_chat_history));
183+
138184
// TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied
139185
} else {
140186
encoded_input = m_tokenizer.encode(prompt);
141187
}
142188
}
189+
143190
auto encode_stop_time = std::chrono::steady_clock::now();
144191
auto encoded_results = generate(encoded_input, config, streamer);
145192

@@ -188,6 +235,14 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
188235
OptionalGenerationConfig generation_config,
189236
StreamerVariant streamer
190237
) override {
238+
if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::UNDEF)
239+
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS;
240+
241+
if (is_chat_conversation)
242+
// if chat was run in StringInputs mode, but it was called EncodedInputs generate, last m_history entry will be with assistant role
243+
OPENVINO_ASSERT(m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS || m_history.back()["role"] == "user",
244+
"Chat doesn't support switching between input types. Please, continue using StringInputs or restart the chat.");
245+
191246
auto start_time = std::chrono::steady_clock::now();
192247
ov::Tensor input_ids;
193248
ov::Tensor attention_mask;
@@ -199,6 +254,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
199254
attention_mask = data->attention_mask;
200255
}
201256

257+
if (is_chat_conversation && m_chat_input_type == ov::genai::utils::GenerationChatInputsType::ENCODED_INPUTS)
258+
std::copy(input_ids.data<int64_t>(), input_ids.data<int64_t>() + input_ids.get_size(), std::back_inserter(m_tokenized_chat_history));
259+
202260
GenerationConfig config = (generation_config.has_value()) ? *generation_config : m_generation_config;
203261

204262
// If eos_token_id was not provided, take value from default m_generation_config
@@ -230,16 +288,17 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
230288
"(input_ids, attention_mask, position_ids, beam_idx) "
231289
"but you have '" + std::to_string(num_inputs) + "' inputs");
232290

291+
ov::genai::utils::trim_kv_cache(m_model_runner, m_to_remove_from_hist, m_kv_cache_seq_length_axis, m_adapter_controller);
233292

234293
size_t kv_cache_len = 0;
235294
ov::Tensor concatenated_attention_mask;
236-
if (is_chat_conversation && !m_is_cache_empty) {
295+
if (is_chat_conversation && !m_tokenized_chat_history.empty()) {
237296
OPENVINO_ASSERT(batch_size == 1, "continuation of generation is possible only for batch 1");
238297
// If history is saved in KV cache, concatenate new attention_mask with the already existing.
239298
// Between subsequent runs attention_mask should not be modified.
240299
auto atten_mask_history = m_model_runner.get_tensor("attention_mask");
241300
auto prompt_len = attention_mask.get_shape()[1];
242-
kv_cache_len = atten_mask_history.get_shape()[1];
301+
kv_cache_len = atten_mask_history.get_shape()[1] - m_to_remove_from_hist;
243302

244303
ov::Tensor new_atten_mask = ov::Tensor{ov::element::i64, {batch_size, kv_cache_len + prompt_len}};
245304
auto start_atten_hst = atten_mask_history.data<int64_t>() + kv_cache_len * (*m_selected_beam);
@@ -263,6 +322,11 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
263322
m_adapter_controller->apply(m_model_runner, config.adapters);
264323
}
265324

325+
if (is_chat_conversation && !m_trust_encoded_history) {
326+
m_trust_encoded_history = true;
327+
m_to_remove_from_hist = 0;
328+
}
329+
266330
ov::genai::EncodedResults result;
267331
if (config.is_beam_search() && is_chat_conversation) {
268332
std::tie(result, m_selected_beam) = beam_search(m_model_runner, input_ids, concatenated_attention_mask,
@@ -274,8 +338,9 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
274338

275339
for (size_t request_id = 0; request_id < batch_size; request_id++) {
276340
SequenceGroup::Ptr sequence_group;
277-
if (is_chat_conversation && !m_is_cache_empty) {
278-
sequence_group = std::make_shared<SequenceGroup>(request_id, m_tokenized_chat_history.input_ids, config, block_size, enable_prefix_caching);
341+
if (is_chat_conversation) {
342+
ov::Tensor tokenized_chat_history = ov::Tensor(ov::element::i64, {1, m_tokenized_chat_history.size()}, m_tokenized_chat_history.data());
343+
sequence_group = std::make_shared<SequenceGroup>(request_id, tokenized_chat_history, config, block_size, enable_prefix_caching);
279344
} else {
280345
size_t seq_len = input_ids.get_shape().at(1);
281346
size_t batch_offset = request_id * seq_len;
@@ -294,12 +359,13 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
294359
sampler, requests, position_ids, std::nullopt, m_selected_beam);
295360
}
296361

297-
if (!is_chat_conversation) {
362+
if (is_chat_conversation) {
363+
std::copy(result.tokens[0].begin(), result.tokens[0].end(), std::back_inserter(m_tokenized_chat_history));
364+
} else {
298365
reset_kv_state();
299366
m_selected_beam = std::nullopt;
300-
} else {
301-
m_is_cache_empty = false;
302367
}
368+
303369
auto stop_time = std::chrono::steady_clock::now();
304370

305371
// If is called without tokenization then that stat will not be reported.
@@ -313,12 +379,15 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
313379

314380
void start_chat(const std::string& system_message) override {
315381
is_chat_conversation = true;
316-
m_selected_beam = std::nullopt;
317-
if (!m_is_cache_empty) {
382+
m_selected_beam = std::nullopt;
383+
m_trust_encoded_history = true;
384+
m_to_remove_from_hist = 0;
385+
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
386+
if (!m_tokenized_chat_history.empty()) {
318387
reset_kv_state();
319-
m_is_cache_empty = true;
320388
m_history = {};
321389
m_templated_chat_history = "";
390+
m_tokenized_chat_history.clear();
322391
}
323392
if (system_message.empty())
324393
return;
@@ -332,11 +401,14 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
332401
void finish_chat() override {
333402
is_chat_conversation = false;
334403
m_selected_beam = std::nullopt;
335-
if (!m_is_cache_empty) {
404+
m_trust_encoded_history = true;
405+
m_to_remove_from_hist = 0;
406+
m_chat_input_type = ov::genai::utils::GenerationChatInputsType::UNDEF;
407+
if (!m_tokenized_chat_history.empty()) {
336408
reset_kv_state();
337-
m_is_cache_empty = true;
338409
m_history.clear();
339410
m_templated_chat_history.clear();
411+
m_tokenized_chat_history.clear();
340412
}
341413
}
342414
};

src/cpp/src/utils.cpp

+75
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "openvino/op/tanh.hpp"
1414
#include "openvino/op/transpose.hpp"
1515

16+
#include "sampler.hpp"
17+
1618
namespace ov {
1719
namespace genai {
1820
namespace utils {
@@ -306,6 +308,79 @@ ov::Core singleton_core() {
306308
return core;
307309
}
308310

311+
size_t get_first_history_difference(const ov::Tensor& encoded_history, const std::vector<int64_t> tokenized_history, std::set<int64_t> stop_tokens) {
312+
size_t idx = 0;
313+
auto encoded_history_data = encoded_history.data<int64_t>();
314+
while(idx < encoded_history.get_size() && idx < tokenized_history.size()) {
315+
if (encoded_history_data[idx] != tokenized_history[idx])
316+
break;
317+
idx++;
318+
}
319+
320+
// encoded_history after decode of tokenizer could lose one last token (eos/stop token)
321+
if ((idx == tokenized_history.size() && idx == encoded_history.get_size()) ||
322+
(encoded_history.get_size() < tokenized_history.size() && idx == tokenized_history.size() - 1 && stop_tokens.find(tokenized_history.back()) != stop_tokens.end()))
323+
return SIZE_MAX;
324+
else
325+
return idx;
326+
}
327+
328+
size_t get_seq_len_axis(std::shared_ptr<const ov::Model> model) {
329+
// sequence length axis in key/values tensors, for most cases [BATCH_SIZE, num_kv_heads, seq_len, head_size],
330+
// therefore usually seq_length_axis = 2
331+
size_t seq_length_axis = 2;
332+
333+
// "ReadValue" node is KV cache representation in stateful model
334+
std::string kv_node_type_name = std::string(ov::op::v6::ReadValue::get_type_info_static().name);
335+
336+
for (const auto op : model->get_ops()) {
337+
// check input size, as in LoRA adapters case it could be 0
338+
if (op->get_type_name() != kv_node_type_name || op->get_input_size() < 1) {
339+
continue;
340+
}
341+
342+
// Shape example: [-1,4,0,64]
343+
auto shape = op->get_input_partial_shape(0);
344+
345+
for (size_t i = 0; i < shape.rank().get_length(); i++) {
346+
// Find axis = 0. This would be sequence length axis.
347+
if (shape[i] == 0) {
348+
seq_length_axis = i;
349+
}
350+
}
351+
break;
352+
}
353+
354+
return seq_length_axis;
355+
}
356+
357+
void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional<AdapterController> adapter_controller) {
358+
// nothing to trim in this case
359+
if (remove_from_end == 0)
360+
return;
361+
362+
auto states = request.query_state();
363+
for (auto& state : states) {
364+
if(adapter_controller && adapter_controller->has_state_name(state.get_name()))
365+
continue;
366+
367+
ov::Tensor old_tensor = state.get_state();
368+
// [BATCH_SIZE, num_kv_heads, seq_len, head_size]
369+
auto shape = old_tensor.get_shape();
370+
shape[seq_length_axis] -= remove_from_end;
371+
372+
ov::Coordinate new_shape_begin{0, 0, 0, 0};
373+
ov::Coordinate new_shape_end{shape};
374+
375+
auto trimmed_tensor = ov::Tensor(old_tensor, new_shape_begin, new_shape_end);
376+
377+
ov::Tensor new_tensor(old_tensor.get_element_type(), shape);
378+
trimmed_tensor.copy_to(new_tensor);
379+
380+
state.set_state(new_tensor);
381+
}
382+
}
383+
309384
} // namespace utils
310385
} // namespace genai
311386
} // namespace ov

src/cpp/src/utils.hpp

+11
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ constexpr bool is_container<T,
2222
std::void_t<decltype(std::declval<T>().begin()),
2323
decltype(std::declval<T>().end())>> = true;
2424

25+
enum class GenerationChatInputsType {
26+
UNDEF = 0, // Default value, type of inputs is not defined
27+
STRING = 1, // Type of inputs is StringInputs
28+
ENCODED_INPUTS = 2, // Type of inputs is EncodedInputs
29+
};
2530

2631
Tensor init_attention_mask(const Tensor& position_ids);
2732

@@ -93,6 +98,12 @@ ov::Core singleton_core();
9398
template <typename T>
9499
void read_rt_info(std::shared_ptr<ov::Model>& model, const char* name, T& value);
95100

101+
size_t get_first_history_difference(const ov::Tensor& encoded_history, const std::vector<int64_t> tokenized_history, std::set<int64_t> stop_tokens);
102+
103+
size_t get_seq_len_axis(std::shared_ptr<const ov::Model> model);
104+
105+
void trim_kv_cache(ov::InferRequest request, uint64_t remove_from_end, size_t seq_length_axis, std::optional<AdapterController> adapter_controller);
106+
96107
} // namespace utils
97108
} // namespace genai
98109
} // namespace ov

0 commit comments

Comments
 (0)