Skip to content

Commit 0540b12

Browse files
eaidovaecharlaix
andauthored
Fix sentence transformer model export with openvino (huggingface#660)
* fix sentence transformer model export with openvino * Apply suggestions from code review Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> --------- Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
1 parent 402b9db commit 0540b12

File tree

3 files changed

+19
-0
lines changed

3 files changed

+19
-0
lines changed

optimum/commands/export/openvino.py

+9
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,14 @@ def parse_args_openvino(parser: "ArgumentParser"):
121121
help="Add converted tokenizer and detokenizer with OpenVINO Tokenizers",
122122
)
123123

124+
optional_group.add_argument(
125+
"--library",
126+
type=str,
127+
choices=["transformers", "diffusers", "timm", "sentence_transformers"],
128+
default=None,
129+
help=("The library on the model. If not provided, will attempt to infer the local checkpoint's library"),
130+
)
131+
124132

125133
class OVExportCommand(BaseOptimumCLICommand):
126134
COMMAND = CommandInfo(name="openvino", help="Export PyTorch models to OpenVINO IR.")
@@ -201,5 +209,6 @@ def run(self):
201209
ov_config=ov_config,
202210
stateful=not self.args.disable_stateful,
203211
convert_tokenizer=self.args.convert_tokenizer,
212+
library_name=self.args.library
204213
# **input_shapes,
205214
)

optimum/exporters/openvino/__main__.py

+8
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,18 @@ def main_export(
163163
original_task = task
164164
task = TasksManager.map_from_synonym(task)
165165
framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework)
166+
library_name_is_not_provided = library_name is None
166167
library_name = TasksManager.infer_library_from_model(
167168
model_name_or_path, subfolder=subfolder, library_name=library_name
168169
)
169170

171+
if library_name == "sentence_transformers" and library_name_is_not_provided:
172+
logger.warning(
173+
"Library name is not specified. There are multiple possible variants: `sentence_tenasformers`, `transformers`."
174+
"`transformers` will be selected. If you want to load your model with the `sentence-transformers` library instead, please set --library sentence_transformers"
175+
)
176+
library_name = "transformers"
177+
170178
if task == "auto":
171179
try:
172180
task = TasksManager.infer_task_from_model(model_name_or_path)

optimum/exporters/openvino/convert.py

+2
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,8 @@ def ts_patched_forward(*args, **kwargs):
382382

383383
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
384384
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
385+
if not ordered_dummy_inputs:
386+
ordered_dummy_inputs = dummy_inputs
385387
ordered_input_names = list(inputs)
386388
flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values())
387389
ov_model.validate_nodes_and_infer_types()

0 commit comments

Comments
 (0)