19
19
from openvino .runtime .exceptions import OVTypeError , UserInputError
20
20
from openvino .runtime .utils .types import as_node , make_constant_node
21
21
22
- from . import _get_factory , _get_opset_factory
22
+ from . import _get_factory
23
23
from .constants import (
24
24
ATTENTION_MASK_INPUT_NAME ,
25
25
DETOKENIZER_NAME ,
31
31
VOCAB_SIZE_CACHE_PROPORTION ,
32
32
UTF8ReplaceMode ,
33
33
)
34
- from .utils import (
35
- apply_unicode_to_bytes ,
36
- create_unpacked_string ,
37
- generate_tokens_with_space_symbols ,
38
- has_incompatible_re2_op ,
39
- quote_meta ,
40
- )
34
+ from .str_pack import pack_string , pack_strings
35
+ from .utils import apply_unicode_to_bytes , generate_tokens_with_space_symbols , has_incompatible_re2_op , quote_meta
41
36
42
37
43
38
logger = logging .getLogger (__name__ )
@@ -71,15 +66,15 @@ def get_ov_subgraph(self, *input_nodes: List[Output]) -> List[Output]:
71
66
raise NotImplementedError
72
67
73
68
@staticmethod
74
- def create_string_constant_node (value : Union [str , Iterable [str ]]) -> List [ Output ] :
69
+ def create_string_constant_node (value : Union [str , Iterable [str ]]) -> op . Constant :
75
70
if isinstance (value , str ):
76
71
# string scalar
77
- return op .Constant (np .frombuffer (bytes (value , "utf-8" ), dtype = np .uint8 )).outputs ()
78
- elif isinstance (value , Iterable ):
79
- # support only 1D strings for now
80
- return create_unpacked_string (value )
72
+ ps = pack_string (value )
73
+ return op .Constant (ps )
81
74
else :
82
- raise ValueError (f"Unsupported value type { type (value )} " )
75
+ # support only 1D strings for now
76
+ ps = pack_strings (value )
77
+ return _get_factory ().create ("StringTensorUnpack" , op .Constant (ps ).outputs ())
83
78
84
79
def finalize (self ) -> None :
85
80
"""Called after the entire pipeline has been built"""
@@ -149,7 +144,7 @@ def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:
149
144
return list (input_nodes )
150
145
151
146
split_pattern = "|" .join (token .regex_repr () for token in self .special_tokens )
152
- input_nodes .extend (self .create_string_constant_node (split_pattern ))
147
+ input_nodes .extend (self .create_string_constant_node (split_pattern ). outputs () )
153
148
154
149
return _get_factory ().create ("SpecialTokensSplit" , input_nodes ).outputs ()
155
150
@@ -238,10 +233,10 @@ def del_control_chars_regex(cls) -> "RegexNormalizationStep":
238
233
239
234
def get_ov_subgraph (self , input_nodes : List [Output ]) -> List [Output ]:
240
235
input_nodes .extend (
241
- [
242
- * self .create_string_constant_node (self .regex_search_pattern ),
243
- * self .create_string_constant_node (self .replace_term ),
244
- ]
236
+ (
237
+ self .create_string_constant_node (self .regex_search_pattern ),
238
+ self .create_string_constant_node (self .replace_term ),
239
+ )
245
240
)
246
241
return (
247
242
_get_factory ().create ("RegexNormalization" , input_nodes , {"global_replace" : self .global_replace }).outputs ()
@@ -362,7 +357,7 @@ def punctuation_splitter(cls, behaviour="isolate") -> "RegexSplitStep":
362
357
)
363
358
364
359
def get_ov_subgraph (self , input_nodes : List [Output ]) -> List [Output ]:
365
- input_nodes .extend (self .create_string_constant_node (self .split_pattern ))
360
+ input_nodes .extend (self .create_string_constant_node (self .split_pattern ). outputs () )
366
361
return (
367
362
_get_factory ()
368
363
.create (
@@ -428,7 +423,7 @@ def get_vocab_node_outputs(self) -> Optional[List[Output]]:
428
423
429
424
def get_ov_subgraph (self , input_nodes : List [Output ]) -> List [Output ]:
430
425
pipeline = self .get_pipeline ()
431
- pipeline .vocab_node_outputs = self .create_string_constant_node (self .vocab )
426
+ pipeline .vocab_node_outputs = self .create_string_constant_node (self .vocab ). outputs ()
432
427
433
428
ragged_dims , other_dims = [], input_nodes
434
429
if len (input_nodes ) > 4 :
@@ -480,7 +475,7 @@ def from_rwkv_vocab(cls, vocab_file_strings: Iterable[str]) -> TrieTokenizerStep
480
475
def get_ov_subgraph (self , input_nodes : List [Output ]) -> List [Output ]:
481
476
input_nodes .extend (
482
477
(
483
- * self .create_string_constant_node (self .vocab ),
478
+ * self .create_string_constant_node (self .vocab ). outputs () ,
484
479
make_constant_node (np .array (self .indices , dtype = np .int32 ), Type .i32 ),
485
480
)
486
481
)
@@ -516,7 +511,7 @@ def from_hf_json(cls, tokenizer_json: Dict[str, Any]) -> "WordPieceTokenizationS
516
511
def get_ov_subgraph (self , input_nodes : List [Output ]) -> List [Output ]:
517
512
input_nodes .extend (
518
513
(
519
- * self .create_string_constant_node (self .vocab ),
514
+ * self .create_string_constant_node (self .vocab ). outputs () ,
520
515
* as_node (self .unk_token_id ).outputs (),
521
516
)
522
517
)
@@ -648,10 +643,10 @@ def merges_are_pairs(self) -> bool:
648
643
649
644
def get_ov_subgraph (self , input_nodes : List [Output ]) -> List [Output ]:
650
645
pipeline = self .get_pipeline ()
651
- pipeline .vocab_node_outputs = self .create_string_constant_node (self .vocab )
646
+ pipeline .vocab_node_outputs = self .create_string_constant_node (self .vocab ). outputs ()
652
647
653
648
if self .added_tokens :
654
- special_tokens_outputs = self .create_string_constant_node (self .added_tokens )
649
+ special_tokens_outputs = self .create_string_constant_node (self .added_tokens ). outputs ()
655
650
else :
656
651
special_tokens_outputs = []
657
652
@@ -664,12 +659,12 @@ def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:
664
659
left_merges , right_merges = zip (* self .merges )
665
660
input_nodes .extend (
666
661
(
667
- * self .create_string_constant_node (left_merges ),
668
- * self .create_string_constant_node (right_merges ),
662
+ * self .create_string_constant_node (left_merges ). outputs () ,
663
+ * self .create_string_constant_node (right_merges ). outputs () ,
669
664
)
670
665
)
671
666
else :
672
- input_nodes .extend (self .create_string_constant_node (self .merges ))
667
+ input_nodes .extend (self .create_string_constant_node (self .merges ). outputs () )
673
668
674
669
if special_tokens_outputs :
675
670
input_nodes .extend (
@@ -1040,13 +1035,7 @@ def finalize(self) -> None:
1040
1035
self .skip_tokens = pipeline .skip_tokens or []
1041
1036
1042
1037
@classmethod
1043
- def from_hf_json (
1044
- cls ,
1045
- tokenizer_json : Dict [str , Any ],
1046
- pipeline_vocab : Optional [List [str ]],
1047
- skip_tokens : Optional [List [int ]] = None ,
1048
- do_skip_tokens : bool = True ,
1049
- ) -> "VocabDecoderStep" :
1038
+ def from_hf_json (cls , tokenizer_json : Dict [str , Any ], pipeline_vocab : Optional [List [str ]], skip_tokens : Optional [List [int ]] = None , do_skip_tokens : bool = True ) -> "VocabDecoderStep" :
1050
1039
model_type = tokenizer_json ["model" ]["type" ]
1051
1040
1052
1041
if pipeline_vocab is not None and model_type == "WordLevel" :
@@ -1068,7 +1057,7 @@ def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:
1068
1057
if self .vocab is None :
1069
1058
vocab_outputs = self .get_vocab_node_outputs ()
1070
1059
else :
1071
- vocab_outputs = self .create_string_constant_node (self .vocab )
1060
+ vocab_outputs = self .create_string_constant_node (self .vocab ). outputs ()
1072
1061
input_nodes .extend (vocab_outputs )
1073
1062
1074
1063
# Put constant with skip tokens even if do_skip_tokens=False, so that it can be switched on/off at runtime.
@@ -1189,8 +1178,8 @@ def get_ov_subgraph(self, input_nodes: List[Output]) -> List[Output]:
1189
1178
1190
1179
input_nodes .extend (
1191
1180
(
1192
- * self .create_string_constant_node (self .regex_search_pattern ),
1193
- * self .create_string_constant_node (self .replace_term ),
1181
+ * self .create_string_constant_node (self .regex_search_pattern ). outputs () ,
1182
+ * self .create_string_constant_node (self .replace_term ). outputs () ,
1194
1183
)
1195
1184
)
1196
1185
return ragged_dims + _get_factory ().create ("RegexNormalization" , input_nodes ).outputs ()
@@ -1245,7 +1234,7 @@ def get_tokenizer_ov_subgraph(self) -> Model:
1245
1234
1246
1235
processing_outputs = []
1247
1236
for input_node in string_inputs :
1248
- input_node = _get_opset_factory ( "opset15" ).create ("StringTensorUnpack" , input_node .outputs ()).outputs ()
1237
+ input_node = _get_factory ( ).create ("StringTensorUnpack" , input_node .outputs ()).outputs ()
1249
1238
1250
1239
ragged = []
1251
1240
if isinstance (self .steps [0 ], SpecialTokensSplit ):
@@ -1318,7 +1307,7 @@ def create_decoding_pipeline(self, input_nodes: List[Output]) -> List[Output]:
1318
1307
pipeline_step = step .get_ov_subgraph (input_nodes )
1319
1308
input_nodes = pipeline_step
1320
1309
1321
- return _get_opset_factory ( "opset15" ).create ("StringTensorPack" , input_nodes ).outputs ()
1310
+ return _get_factory ( ).create ("StringTensorPack" , input_nodes ).outputs ()
1322
1311
1323
1312
def get_detokenizer_ov_subgraph (self ) -> Model :
1324
1313
self .finalize ()
0 commit comments