Skip to content

Commit 75db113

Browse files
committed
add test
1 parent 5ba91d7 commit 75db113

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

tests/openvino/test_export.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import unittest
17+
from pathlib import Path
18+
from tempfile import TemporaryDirectory
19+
from typing import Optional
20+
21+
from parameterized import parameterized
22+
from transformers import AutoConfig
23+
from utils_tests import MODEL_NAMES
24+
25+
from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
26+
from optimum.exporters.openvino import export_from_model
27+
from optimum.exporters.tasks import TasksManager
28+
from optimum.intel import (
29+
OVModelForAudioClassification,
30+
OVModelForCausalLM,
31+
OVModelForFeatureExtraction,
32+
OVModelForImageClassification,
33+
OVModelForMaskedLM,
34+
OVModelForPix2Struct,
35+
OVModelForQuestionAnswering,
36+
OVModelForSeq2SeqLM,
37+
OVModelForSequenceClassification,
38+
OVModelForSpeechSeq2Seq,
39+
OVModelForTokenClassification,
40+
OVStableDiffusionPipeline,
41+
OVStableDiffusionXLImg2ImgPipeline,
42+
OVStableDiffusionXLPipeline,
43+
)
44+
from optimum.intel.openvino.modeling_base import OVBaseModel
45+
from optimum.utils.save_utils import maybe_load_preprocessors
46+
47+
48+
class ExportModelTest(unittest.TestCase):
49+
SUPPORTED_ARCHITECTURES = {
50+
"bert": OVModelForMaskedLM,
51+
"pix2struct": OVModelForPix2Struct,
52+
"t5": OVModelForSeq2SeqLM,
53+
"bart": OVModelForSeq2SeqLM,
54+
"gpt2": OVModelForCausalLM,
55+
"distilbert": OVModelForQuestionAnswering,
56+
"albert": OVModelForSequenceClassification,
57+
"vit": OVModelForImageClassification,
58+
"roberta": OVModelForTokenClassification,
59+
"wav2vec2": OVModelForAudioClassification,
60+
"whisper": OVModelForSpeechSeq2Seq,
61+
"blenderbot": OVModelForFeatureExtraction,
62+
"stable-diffusion": OVStableDiffusionPipeline,
63+
"stable-diffusion-xl": OVStableDiffusionXLPipeline,
64+
"stable-diffusion-xl-refiner": OVStableDiffusionXLImg2ImgPipeline,
65+
}
66+
67+
def _openvino_export(
68+
self,
69+
model_type: str,
70+
compression_option: Optional[str] = None,
71+
stateful: bool = True,
72+
):
73+
auto_model = self.SUPPORTED_ARCHITECTURES[model_type]
74+
task = auto_model.export_feature
75+
model_name = MODEL_NAMES[model_type]
76+
library_name = TasksManager.infer_library_from_model(model_name)
77+
loading_kwargs = {"attn_implementation": "eager"} if model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED else {}
78+
79+
if library_name == "timm":
80+
model_class = TasksManager.get_model_class_for_task(task, library=library_name)
81+
model = model_class(f"hf_hub:{model_name}", pretrained=True, exportable=True)
82+
TasksManager.standardize_model_attributes(model_name, model, library_name=library_name)
83+
else:
84+
config = AutoConfig.from_pretrained(model_name)
85+
model_class = TasksManager.get_model_class_for_task(task, model_type=config.model_type.replace("_", "-"))
86+
model = model_class.from_pretrained(model_name, **loading_kwargs)
87+
88+
if model.config.model_type == "pix2struct":
89+
preprocessors = maybe_load_preprocessors(model_name)
90+
else:
91+
preprocessors = None
92+
93+
supported_tasks = (task, task + "-with-past") if "text-generation" in task else (task,)
94+
for supported_task in supported_tasks:
95+
with TemporaryDirectory() as tmpdirname:
96+
export_from_model(
97+
model=model,
98+
output=Path(tmpdirname),
99+
task=supported_task,
100+
preprocessors=preprocessors,
101+
compression_option=compression_option,
102+
stateful=stateful,
103+
)
104+
105+
use_cache = supported_task.endswith("-with-past")
106+
ov_model = auto_model.from_pretrained(tmpdirname, use_cache=use_cache)
107+
self.assertIsInstance(ov_model, OVBaseModel)
108+
109+
if "text-generation" in task:
110+
self.assertEqual(ov_model.use_cache, use_cache)
111+
112+
if task == "text-generation":
113+
self.assertEqual(ov_model.stateful, stateful)
114+
115+
@parameterized.expand(SUPPORTED_ARCHITECTURES)
116+
def test_export(self, model_type: str):
117+
self._openvino_export(model_type)

0 commit comments

Comments
 (0)