Skip to content

Commit 7c7d03b

Browse files
authored
[Tokenizers] add max_lengh parametrisation to encode (#1518)
Works in collaboration with tokenizers changes openvinotoolkit/openvino_tokenizers#362 Ticket: CVS-157356, CVS-159924
1 parent 06a95e4 commit 7c7d03b

12 files changed

+374
-105
lines changed

src/README.md

+97
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,103 @@ print(f'Median from token to token duration: {np.median(durations):.2f} ms')
398398

399399
For more examples of how metrics are used, please refer to the Python [benchmark_genai.py](../samples/python/text_generation/README.md) and C++ [benchmark_genai](../samples/cpp/text_generation/README.md) samples.
400400

401+
### Tokenization
402+
403+
OpenVINO™ GenAI provides a way to tokenize and detokenize text using the `ov::genai::Tokenizer` class. The `Tokenizer` is a high level abstraction over the OpenVINO Tokenizers library.
404+
405+
It can be initialized from the path, in-memory IR representation or obtained from the `ov::genai::LLMPipeline` object.
406+
407+
```cpp
408+
// Initialize from the path
409+
#include "openvino/genai/llm_pipeline.hpp"
410+
auto tokenizer = ov::genai::Tokenizer(models_path);
411+
412+
// Get instance of Tokenizer from LLMPipeline.
413+
auto pipe = ov::genai::LLMPipeline pipe(models_path, "CPU");
414+
auto tokenzier = pipe.get_tokenizer();
415+
````
416+
417+
```python
418+
import openvino_genai as ov_genai
419+
tokenizer = ov_genai.Tokenizer(models_path)
420+
421+
# Or from LLMPipeline.
422+
pipe = ov_genai.LLMPipeline(models_path, "CPU")
423+
tokenizer = pipe.get_tokenizer()
424+
```
425+
426+
`Tokenizer` has `encode` and `decode` methods which support the following arguments: `add_special_tokens`, `skip_special_tokens`, `pad_to_max_length`, `max_length` arguments.
427+
428+
In order to disable adding special tokens do the followings, in C++:
429+
```cpp
430+
auto tokens = tokenizer.encode("The Sun is yellow because", ov::genai::add_special_tokens(false));
431+
```
432+
433+
In Python:
434+
```python
435+
tokens = tokenizer.encode("The Sun is yellow because", add_special_tokens=False)
436+
```
437+
The `encode` method returns a `TokenizedInputs` object containing `input_ids` and `attention_mask`, both stored as ov::Tensor. Since ov::Tensor requires fixed-length sequences, padding is applied to match the longest sequence in a batch, ensuring a uniform shape. Also resulting sequence is truncated by `max_length`. If this value is not defined by used, it's is taken from the IR.
438+
439+
Both padding and `max_length` can be controlled by the user. If `pad_to_max_length` is set to true, then instead of padding to the longest sequence it will be padded to the `max_length`.
440+
441+
Below are example how padding can be controlled, in C++:
442+
```cpp
443+
#include "openvino/genai/llm_pipeline.hpp"
444+
auto tokenizer = ov::genai::Tokenizer(models_path);
445+
std::vector<std::string> prompts = {"The Sun is yellow because", "The"};
446+
447+
// Since prompt is defenitely shorter than maximal length (which is taken from IR) will not affect shape.
448+
// Resulting shape is defined by length of the longest tokens sequence.
449+
// Equivalent of HuggingFace hf_tokenizer.encode(prompt, padding="longest", truncation=True)
450+
tokens = tokenizer.encode({"The Sun is yellow because", "The"})
451+
// or is equivalent to
452+
tokens = tokenizer.encode({"The Sun is yellow because", "The"}, ov::genai::pad_to_max_length(False))
453+
// out_shape: [2, 6]
454+
455+
// Resulting tokens tensor will be padded to 1024.
456+
// Equivalent of HuggingFace hf_tokenizer.encode(prompt, padding="max_length", truncation=True, max_length=1024)
457+
tokens = tokenizer.encode({"The Sun is yellow because",
458+
"The",
459+
std::string(2000, 'n')}, ov::genai::pad_to_max_length(True), ov::genai::max_length(1024))
460+
// out_shape: [3, 1024]
461+
462+
// For single string prompts truncation and padding are also applied.
463+
tokens = tokenizer.encode({"The Sun is yellow because"}, ov::genai::pad_to_max_length(True), ov::genai::max_length(1024))
464+
// out_shape: [1, 128]
465+
```
466+
467+
In Python:
468+
```python
469+
import openvino_genai as ov_genai
470+
471+
tokenizer = ov_genai.Tokenizer(models_path)
472+
prompts = ["The Sun is yellow because", "The"]
473+
474+
# Since prompt is defenitely shorter than maximal length (which is taken from IR) will not affect shape.
475+
# Resulting shape is defined by length of the longest tokens sequence.
476+
# Equivalent of HuggingFace hf_tokenizer.encode(prompt, padding="longest", truncation=True)
477+
tokens = tokenizer.encode(["The Sun is yellow because", "The"])
478+
# or is equivalent to
479+
tokens = tokenizer.encode(["The Sun is yellow because", "The"], pad_to_max_length=False)
480+
print(tokens.input_ids.shape)
481+
# out_shape: [2, 6]
482+
483+
# Resulting tokens tensor will be padded to 1024, sequences which exceed this length will be truncated.
484+
# Equivalent of HuggingFace hf_tokenizer.encode(prompt, padding="max_length", truncation=True, max_length=1024)
485+
tokens = tokenizer.encode(["The Sun is yellow because",
486+
"The"
487+
"The longest string ever" * 2000], pad_to_max_length=True, max_length=1024)
488+
print(tokens.input_ids.shape)
489+
# out_shape: [3, 1024]
490+
491+
# For single string prompts truncation and padding are also applied.
492+
tokens = tokenizer.encode("The Sun is yellow because", pad_to_max_length=True, max_length=128)
493+
print(tokens.input_ids.shape)
494+
# out_shape: [1, 128]
495+
496+
```
497+
401498
## How It Works
402499
403500
For information on how OpenVINO™ GenAI works, refer to the [How It Works Section](./docs/HOW_IT_WORKS.md).

src/cpp/include/openvino/genai/tokenizer.hpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,15 @@ class OPENVINO_GENAI_EXPORTS Tokenizer {
107107
/**
108108
* @brief encode a single prompt
109109
* @param prompt std::string with input prompt
110-
* @param tokenization_params AnyMap with tokenization parameters, e.g. {"add_special_tokens", false}
110+
* @param tokenization_params AnyMap with tokenization parameters, e.g. {{"add_special_tokens", false}, {"max_length", 128}}
111111
* @return pair of [input_ids, attention_mask]
112112
*/
113113
TokenizedInputs encode(const std::string prompt, const ov::AnyMap& tokenization_params = {});
114114

115115
/**
116116
* @brief encode batch of prompts. Left padding will be applied by default
117117
* @param prompts vector storing batch of prompts
118-
* @param tokenization_params AnyMap with tokenization parameters, e.g. {"add_special_tokens", false}
118+
* @param tokenization_params AnyMap with tokenization parameters, e.g. {{"add_special_tokens", false}, {"max_length", 128}}
119119
* @return pair of [input_ids, attention_mask]
120120
*/
121121
TokenizedInputs encode(std::vector<std::string>& prompt, const ov::AnyMap& tokenization_params = {});
@@ -125,7 +125,9 @@ class OPENVINO_GENAI_EXPORTS Tokenizer {
125125
/**
126126
* @brief encode a single prompt
127127
* @param prompt std::string with input prompt
128-
* @param properties tokenization properties, e.g. ov::genai::add_special_tokens(false)
128+
* @param add_special_tokens whether to add special tokens
129+
* @param max_length optional maximum length to which output will be truncated and/or padded. If not defined, taken from IR.
130+
* @param pad_to_max_length either pad to max_length, or pad to the longest sequence in the batch. Default is false.
129131
* @return pair of [input_ids, attention_mask]
130132
*/
131133
template <typename... Properties>
@@ -136,7 +138,9 @@ class OPENVINO_GENAI_EXPORTS Tokenizer {
136138
/**
137139
* @brief encode batch of prompts. Left padding will be applied by default
138140
* @param prompts vector storing batch of prompts
139-
* @param properties tokenization properties, e.g. ov::genai::add_special_tokens(false)
141+
* @param add_special_tokens whether to add special tokens
142+
* @param max_length optional maximum length to which output will be truncated and/or padded. If not defined, taken from IR.
143+
* @param pad_to_max_length either pad to max_length, or pad to the longest sequence in the batch. Default is false.
140144
* @return pair of [input_ids, attention_mask]
141145
*/
142146
template <typename... Properties>
@@ -243,6 +247,7 @@ class OPENVINO_GENAI_EXPORTS Tokenizer {
243247

244248
static constexpr ov::Property<bool> add_special_tokens{"add_special_tokens"};
245249
static constexpr ov::Property<bool> skip_special_tokens{"skip_special_tokens"};
250+
static constexpr ov::Property<bool> pad_to_max_length{"pad_to_max_length"};
246251

247252
} // namespace genai
248253
} // namespace ov

src/cpp/src/make_tokenizer_stateful.cpp

+121-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
#include "make_tokenizer_stateful.hpp"
55
#include "openvino/op/constant.hpp"
66
#include "openvino/op/select.hpp"
7+
#include "openvino/op/maximum.hpp"
8+
#include "openvino/op/minimum.hpp"
9+
#include "openvino/op/add.hpp"
10+
#include "openvino/op/subtract.hpp"
711
#include "openvino/op/slice.hpp"
812
#include "openvino/op/multiply.hpp"
913
#include "openvino/op/read_value.hpp"
@@ -13,7 +17,7 @@
1317
using namespace ov;
1418
using namespace ov::op;
1519

16-
bool ov::genai::MakeCombineSegmentsSatateful::run_on_model(const std::shared_ptr<ov::Model>& model) {
20+
bool ov::genai::MakeAddSpecialTokensSatateful::run_on_model(const std::shared_ptr<ov::Model>& model) {
1721
std::shared_ptr<ov::Node> combine_seg_node;
1822
for (auto node: model->get_ordered_ops()) {
1923
if (strcmp(node->get_type_info().name, "CombineSegments") == 0) {
@@ -56,6 +60,7 @@ bool ov::genai::MakeCombineSegmentsSatateful::run_on_model(const std::shared_ptr
5660
return true;
5761
}
5862

63+
5964
bool ov::genai::MakeVocabDecoderSatateful::run_on_model(const std::shared_ptr<ov::Model>& model) {
6065
std::shared_ptr<ov::Node> vocab_decoder_node;
6166
for (auto node: model->get_ordered_ops()) {
@@ -97,3 +102,118 @@ bool ov::genai::MakeVocabDecoderSatateful::run_on_model(const std::shared_ptr<ov
97102
model->add_variables({variable});
98103
return true;
99104
}
105+
106+
107+
bool ov::genai::MakePaddingSatateful::run_on_model(const std::shared_ptr<ov::Model>& model) {
108+
std::shared_ptr<ov::Node> combine_seg_node;
109+
for (auto node: model->get_ordered_ops()) {
110+
if (strcmp(node->get_type_info().name, "CombineSegments") == 0) {
111+
combine_seg_node = node;
112+
}
113+
}
114+
if (!combine_seg_node) { return false; }
115+
auto num_comb = combine_seg_node->get_input_size();
116+
117+
size_t num_segments = (combine_seg_node->get_input_size() - 1) / 3;
118+
size_t number_of_main_tokens_inputs = 0;
119+
std::shared_ptr<Node> add_or_sub_node;
120+
for (size_t i = 0; i < num_segments; i++) {
121+
// Check all ends inputs of CombineSegments node.
122+
// For special tokens they are Constant/Select,
123+
// for the ends input with main tokens sequence it's Add/Subtract.
124+
// If Add then it's a right truncation, if Subtract then it's a left truncation.
125+
auto tmp_node = combine_seg_node->input_value(3*i + 1).get_node_shared_ptr();
126+
if (ov::as_type_ptr<v1::Add>(tmp_node) || ov::as_type_ptr<v1::Subtract>(tmp_node)) {
127+
number_of_main_tokens_inputs += 1;
128+
add_or_sub_node = tmp_node;
129+
}
130+
}
131+
132+
// Exit if couldn't find main input or there are several.
133+
if (number_of_main_tokens_inputs != 1) { return false; }
134+
135+
// Minimum between max_length and length of token sequence.
136+
auto min_node = ov::as_type_ptr<v1::Minimum>(add_or_sub_node->get_input_node_shared_ptr(1));
137+
if (!min_node) { return false; }
138+
139+
// constant containing final max_length - num_added tokens at the end of pipeline.
140+
auto const_node = ov::as_type_ptr<v0::Constant>(min_node->get_input_node_shared_ptr(1));
141+
if (!const_node) { return false; }
142+
143+
op::util::VariableInfo var_info{const_node->get_output_shape(0), const_node->get_output_element_type(0), MAX_LENGTH_VAR_ID};
144+
auto variable_1 = std::make_shared<op::util::Variable>(var_info);
145+
146+
size_t num_added_tokens = num_segments - number_of_main_tokens_inputs;
147+
// Constant which stores number of added_tokens.
148+
auto num_added_tokens_const = std::make_shared<v0::Constant>(
149+
const_node->get_output_element_type(0), const_node->get_output_shape(0), std::vector{num_added_tokens});
150+
151+
OPENVINO_ASSERT(const_node->get_element_type() == element::i32);
152+
auto values = const_node->get_vector<int32_t>();
153+
OPENVINO_ASSERT(values.size() == 1);
154+
// Since const_node contain value = max_length - num_added tokens,
155+
size_t default_max_length = values[0] + num_added_tokens;
156+
157+
auto default_max_length_const = std::make_shared<v0::Constant>(
158+
const_node->get_output_element_type(0), const_node->get_output_shape(0), std::vector{default_max_length});
159+
160+
// Save targets before adding new target with ReadValue to avoid recursion.
161+
auto target_inputs = const_node->output(0).get_target_inputs();
162+
auto max_length_rv = std::make_shared<v6::ReadValue>(default_max_length_const, variable_1);
163+
auto subtract_node = std::make_shared<v1::Subtract>(max_length_rv, num_added_tokens_const);
164+
165+
for (auto target_input : target_inputs) {
166+
target_input.replace_source_output(subtract_node->output(0));
167+
}
168+
169+
// We need to check if user requested to not add special tokens.
170+
std::shared_ptr<v6::ReadValue> read_value_spec_tokens;
171+
for (const auto& sink : model->get_sinks()) {
172+
// Check if sink accepts input from Assign, and if that't the case get the ReadValus node input.
173+
if (auto read_value = ov::as_type_ptr<v6::ReadValue>(sink->get_input_node_shared_ptr(0))) {
174+
if (read_value->get_variable()->get_info().variable_id == ADD_SPECIAL_TOKENS_VAR_ID) {
175+
read_value_spec_tokens = read_value;
176+
break;
177+
}
178+
}
179+
}
180+
181+
// If user requested to not add special tokens in order to correctly calculate
182+
// truncation we need to enforce num_added_tokens to 0 regardless the hardcoded value of Constant.
183+
if (read_value_spec_tokens && num_added_tokens_const) {
184+
auto zero_constant = std::make_shared<v0::Constant>(ov::element::i32, ov::Shape{}, std::vector{0});
185+
auto select_node = std::make_shared<v1::Select>(read_value_spec_tokens, num_added_tokens_const, zero_constant);
186+
subtract_node->input(1).replace_source_output(select_node->output(0));
187+
}
188+
189+
model->add_sinks({std::make_shared<v6::Assign>(max_length_rv, variable_1)});
190+
model->add_variables({variable_1});
191+
192+
std::shared_ptr<ov::Node> ragged_to_dense_node;
193+
for (auto node: model->get_ordered_ops()) {
194+
if (strcmp(node->get_type_info().name, "RaggedToDense") == 0) {
195+
ragged_to_dense_node = node;
196+
}
197+
}
198+
199+
if (!ragged_to_dense_node || ragged_to_dense_node->input_value(3).get_element_type() != ov::element::i32) {
200+
return true; // true since at this point we already have modified the graph.s
201+
}
202+
203+
auto variable_2 = std::make_shared<op::util::Variable>(op::util::VariableInfo{ov::Shape{1}, ov::element::boolean, PAD_TO_LONGEST_VAR_ID});
204+
205+
// By default do not pad to max_length
206+
auto default_false_const = std::make_shared<v0::Constant>(ov::element::boolean, ov::Shape{1}, std::vector{false});
207+
auto pad_to_max_length_rv = std::make_shared<v6::ReadValue>(default_false_const, variable_2);
208+
209+
auto zero_constant = std::make_shared<v0::Constant>(ov::element::i32, ov::Shape{}, std::vector{0});
210+
auto select_node = std::make_shared<v1::Select>(pad_to_max_length_rv, max_length_rv, zero_constant);
211+
212+
auto max_op = std::make_shared<v1::Maximum>(ragged_to_dense_node->input_value(3), select_node);
213+
ragged_to_dense_node->input(3).replace_source_output(max_op->output(0));
214+
215+
model->add_sinks({std::make_shared<v6::Assign>(pad_to_max_length_rv, variable_2)});
216+
model->add_variables({variable_2});
217+
218+
return true;
219+
}

src/cpp/src/make_tokenizer_stateful.hpp

+17-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "openvino/op/constant.hpp"
55
#include "openvino/pass/pass.hpp"
6+
#include "openvino/pass/matcher_pass.hpp"
67

78
namespace ov {
89
namespace genai {
@@ -32,9 +33,19 @@ namespace genai {
3233
* | CombineSegments |
3334
* +-------------------------+
3435
**/
35-
class MakeCombineSegmentsSatateful : public ov::pass::ModelPass {
36+
class MakeAddSpecialTokensSatateful : public ov::pass::ModelPass {
3637
public:
37-
OPENVINO_MODEL_PASS_RTTI("MakeCombineSegmentsSatateful");
38+
OPENVINO_MODEL_PASS_RTTI("MakeAddSpecialTokensSatateful");
39+
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
40+
};
41+
42+
/**
43+
* @brief This pass modifies tokenizer ov::Model so that inputs to RaggedToDense, CombineSegments
44+
* become modifiable during runtime so that padding can be controlled.
45+
*/
46+
class MakePaddingSatateful : public ov::pass::ModelPass {
47+
public:
48+
OPENVINO_MODEL_PASS_RTTI("MakePaddingSatateful");
3849
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
3950
};
4051

@@ -74,8 +85,10 @@ class MakeVocabDecoderSatateful : public ov::pass::ModelPass {
7485
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
7586
};
7687

77-
const std::string ADD_SPECIAL_TOKENS_VAR_ID = "add_special_tokens";
78-
const std::string SKIP_SPECIAL_TOKENS_VAR_ID = "skip_special_tokens";
88+
inline const std::string ADD_SPECIAL_TOKENS_VAR_ID = "add_special_tokens";
89+
inline const std::string SKIP_SPECIAL_TOKENS_VAR_ID = "skip_special_tokens";
90+
inline const std::string MAX_LENGTH_VAR_ID = "max_length";
91+
inline const std::string PAD_TO_LONGEST_VAR_ID = "PAD_TO_LONGEST";
7992

8093
} // namespace genai
8194
} // namespace ov

0 commit comments

Comments
 (0)