Skip to content

Commit 7dc3257

Browse files
committed
WIP|
1 parent 2e7c556 commit 7dc3257

File tree

9 files changed

+460
-14
lines changed

9 files changed

+460
-14
lines changed

optimum/exporters/openvino/__main__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import gc
16+
import importlib
1617
import logging
1718
import operator
1819
import warnings
@@ -192,7 +193,6 @@ def main_export(
192193
```
193194
"""
194195
from optimum.exporters.openvino.convert import export_from_model
195-
196196
if use_auth_token is not None:
197197
warnings.warn(
198198
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
@@ -214,6 +214,7 @@ def main_export(
214214
revision=revision,
215215
cache_dir=cache_dir,
216216
token=token,
217+
library_name=library_name
217218
)
218219
if library_name == "sentence_transformers":
219220
logger.warning(
@@ -233,6 +234,9 @@ def main_export(
233234
library_name=library_name,
234235
)
235236

237+
logger.warn(task)
238+
logger.warn(library_name)
239+
236240
do_gptq_patching = False
237241
do_quant_patching = False
238242
custom_architecture = False
@@ -447,6 +451,7 @@ class StoreAttr(object):
447451
device=device,
448452
trust_remote_code=trust_remote_code,
449453
patch_16bit_model=patch_16bit,
454+
library_name=library_name,
450455
**kwargs_shapes,
451456
)
452457

optimum/exporters/openvino/convert.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ def export_from_model(
606606
device: str = "cpu",
607607
trust_remote_code: bool = False,
608608
patch_16bit_model: bool = False,
609+
library_name: Optional[str] = None,
609610
**kwargs_shapes,
610611
):
611612
model_kwargs = model_kwargs or {}
@@ -615,9 +616,9 @@ def export_from_model(
615616
f"Compression of the weights to {ov_config.quantization_config} requires nncf, please install it with `pip install nncf`"
616617
)
617618

618-
library_name = _infer_library_from_model_or_model_class(model)
619+
library_name = _infer_library_from_model_or_model_class(model, library_name=library_name)
619620
if library_name != "open_clip":
620-
TasksManager.standardize_model_attributes(model)
621+
TasksManager.standardize_model_attributes(model, library_name=library_name)
621622

622623
if hasattr(model.config, "export_model_type"):
623624
model_type = model.config.export_model_type.replace("_", "-")
@@ -630,7 +631,7 @@ def export_from_model(
630631
task = TasksManager.map_from_synonym(task)
631632
else:
632633
try:
633-
task = TasksManager._infer_task_from_model_or_model_class(model=model)
634+
task = TasksManager._infer_task_from_model_or_model_class(model=model, library_name=library_name)
634635
except (ValueError, KeyError) as e:
635636
raise RuntimeError(
636637
f"The model task could not be automatically inferred in `export_from_model`. Please provide the argument `task` with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"

0 commit comments

Comments
 (0)