Skip to content

Commit 8911601

Browse files
committed
fix opset
1 parent 5658213 commit 8911601

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

optimum/exporters/openvino/model_configs.py

+16
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
UNetOnnxConfig,
2929
VaeDecoderOnnxConfig,
3030
VaeEncoderOnnxConfig,
31+
Wav2Vec2OnnxConfig,
3132
)
3233
from optimum.exporters.tasks import TasksManager
3334
from optimum.utils import DEFAULT_DUMMY_SHAPES
@@ -87,6 +88,21 @@ def init_model_configs():
8788
register_in_tasks_manager = TasksManager.create_register("openvino", overwrite_existing=True)
8889

8990

91+
@register_in_tasks_manager(
92+
"wav2vec2",
93+
*[
94+
"feature-extraction",
95+
"automatic-speech-recognition",
96+
"audio-classification",
97+
"audio-frame-classification",
98+
"audio-xvector",
99+
],
100+
library_name="transformers",
101+
)
102+
class Wav2Vec2OpenVINOConfig(Wav2Vec2OnnxConfig):
103+
DEFAULT_ONNX_OPSET = 14
104+
105+
90106
@register_in_tasks_manager("baichuan", *["text-generation", "text-generation-with-past"], library_name="transformers")
91107
class BaichaunOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
92108
DEFAULT_ONNX_OPSET = 13

optimum/intel/openvino/trainer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -906,17 +906,17 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
906906
output_path = os.path.join(output_dir, OV_XML_FILE_NAME)
907907
self.compression_controller.prepare_for_export()
908908
model_type = self.model.config.model_type.replace("_", "-")
909-
onnx_config_class = TasksManager.get_exporter_config_constructor(
910-
exporter="onnx",
909+
exporter_config_class = TasksManager.get_exporter_config_constructor(
910+
exporter="openvino",
911911
model=self.model,
912912
task=self.task,
913913
model_type=model_type,
914914
)
915915

916916
if self.task == "text-generation":
917-
onnx_config = onnx_config_class(self.model.config, use_past=self.model.config.use_cache)
917+
onnx_config = exporter_config_class(self.model.config, use_past=self.model.config.use_cache)
918918
else:
919-
onnx_config = onnx_config_class(self.model.config)
919+
onnx_config = exporter_config_class(self.model.config)
920920

921921
num_parameters = self.model.num_parameters()
922922
save_as_external_data = use_external_data_format(num_parameters) or self.ov_config.save_onnx_model

0 commit comments

Comments
 (0)