13
13
using namespace ov ;
14
14
15
15
16
- VocabEncoder::VocabEncoder (const ov::OutputVector& arguments) :
17
- ov::op::Op(arguments) {
18
- constructor_validate_and_infer_types ();
19
- }
20
-
21
-
22
16
void VocabEncoder::validate_and_infer_types () {
23
17
// main string input
24
18
check_string_input (this , 0 );
@@ -44,19 +38,21 @@ bool VocabEncoder::evaluate(ov::TensorVector& outputs, const ov::TensorVector& i
44
38
auto ends = inputs[1 ].data <const int32_t >();
45
39
auto chars = inputs[2 ].data <const uint8_t >();
46
40
47
- // vocab string keys
48
- auto vocab_begins = inputs[3 ].data <const int32_t >();
49
- auto vocab_ends = inputs[4 ].data <const int32_t >();
50
- auto vocab_chars = inputs[5 ].data <const uint8_t >();
41
+ if (m_vocab == nullptr ) {
42
+ // vocab string keys
43
+ auto vocab_begins = inputs[3 ].data <const int32_t >();
44
+ auto vocab_ends = inputs[4 ].data <const int32_t >();
45
+ auto vocab_chars = inputs[5 ].data <const uint8_t >();
51
46
52
- auto vocab_values = inputs[6 ].data <const int32_t >();
53
- auto vocab_size = inputs[6 ].get_size ();
47
+ auto vocab_values = inputs[6 ].data <const int32_t >();
48
+ auto vocab_size = inputs[6 ].get_size ();
54
49
55
- std::map<std::vector<uint8_t >, int32_t > vocab;
56
- for (size_t i = 0 ; i < vocab_size; ++i) {
57
- std::vector<uint8_t > token = std::vector<uint8_t >(vocab_chars + vocab_begins[i], vocab_chars + vocab_ends[i]);
58
- vocab[token] = vocab_values[i];
59
- };
50
+ m_vocab = std::make_shared<std::map<std::vector<unsigned char >, int32_t >>();
51
+ for (size_t i = 0 ; i < vocab_size; ++i) {
52
+ std::vector<uint8_t > token = std::vector<uint8_t >(vocab_chars + vocab_begins[i], vocab_chars + vocab_ends[i]);
53
+ m_vocab->insert (std::pair{token, vocab_values[i]});
54
+ };
55
+ }
60
56
61
57
auto default_value = *inputs[7 ].data <const int32_t >();
62
58
const size_t num_elements = inputs[0 ].get_size ();
@@ -66,8 +62,8 @@ bool VocabEncoder::evaluate(ov::TensorVector& outputs, const ov::TensorVector& i
66
62
auto token_ids = outputs[0 ].data <int32_t >();
67
63
68
64
for (size_t element_idx = 0 ; element_idx < num_elements; ++element_idx) {
69
- auto element = vocab. find (std::vector<uint8_t >(chars + begins[element_idx], chars + ends[element_idx]));
70
- if (element == vocab. end ()) {
65
+ auto element = m_vocab-> find (std::vector<uint8_t >(chars + begins[element_idx], chars + ends[element_idx]));
66
+ if (element == m_vocab-> end ()) {
71
67
token_ids[element_idx] = default_value;
72
68
} else {
73
69
token_ids[element_idx] = element->second ;
0 commit comments