Skip to content

Commit b2751dc

Browse files
committed
fix sentence transformers export
1 parent a9a235b commit b2751dc

File tree

3 files changed

+48
-2
lines changed

3 files changed

+48
-2
lines changed

optimum/exporters/openvino/__main__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def main_export(
6565
compression_ratio: Optional[float] = None,
6666
stateful: bool = True,
6767
convert_tokenizer: bool = False,
68+
library_name: Optional[str] = None,
6869
**kwargs_shapes,
6970
):
7071
"""
@@ -139,7 +140,9 @@ def main_export(
139140
original_task = task
140141
task = TasksManager.map_from_synonym(task)
141142
framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework)
142-
library_name = TasksManager.infer_library_from_model(model_name_or_path, subfolder=subfolder)
143+
library_name = TasksManager.infer_library_from_model(
144+
model_name_or_path, subfolder=subfolder, library_name=library_name
145+
)
143146

144147
if task == "auto":
145148
try:

optimum/exporters/openvino/convert.py

-1
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,6 @@ def export_pytorch(
374374
try:
375375
with config.patch_model_for_export(model, model_kwargs=model_kwargs):
376376
check_dummy_inputs_are_allowed(model, dummy_inputs)
377-
378377
inputs = config.ordered_inputs(model)
379378
input_names = list(inputs.keys())
380379
output_names = list(config.outputs.keys())

optimum/intel/openvino/modeling.py

+44
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import logging
1616
import os
1717
from pathlib import Path
18+
from tempfile import TemporaryDirectory
1819
from typing import Optional, Union
1920

2021
import numpy as np
@@ -50,6 +51,7 @@
5051

5152
from optimum.exporters import TasksManager
5253

54+
from ...exporters.openvino import main_export
5355
from ..utils.import_utils import is_timm_available, is_timm_version
5456
from .modeling_base import OVBaseModel
5557
from .utils import _is_timm_ov_dir
@@ -411,6 +413,48 @@ def forward(
411413
)
412414
return BaseModelOutput(last_hidden_state=last_hidden_state)
413415

416+
@classmethod
417+
def _from_transformers(
418+
cls,
419+
model_id: str,
420+
config: PretrainedConfig,
421+
use_auth_token: Optional[Union[bool, str]] = None,
422+
revision: Optional[str] = None,
423+
force_download: bool = False,
424+
cache_dir: Optional[str] = None,
425+
subfolder: str = "",
426+
local_files_only: bool = False,
427+
task: Optional[str] = None,
428+
trust_remote_code: bool = False,
429+
load_in_8bit: Optional[bool] = None,
430+
load_in_4bit: Optional[bool] = None,
431+
**kwargs,
432+
):
433+
save_dir = TemporaryDirectory()
434+
save_dir_path = Path(save_dir.name)
435+
436+
# If load_in_8bit is not specified then compression_option should be set to None and will be set by default in main_export depending on the model size
437+
compression_option = "fp32" if load_in_8bit is not None else None
438+
439+
# OVModelForFeatureExtraction works with Transformers type of models, thus even sentence-transformers models are loaded as such.
440+
main_export(
441+
model_name_or_path=model_id,
442+
output=save_dir_path,
443+
task=task or cls.export_feature,
444+
subfolder=subfolder,
445+
revision=revision,
446+
cache_dir=cache_dir,
447+
use_auth_token=use_auth_token,
448+
local_files_only=local_files_only,
449+
force_download=force_download,
450+
trust_remote_code=trust_remote_code,
451+
compression_option=compression_option,
452+
library_name="transformers",
453+
)
454+
455+
config.save_pretrained(save_dir_path)
456+
return cls._from_pretrained(model_id=save_dir_path, config=config, load_in_8bit=load_in_8bit, **kwargs)
457+
414458

415459
MASKED_LM_EXAMPLE = r"""
416460
Example of masked language modeling using `transformers.pipelines`:

0 commit comments

Comments
 (0)