|
20 | 20 |
|
21 | 21 | from ...exporters import TasksManager
|
22 | 22 | from ..base import BaseOptimumCLICommand, CommandInfo
|
| 23 | +from ...intel.utils.import_utils import DIFFUSERS_IMPORT_ERROR, is_diffusers_available |
23 | 24 |
|
24 | 25 |
|
25 | 26 | logger = logging.getLogger(__name__)
|
@@ -209,31 +210,51 @@ def run(self):
|
209 | 210 | ov_config = OVConfig(quantization_config=quantization_config)
|
210 | 211 |
|
211 | 212 | library_name = TasksManager.infer_library_from_model(self.args.model)
|
212 |
| - task = get_relevant_task(self.args.task, self.args.model) |
213 |
| - saved_dir = self.args.output |
214 | 213 |
|
215 | 214 | if library_name == "diffusers" and ov_config and ov_config.quantization_config.get("dataset"):
|
216 |
| - import tempfile |
217 |
| - from copy import deepcopy |
218 |
| - saved_dir = tempfile.mkdtemp() |
219 |
| - quantization_config = deepcopy(ov_config.quantization_config) |
220 |
| - ov_config.quantization_config = {} |
221 |
| - |
222 |
| - # TODO : add input shapes |
223 |
| - main_export( |
224 |
| - model_name_or_path=self.args.model, |
225 |
| - output=saved_dir, |
226 |
| - task=task, |
227 |
| - framework=self.args.framework, |
228 |
| - cache_dir=self.args.cache_dir, |
229 |
| - trust_remote_code=self.args.trust_remote_code, |
230 |
| - pad_token_id=self.args.pad_token_id, |
231 |
| - ov_config=ov_config, |
232 |
| - stateful=not self.args.disable_stateful, |
233 |
| - convert_tokenizer=self.args.convert_tokenizer, |
234 |
| - library_name=self.args.library |
235 |
| - # **input_shapes, |
236 |
| - ) |
237 | 215 |
|
238 |
| - if saved_dir != self.args.output: |
239 |
| - export_optimized_diffusion_model(saved_dir, self.args.output, task, quantization_config) |
| 216 | + if not is_diffusers_available(): |
| 217 | + raise ValueError(DIFFUSERS_IMPORT_ERROR.format("Export of diffusers models")) |
| 218 | + |
| 219 | + from diffusers import DiffusionPipeline |
| 220 | + |
| 221 | + diffusers_config = DiffusionPipeline.load_config(self.args.model) |
| 222 | + class_name = diffusers_config.get("_class_name", None) |
| 223 | + |
| 224 | + if class_name == "LatentConsistencyModelPipeline": |
| 225 | + |
| 226 | + from optimum.intel import OVLatentConsistencyModelPipeline |
| 227 | + |
| 228 | + model_cls = OVLatentConsistencyModelPipeline |
| 229 | + |
| 230 | + elif class_name == "StableDiffusionXLPipeline": |
| 231 | + |
| 232 | + from optimum.intel import OVStableDiffusionXLPipeline |
| 233 | + |
| 234 | + model_cls = OVStableDiffusionXLPipeline |
| 235 | + elif class_name == "StableDiffusionPipeline": |
| 236 | + from optimum.intel import OVStableDiffusionPipeline |
| 237 | + |
| 238 | + model_cls = OVStableDiffusionPipeline |
| 239 | + else: |
| 240 | + raise NotImplementedError(f"Quantization in hybrid mode isn't supported for class {class_name}.") |
| 241 | + |
| 242 | + model = model_cls.from_pretrained(self.args.model, export=True, quantization_config=ov_config.quantization_config) |
| 243 | + model.save_pretrained(self.args.output) |
| 244 | + |
| 245 | + else: |
| 246 | + # TODO : add input shapes |
| 247 | + main_export( |
| 248 | + model_name_or_path=self.args.model, |
| 249 | + output=self.args.output, |
| 250 | + task=self.args.task, |
| 251 | + framework=self.args.framework, |
| 252 | + cache_dir=self.args.cache_dir, |
| 253 | + trust_remote_code=self.args.trust_remote_code, |
| 254 | + pad_token_id=self.args.pad_token_id, |
| 255 | + ov_config=ov_config, |
| 256 | + stateful=not self.args.disable_stateful, |
| 257 | + convert_tokenizer=self.args.convert_tokenizer, |
| 258 | + library_name=library_name, |
| 259 | + # **input_shapes, |
| 260 | + ) |
0 commit comments