30
30
TOKENIZER_NAME ,
31
31
)
32
32
from .tokenizer_pipeline import (
33
+ AddToken ,
33
34
BasePipelineStep ,
34
35
BPETokenizationStep ,
35
36
ByteFallbackStep ,
47
48
RegexDecodingStep ,
48
49
RegexNormalizationStep ,
49
50
RegexSplitStep ,
51
+ Sequence ,
50
52
StripStringStep ,
51
53
TokenizerPipeline ,
52
54
TruncationStep ,
@@ -449,10 +451,26 @@ def is_sentencepiece_model(hf_tokenizer: PreTrainedTokenizerBase) -> bool:
449
451
if not hasattr (hf_tokenizer , "vocab_files_names" ) or "vocab_file" not in hf_tokenizer .vocab_files_names :
450
452
return False
451
453
vocab_file = Path (tmp ) / hf_tokenizer .vocab_files_names ["vocab_file" ]
452
- return (
454
+ vocab_file_exists = (
453
455
getattr (hf_tokenizer , "vocab_files_names" , {}).get ("vocab_file" , "" ).endswith (".model" )
454
456
and vocab_file .exists ()
455
457
)
458
+ if vocab_file_exists :
459
+ try :
460
+ from google .protobuf .message import DecodeError
461
+ except (ImportError , ModuleNotFoundError ):
462
+ return False
463
+
464
+ model_pb = import_protobuf ()
465
+ model = model_pb .ModelProto ()
466
+ try :
467
+ with open (vocab_file , "rb" ) as model_file :
468
+ model .ParseFromString (model_file .read ())
469
+ return True
470
+ except DecodeError :
471
+ pass # protobuf file is corrupted
472
+
473
+ return False
456
474
457
475
458
476
def modify_sentencepiece_model (
@@ -831,11 +849,13 @@ def get_sp_detokenizer(
831
849
def is_tiktoken_model (hf_tokenizer : PreTrainedTokenizerBase ) -> bool :
832
850
try :
833
851
from tiktoken import Encoding
834
- except ImportError :
852
+ except ( ImportError , ModuleNotFoundError ) :
835
853
return False
836
854
837
- return getattr (hf_tokenizer , "vocab_files_names" , {}).get ("vocab_file" , "" ).endswith (".tiktoken" ) or isinstance (
838
- getattr (hf_tokenizer , "encoder" , None ), Encoding
855
+ return (
856
+ getattr (hf_tokenizer , "vocab_files_names" , {}).get ("vocab_file" , "" ).endswith (".tiktoken" )
857
+ or isinstance (getattr (hf_tokenizer , "encoder" , None ), Encoding )
858
+ or isinstance (getattr (hf_tokenizer , "tokenizer" , None ), Encoding )
839
859
)
840
860
841
861
@@ -854,13 +874,20 @@ def convert_tiktoken_model_tokenizer(
854
874
if skip_special_tokens :
855
875
skip_tokens = list (parse_special_tokens (hf_tokenizer ))
856
876
877
+ add_prefix_steps = []
878
+ if hasattr (hf_tokenizer , "get_prefix_tokens" ):
879
+ prefix_tokens = [AddToken (_token_id = token_id ) for token_id in hf_tokenizer .get_prefix_tokens ()]
880
+ add_prefix_steps .append (CombineSegmentsStep (inputs = prefix_tokens + [Sequence ()]))
881
+
882
+ reference_vocab = getattr (hf_tokenizer , "get_vocab" , lambda : None )()
857
883
pipeline .add_steps (
858
884
[
859
885
NormalizeUnicode ("NFC" ),
860
886
RegexSplitStep (split_pattern , behaviour = "contiguous" ),
861
887
BytesToCharsStep (),
862
- BPETokenizationStep .from_tiktoken_encoding (encoding ),
888
+ BPETokenizationStep .from_tiktoken_encoding (encoding , reference_vocab = reference_vocab ),
863
889
TruncationStep .from_hf_object (hf_tokenizer ),
890
+ * add_prefix_steps ,
864
891
PaddingStep (
865
892
token = getattr (hf_tokenizer , "pad_token" ),
866
893
_token_id = getattr (hf_tokenizer , "pad_token_id" ),
0 commit comments