@@ -387,7 +387,7 @@ class StoreAttr(object):
387
387
)
388
388
389
389
if convert_tokenizer :
390
- maybe_convert_tokenizers (library_name , output , model , preprocessors )
390
+ maybe_convert_tokenizers (library_name , output , model , preprocessors , task = task )
391
391
392
392
clear_class_registry ()
393
393
del model
@@ -399,7 +399,7 @@ class StoreAttr(object):
399
399
GPTQQuantizer .post_init_model = orig_post_init_model
400
400
401
401
402
- def maybe_convert_tokenizers (library_name : str , output : Path , model = None , preprocessors = None ):
402
+ def maybe_convert_tokenizers (library_name : str , output : Path , model = None , preprocessors = None , task = None ):
403
403
"""
404
404
Tries to convert tokenizers to OV format and export them to disk.
405
405
@@ -412,6 +412,8 @@ def maybe_convert_tokenizers(library_name: str, output: Path, model=None, prepro
412
412
Model instance.
413
413
preprocessors (`Iterable`, *optional*, defaults to None):
414
414
Iterable possibly containing tokenizers to be converted.
415
+ task (`str`, *optional*, defaults to None):
416
+ The task to export the model for. Affects tokenizer conversion parameters.
415
417
"""
416
418
from optimum .exporters .openvino .convert import export_tokenizer
417
419
@@ -420,7 +422,7 @@ def maybe_convert_tokenizers(library_name: str, output: Path, model=None, prepro
420
422
tokenizer = next (filter (lambda it : isinstance (it , PreTrainedTokenizerBase ), preprocessors ), None )
421
423
if tokenizer :
422
424
try :
423
- export_tokenizer (tokenizer , output )
425
+ export_tokenizer (tokenizer , output , task = task )
424
426
except Exception as exception :
425
427
logger .warning (
426
428
"Could not load tokenizer using specified model ID or path. OpenVINO tokenizer/detokenizer "
@@ -430,6 +432,6 @@ def maybe_convert_tokenizers(library_name: str, output: Path, model=None, prepro
430
432
for tokenizer_name in ("tokenizer" , "tokenizer_2" ):
431
433
tokenizer = getattr (model , tokenizer_name , None )
432
434
if tokenizer :
433
- export_tokenizer (tokenizer , output / tokenizer_name )
435
+ export_tokenizer (tokenizer , output / tokenizer_name , task = task )
434
436
else :
435
437
logger .warning ("Tokenizer won't be converted." )
0 commit comments