Skip to content

Commit c58ba64

Browse files
authored
Add/skip special tokens in runtime (openvinotoolkit#859)
CVS-152371 In Python ``` tok = pipe.get_tokenizer() res_genai = tok.encode(prompt, add_special_tokens=False).input_ids ``` In C++ ``` tok = pipe.get_tokenizer() res_genai = tok.encode(prompt, ov::genai::add_special_tokens(False)).input_ids ```
1 parent 4bb683e commit c58ba64

8 files changed

+226
-31
lines changed

.github/workflows/causal_lm_cpp.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ jobs:
665665
output.write('question:\n')
666666
chat_history.append(gen_prompt(prompt))
667667
chat_prompt = tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
668-
tokenized = tokenizer(chat_prompt, return_tensors='pt')
668+
tokenized = tokenizer(chat_prompt, return_tensors='pt', add_special_tokens=False)
669669
answer = model.generate(**tokenized, max_length=1000, do_sample=False)
670670
answer_str = tokenizer.decode(answer[0, tokenized['input_ids'].numel():], skip_special_tokens=True)
671671
chat_history.append(gen_answer(answer_str))

src/cpp/include/openvino/genai/tokenizer.hpp

+34-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "openvino/runtime/tensor.hpp"
1111
#include "openvino/genai/visibility.hpp"
12+
#include <openvino/runtime/properties.hpp>
1213

1314
namespace ov {
1415
namespace genai {
@@ -33,19 +34,44 @@ class OPENVINO_GENAI_EXPORTS Tokenizer {
3334

3435
/**
3536
* @brief encode a single prompt
37+
* @param prompt std::string with input prompt
38+
* @param tokenization_params AnyMap with tokenization parameters, e.g. {'add_special_tokens', false}
3639
* @return pair of [input_ids, attention_mask]
3740
*/
38-
TokenizedInputs encode(const std::string prompt);
41+
TokenizedInputs encode(const std::string prompt, const ov::AnyMap& tokenization_params = {});
3942

4043
/**
4144
* @brief encode batch of prompts. Left padding will be applied by default
4245
* @param prompts vector storing batch of prompts
46+
* @param tokenization_params AnyMap with tokenization parameters, e.g. {'add_special_tokens', false}
4347
* @return pair of [input_ids, attention_mask]
4448
*/
45-
TokenizedInputs encode(std::vector<std::string>& prompts);
46-
TokenizedInputs encode(std::vector<std::string>&& prompts);
47-
TokenizedInputs encode(std::initializer_list<std::string>& prompts);
48-
49+
TokenizedInputs encode(std::vector<std::string>& prompt, const ov::AnyMap& tokenization_params = {});
50+
TokenizedInputs encode(std::vector<std::string>&& prompts, const ov::AnyMap& tokenization_params = {});
51+
TokenizedInputs encode(std::initializer_list<std::string>& prompts, const ov::AnyMap& tokenization_params = {});
52+
53+
/**
54+
* @brief encode a single prompt
55+
* @param prompt std::string with input prompt
56+
* @param properties tokenization properties, e.g. ov::genai::add_special_tokens(false)
57+
* @return pair of [input_ids, attention_mask]
58+
*/
59+
template <typename... Properties>
60+
util::EnableIfAllStringAny<TokenizedInputs, Properties...> encode(std::string& prompt, Properties&&... properties) {
61+
return encode(prompt, AnyMap{std::forward<Properties>(properties)...});
62+
}
63+
64+
/**
65+
* @brief encode batch of prompts. Left padding will be applied by default
66+
* @param prompts vector storing batch of prompts
67+
* @param properties tokenization properties, e.g. ov::genai::add_special_tokens(false)
68+
* @return pair of [input_ids, attention_mask]
69+
*/
70+
template <typename... Properties>
71+
util::EnableIfAllStringAny<TokenizedInputs, Properties...> encode(std::vector<std::string>& prompts, Properties&&... properties) {
72+
return encode(prompts, AnyMap{std::forward<Properties>(properties)...});
73+
}
74+
4975
/**
5076
* @brief decode sequence of tokens
5177
* @param tokens vector storing tokens
@@ -103,5 +129,8 @@ class OPENVINO_GENAI_EXPORTS Tokenizer {
103129
class TokenizerImpl;
104130
std::shared_ptr<TokenizerImpl> m_pimpl;
105131
};
132+
133+
static constexpr ov::Property<bool> add_special_tokens{"add_special_tokens"};
134+
106135
} // namespace genai
107136
} // namespace ov

src/cpp/src/llm_pipeline.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,12 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
148148
m_history.push_back({{"role", "user"}, {"content", prompt}});
149149
constexpr bool add_generation_prompt = true;
150150
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
151-
auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history);
151+
bool add_special_tokens_ = false; // Do not add special tokens is chat scenario.
152+
auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(add_special_tokens_));
152153
if (m_is_cache_empty) {
153154
encoded_input = new_chat_tokens;
154155
} else {
155-
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history);
156+
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(add_special_tokens_));
156157
encoded_input = subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens);
157158
}
158159
m_templated_chat_history = new_templated_chat_history;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright (C) 2023-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#include "make_combine_segments_stateful.hpp"
5+
#include "openvino/op/constant.hpp"
6+
#include "openvino/op/select.hpp"
7+
#include "openvino/op/read_value.hpp"
8+
#include "openvino/op/assign.hpp"
9+
10+
11+
using namespace ov;
12+
using namespace ov::op;
13+
14+
bool ov::genai::MakeCombineSegmentsSatateful::run_on_model(const std::shared_ptr<ov::Model>& model) {
15+
16+
std::shared_ptr<ov::Node> combine_seg_node;
17+
for (auto node: model->get_ordered_ops()) {
18+
if (strcmp(node->get_type_info().name, "CombineSegments") == 0) {
19+
combine_seg_node = node;
20+
}
21+
}
22+
if (!combine_seg_node || combine_seg_node->input_value(1).get_element_type() != ov::element::i32) {
23+
return false;
24+
}
25+
26+
std::shared_ptr<v0::Constant> input_1_const = std::dynamic_pointer_cast<v0::Constant>(combine_seg_node->get_input_node_shared_ptr(1));
27+
if (!input_1_const) {
28+
return false;
29+
}
30+
31+
op::util::VariableInfo var_info{ov::Shape{}, ov::element::boolean, ADD_SPECIAL_TOKENS_VAR_ID};
32+
auto variable = std::make_shared<op::util::Variable>(var_info);
33+
34+
// Default mode is add_special_tokens.
35+
auto default_mode_const = std::make_shared<v0::Constant>(ov::element::boolean, ov::Shape{}, std::vector{true});
36+
auto read_value = std::make_shared<v6::ReadValue>(default_mode_const, variable);
37+
auto zero_constant = std::make_shared<v0::Constant>(ov::element::i32, ov::Shape{}, std::vector{0});
38+
auto select_node = std::make_shared<v1::Select>(read_value, input_1_const, zero_constant);
39+
combine_seg_node->input(1).replace_source_output(select_node->output(0));
40+
41+
auto assign = std::make_shared<v6::Assign>(read_value, variable);
42+
43+
model->add_sinks({assign});
44+
model->add_variables({variable});
45+
return true;
46+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright (C) 2023-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#include "openvino/op/constant.hpp"
5+
#include "openvino/pass/pass.hpp"
6+
7+
namespace ov {
8+
namespace genai {
9+
10+
/**
11+
* @brief This pass modifies tokenizer ov::Model so that special tokens adding will be
12+
* enabled or diabled depending on stateful value.
13+
*
14+
* +--------------+
15+
* | DefaultMode |
16+
* +--------------+
17+
* |
18+
* |
19+
* v
20+
* +--------------+ +--------+ +------------------+
21+
* | ReadValue | | ends | | const value = 0 |
22+
* +--------------+ +--------+ +------------------+
23+
* \ | /
24+
* \ | /
25+
* v v v
26+
* +--------------+
27+
* | Select |
28+
* +--------------+
29+
* |
30+
* v
31+
* +-------------------------+
32+
* | CombineSegments |
33+
* +-------------------------+
34+
**/
35+
class MakeCombineSegmentsSatateful : public ov::pass::ModelPass {
36+
public:
37+
OPENVINO_RTTI("MakeCombineSegmentsSatateful", "0");
38+
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
39+
};
40+
41+
const std::string ADD_SPECIAL_TOKENS_VAR_ID = "add_special_tokens";
42+
43+
} // namespace genai
44+
} // namespace ov

src/cpp/src/tokenizer.cpp

+57-16
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44
#include <filesystem>
55
#include <fstream>
66
#include <memory>
7-
87
#include <jinja2cpp/template.h>
98
#include <jinja2cpp/template_env.h>
109
#include <jinja2cpp/user_callable.h>
1110
#include <jinja2cpp/generic_list.h>
1211
#include <jinja2cpp/generic_list_iterator.h>
1312

13+
#include "openvino/pass/manager.hpp"
1414
#include "openvino/runtime/core.hpp"
1515
#include "openvino/genai/tokenizer.hpp"
1616

17+
#include "make_combine_segments_stateful.hpp"
1718
#include "tokenizers_path.hpp"
1819
#include "circular_buffer_queue.hpp"
1920
#include "utils.hpp"
@@ -69,7 +70,10 @@ class Tokenizer::TokenizerImpl {
6970

7071
std::unique_ptr<CircularBufferQueue<ov::InferRequest>> m_ireq_queue_tokenizer;
7172
std::unique_ptr<CircularBufferQueue<ov::InferRequest>> m_ireq_queue_detokenizer;
72-
73+
// To change the adding special tokens mode we use a statefull subgraph,
74+
// this flag holds the current state value of the CompiledModel.
75+
bool m_add_special_tokens = true;
76+
7377
int64_t m_pad_token_id = -1;
7478
int64_t m_bos_token_id = -1;
7579
int64_t m_eos_token_id = -1;
@@ -80,6 +84,29 @@ class Tokenizer::TokenizerImpl {
8084

8185
std::string m_chat_template = "";
8286

87+
void set_state_if_necessary(CircularBufferQueueElementGuard<ov::InferRequest>& infer_request_guard, bool add_special_tokens) {
88+
// If user requested add_special_tokens mode different from the current one,
89+
// need to set state variable.
90+
// If requested mode matches the stored state set, then don't touch states.
91+
if (add_special_tokens == m_add_special_tokens) {
92+
return;
93+
}
94+
95+
// auto states = m_ireq_queue_tokenizer->get(0).query_state();
96+
ov::Tensor add_special_tensor = ov::Tensor(ov::element::boolean, {});
97+
*add_special_tensor.data<bool>() = add_special_tokens;
98+
99+
for (auto& state: infer_request_guard.get().query_state()) {
100+
if (state.get_name().find(ov::genai::ADD_SPECIAL_TOKENS_VAR_ID) == std::string::npos) {
101+
// It's not add_special_tokens flag state.
102+
continue;
103+
}
104+
state.set_state(add_special_tensor);
105+
break;
106+
}
107+
m_add_special_tokens = add_special_tokens;
108+
}
109+
83110
TokenizerImpl() = default;
84111

85112
TokenizerImpl(std::filesystem::path tokenizer_path, const ov::AnyMap& plugin_config)
@@ -99,13 +126,18 @@ class Tokenizer::TokenizerImpl {
99126
read_tokenizer_config_if_necessary(tokenizer_path);
100127

101128
auto device = "CPU"; // currently openvino_tokenizer supports only CPU
102-
m_tokenizer = core.compile_model(tokenizer_path / "openvino_tokenizer.xml",
103-
device, plugin_config);
129+
auto ov_tokenizer = core.read_model(tokenizer_path / "openvino_tokenizer.xml");
130+
131+
ov::pass::Manager manager;
132+
manager.register_pass<MakeCombineSegmentsSatateful>();
133+
manager.run_passes(ov_tokenizer);
134+
135+
m_tokenizer = core.compile_model(ov_tokenizer, device, plugin_config);
104136
if (std::filesystem::exists(tokenizer_path / "openvino_detokenizer.xml")) {
105-
m_detokenizer = core.compile_model(tokenizer_path / "openvino_detokenizer.xml",
106-
device, plugin_config);
137+
m_detokenizer = core.compile_model(tokenizer_path / "openvino_detokenizer.xml", device, plugin_config);
107138
}
108139

140+
109141
const size_t INFER_REQUEST_QUEUE_SIZE = m_tokenizer.get_property(ov::optimal_number_of_infer_requests);
110142
m_ireq_queue_tokenizer = std::make_unique<CircularBufferQueue<ov::InferRequest>>(
111143
INFER_REQUEST_QUEUE_SIZE,
@@ -256,8 +288,12 @@ class Tokenizer::TokenizerImpl {
256288
get_id_from_str(m_eos_token, m_eos_token_id);
257289
}
258290

259-
TokenizedInputs encode(std::string prompt) {
291+
TokenizedInputs encode(std::string prompt, const ov::AnyMap& tokenization_params = {}) {
292+
bool add_special_tokens_flag = true;
293+
ov::genai::utils::read_anymap_param(tokenization_params, add_special_tokens.name(), add_special_tokens_flag);
294+
260295
CircularBufferQueueElementGuard<ov::InferRequest> infer_request_guard(this->m_ireq_queue_tokenizer.get());
296+
set_state_if_necessary(infer_request_guard, add_special_tokens_flag);
261297
size_t batch_size = 1;
262298
infer_request_guard.get().set_input_tensor(ov::Tensor{ov::element::string, {batch_size}, &prompt});
263299
infer_request_guard.get().start_async();
@@ -268,10 +304,15 @@ class Tokenizer::TokenizerImpl {
268304
);
269305
}
270306

271-
TokenizedInputs encode(std::vector<std::string>& prompts) {
307+
TokenizedInputs encode(std::vector<std::string>& prompts, const ov::AnyMap& tokenization_params = {}) {
308+
272309
TokenizedInputs unpadded;
273310
{
311+
bool add_special_tokens_flag = true;
312+
ov::genai::utils::read_anymap_param(tokenization_params, add_special_tokens.name(), add_special_tokens_flag);
313+
274314
CircularBufferQueueElementGuard<ov::InferRequest> infer_request_guard(this->m_ireq_queue_tokenizer.get());
315+
set_state_if_necessary(infer_request_guard, add_special_tokens_flag);
275316
infer_request_guard.get().set_input_tensor(ov::Tensor{ov::element::string, {prompts.size()}, prompts.data()});
276317
auto size_ = infer_request_guard.get().get_input_tensor().get_shape();
277318
infer_request_guard.get().start_async();
@@ -454,20 +495,20 @@ Tokenizer::Tokenizer(const std::string& tokenizer_path, const ov::AnyMap& plugin
454495
m_pimpl = std::make_shared<TokenizerImpl>(tokenizer_path, plugin_config);
455496
}
456497

457-
TokenizedInputs Tokenizer::encode(const std::string prompt) {
458-
return m_pimpl->encode(std::move(prompt));
498+
TokenizedInputs Tokenizer::encode(const std::string prompt, const ov::AnyMap& tokenization_params) {
499+
return m_pimpl->encode(std::move(prompt), tokenization_params);
459500
}
460501

461-
TokenizedInputs Tokenizer::encode(std::vector<std::string>& prompts) {
462-
return m_pimpl->encode(prompts);
502+
TokenizedInputs Tokenizer::encode(std::vector<std::string>& prompts, const ov::AnyMap& tokenization_params) {
503+
return m_pimpl->encode(prompts, tokenization_params);
463504
}
464505

465-
TokenizedInputs Tokenizer::encode(std::vector<std::string>&& prompts) {
466-
return m_pimpl->encode(prompts);
506+
TokenizedInputs Tokenizer::encode(std::vector<std::string>&& prompts, const ov::AnyMap& tokenization_params) {
507+
return m_pimpl->encode(prompts, tokenization_params);
467508
}
468509

469-
TokenizedInputs Tokenizer::encode(std::initializer_list<std::string>& text) {
470-
return encode(std::vector<std::string>(text.begin(), text.end()));
510+
TokenizedInputs Tokenizer::encode(std::initializer_list<std::string>& text, const ov::AnyMap& tokenization_params) {
511+
return encode(std::vector<std::string>(text.begin(), text.end()), tokenization_params);
471512
}
472513

473514
std::string Tokenizer::decode(std::vector<int64_t> tokens) {

src/python/py_generate_pipeline.cpp

+13-4
Original file line numberDiff line numberDiff line change
@@ -520,12 +520,21 @@ PYBIND11_MODULE(py_generate_pipeline, m) {
520520
return std::make_unique<ov::genai::Tokenizer>(tokenizer_path, utils::properties_to_any_map(plugin_config));
521521
}), py::arg("tokenizer_path"), py::arg("plugin_config") = ov::AnyMap({}))
522522

523-
.def("encode", [](Tokenizer& tok, std::vector<std::string>& prompts) { return tok.encode(prompts); },
523+
.def("encode", [](Tokenizer& tok, std::vector<std::string>& prompts, bool add_special_tokens) {
524+
ov::AnyMap tokenization_params;
525+
tokenization_params[ov::genai::add_special_tokens.name()] = add_special_tokens;
526+
return tok.encode(prompts, tokenization_params);
527+
},
524528
py::arg("prompts"),
529+
py::arg("add_special_tokens") = true,
525530
R"(Encodes a list of prompts into tokenized inputs.)")
526-
527-
.def("encode", py::overload_cast<const std::string>(&Tokenizer::encode),
528-
py::arg("prompt"),
531+
532+
.def("encode", [](Tokenizer& tok, const std::string prompt, bool add_special_tokens) {
533+
ov::AnyMap tokenization_params;
534+
tokenization_params[ov::genai::add_special_tokens.name()] = add_special_tokens;
535+
return tok.encode(prompt, tokenization_params);
536+
},
537+
py::arg("prompt"), py::arg("add_special_tokens") = true,
529538
R"(Encodes a single prompt into tokenized input.)")
530539

531540
.def(

tests/python_tests/test_chat_generate_api.py

+28-3
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,8 @@ def test_chat_compare_with_HF(model_descr, generation_config: Dict):
4242
chat_history_ov = []
4343
chat_prompt = ''
4444

45-
# HF in chat scenario does not add special tokens, but openvino tokenizer by default is converted with add_special_tokens=True.
46-
# Need to regenerate openvino_tokenizer/detokenizer.
47-
model_id, path, tokenizer, model_opt, pipe = read_model((model_descr[0], model_descr[1] / '_test_chat'), add_special_tokens=False)
45+
# Will set add_special_tokens=False inside pipeline when start_chat() is called.
46+
model_id, path, tokenizer, model_opt, pipe = read_model((model_descr[0], model_descr[1] / '_test_chat'))
4847

4948
pipe.start_chat()
5049
for prompt in quenstions:
@@ -197,3 +196,29 @@ def test_set_chat_template():
197196
pipe.finish_chat()
198197
reference = pipe.generate("a", max_new_tokens=1)
199198
assert generated == reference
199+
200+
prompts = [
201+
'1+1=',
202+
'What is the previous answer?',
203+
'Why is the Sun yellow?',
204+
'What was my first question?',
205+
['Why is the Sun yellow?'],
206+
"若我有一亿美元,在人工智能盛行的今天,我怎样投资才能收益最大化?",
207+
"מחרוזת בדיקה",
208+
"Multiline\nstring!\nWow!",
209+
]
210+
211+
@pytest.mark.precommit
212+
@pytest.mark.nightly
213+
@pytest.mark.parametrize("add_special_tokens", [True, False])
214+
@pytest.mark.parametrize("prompt", prompts)
215+
def test_add_special_tokens(add_special_tokens, prompt):
216+
import numpy as np
217+
model_descr = get_chat_models_list()[0]
218+
model_id, path, hf_tokenizer, model_opt, pipe = read_model((model_descr[0], model_descr[1] / '_test_chat'))
219+
genai_tokenzier = pipe.get_tokenizer()
220+
221+
# Calling encode with add_special_tokens will set state flag.
222+
res_genai = genai_tokenzier.encode(prompt, add_special_tokens).input_ids.data
223+
res_hf = hf_tokenizer(prompt, return_tensors="np", add_special_tokens=add_special_tokens)["input_ids"]
224+
assert np.all(res_genai == res_hf)

0 commit comments

Comments
 (0)