Skip to content

Commit 27405ed

Browse files
echarlaixl-bat
authored andcommitted
Infer task by loading the diffusers config
1 parent 35410e9 commit 27405ed

File tree

1 file changed

+46
-25
lines changed

1 file changed

+46
-25
lines changed

optimum/commands/export/openvino.py

+46-25
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from ...exporters import TasksManager
2222
from ..base import BaseOptimumCLICommand, CommandInfo
23+
from ...intel.utils.import_utils import DIFFUSERS_IMPORT_ERROR, is_diffusers_available
2324

2425

2526
logger = logging.getLogger(__name__)
@@ -209,31 +210,51 @@ def run(self):
209210
ov_config = OVConfig(quantization_config=quantization_config)
210211

211212
library_name = TasksManager.infer_library_from_model(self.args.model)
212-
task = get_relevant_task(self.args.task, self.args.model)
213-
saved_dir = self.args.output
214213

215214
if library_name == "diffusers" and ov_config and ov_config.quantization_config.get("dataset"):
216-
import tempfile
217-
from copy import deepcopy
218-
saved_dir = tempfile.mkdtemp()
219-
quantization_config = deepcopy(ov_config.quantization_config)
220-
ov_config.quantization_config = {}
221-
222-
# TODO : add input shapes
223-
main_export(
224-
model_name_or_path=self.args.model,
225-
output=saved_dir,
226-
task=task,
227-
framework=self.args.framework,
228-
cache_dir=self.args.cache_dir,
229-
trust_remote_code=self.args.trust_remote_code,
230-
pad_token_id=self.args.pad_token_id,
231-
ov_config=ov_config,
232-
stateful=not self.args.disable_stateful,
233-
convert_tokenizer=self.args.convert_tokenizer,
234-
library_name=self.args.library
235-
# **input_shapes,
236-
)
237215

238-
if saved_dir != self.args.output:
239-
export_optimized_diffusion_model(saved_dir, self.args.output, task, quantization_config)
216+
if not is_diffusers_available():
217+
raise ValueError(DIFFUSERS_IMPORT_ERROR.format("Export of diffusers models"))
218+
219+
from diffusers import DiffusionPipeline
220+
221+
diffusers_config = DiffusionPipeline.load_config(self.args.model)
222+
class_name = diffusers_config.get("_class_name", None)
223+
224+
if class_name == "LatentConsistencyModelPipeline":
225+
226+
from optimum.intel import OVLatentConsistencyModelPipeline
227+
228+
model_cls = OVLatentConsistencyModelPipeline
229+
230+
elif class_name == "StableDiffusionXLPipeline":
231+
232+
from optimum.intel import OVStableDiffusionXLPipeline
233+
234+
model_cls = OVStableDiffusionXLPipeline
235+
elif class_name == "StableDiffusionPipeline":
236+
from optimum.intel import OVStableDiffusionPipeline
237+
238+
model_cls = OVStableDiffusionPipeline
239+
else:
240+
raise NotImplementedError(f"Quantization in hybrid mode isn't supported for class {class_name}.")
241+
242+
model = model_cls.from_pretrained(self.args.model, export=True, quantization_config=ov_config.quantization_config)
243+
model.save_pretrained(self.args.output)
244+
245+
else:
246+
# TODO : add input shapes
247+
main_export(
248+
model_name_or_path=self.args.model,
249+
output=self.args.output,
250+
task=self.args.task,
251+
framework=self.args.framework,
252+
cache_dir=self.args.cache_dir,
253+
trust_remote_code=self.args.trust_remote_code,
254+
pad_token_id=self.args.pad_token_id,
255+
ov_config=ov_config,
256+
stateful=not self.args.disable_stateful,
257+
convert_tokenizer=self.args.convert_tokenizer,
258+
library_name=library_name,
259+
# **input_shapes,
260+
)

0 commit comments

Comments
 (0)