|
19 | 19 | from typing import TYPE_CHECKING, Optional
|
20 | 20 |
|
21 | 21 | from ...exporters import TasksManager
|
| 22 | +from ...intel.utils.import_utils import DIFFUSERS_IMPORT_ERROR, is_diffusers_available |
22 | 23 | from ..base import BaseOptimumCLICommand, CommandInfo
|
23 | 24 |
|
24 | 25 |
|
@@ -104,6 +105,16 @@ def parse_args_openvino(parser: "ArgumentParser"):
|
104 | 105 | default=None,
|
105 | 106 | help=("The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization."),
|
106 | 107 | )
|
| 108 | + optional_group.add_argument( |
| 109 | + "--dataset", |
| 110 | + type=str, |
| 111 | + default=None, |
| 112 | + help=( |
| 113 | + "The dataset used for data-aware compression or quantization with NNCF. " |
| 114 | + "You can use the one from the list ['wikitext2','c4','c4-new','ptb','ptb-new'] for LLLMs " |
| 115 | + "or ['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit'] for diffusion models." |
| 116 | + ), |
| 117 | + ) |
107 | 118 | optional_group.add_argument(
|
108 | 119 | "--disable-stateful",
|
109 | 120 | action="store_true",
|
@@ -195,20 +206,59 @@ def run(self):
|
195 | 206 | )
|
196 | 207 | quantization_config["sym"] = "asym" not in self.args.weight_format
|
197 | 208 | quantization_config["group_size"] = 128 if "128" in self.args.weight_format else 64
|
| 209 | + quantization_config["dataset"] = self.args.dataset |
198 | 210 | ov_config = OVConfig(quantization_config=quantization_config)
|
199 | 211 |
|
200 |
| - # TODO : add input shapes |
201 |
| - main_export( |
202 |
| - model_name_or_path=self.args.model, |
203 |
| - output=self.args.output, |
204 |
| - task=self.args.task, |
205 |
| - framework=self.args.framework, |
206 |
| - cache_dir=self.args.cache_dir, |
207 |
| - trust_remote_code=self.args.trust_remote_code, |
208 |
| - pad_token_id=self.args.pad_token_id, |
209 |
| - ov_config=ov_config, |
210 |
| - stateful=not self.args.disable_stateful, |
211 |
| - convert_tokenizer=self.args.convert_tokenizer, |
212 |
| - library_name=self.args.library |
213 |
| - # **input_shapes, |
214 |
| - ) |
| 212 | + library_name = TasksManager.infer_library_from_model(self.args.model) |
| 213 | + |
| 214 | + if ( |
| 215 | + library_name == "diffusers" |
| 216 | + and ov_config |
| 217 | + and ov_config.quantization_config |
| 218 | + and ov_config.quantization_config.dataset is not None |
| 219 | + ): |
| 220 | + if not is_diffusers_available(): |
| 221 | + raise ValueError(DIFFUSERS_IMPORT_ERROR.format("Export of diffusers models")) |
| 222 | + |
| 223 | + from diffusers import DiffusionPipeline |
| 224 | + |
| 225 | + diffusers_config = DiffusionPipeline.load_config(self.args.model) |
| 226 | + class_name = diffusers_config.get("_class_name", None) |
| 227 | + |
| 228 | + if class_name == "LatentConsistencyModelPipeline": |
| 229 | + from optimum.intel import OVLatentConsistencyModelPipeline |
| 230 | + |
| 231 | + model_cls = OVLatentConsistencyModelPipeline |
| 232 | + |
| 233 | + elif class_name == "StableDiffusionXLPipeline": |
| 234 | + from optimum.intel import OVStableDiffusionXLPipeline |
| 235 | + |
| 236 | + model_cls = OVStableDiffusionXLPipeline |
| 237 | + elif class_name == "StableDiffusionPipeline": |
| 238 | + from optimum.intel import OVStableDiffusionPipeline |
| 239 | + |
| 240 | + model_cls = OVStableDiffusionPipeline |
| 241 | + else: |
| 242 | + raise NotImplementedError(f"Quantization in hybrid mode isn't supported for class {class_name}.") |
| 243 | + |
| 244 | + model = model_cls.from_pretrained( |
| 245 | + self.args.model, export=True, quantization_config=ov_config.quantization_config |
| 246 | + ) |
| 247 | + model.save_pretrained(self.args.output) |
| 248 | + |
| 249 | + else: |
| 250 | + # TODO : add input shapes |
| 251 | + main_export( |
| 252 | + model_name_or_path=self.args.model, |
| 253 | + output=self.args.output, |
| 254 | + task=self.args.task, |
| 255 | + framework=self.args.framework, |
| 256 | + cache_dir=self.args.cache_dir, |
| 257 | + trust_remote_code=self.args.trust_remote_code, |
| 258 | + pad_token_id=self.args.pad_token_id, |
| 259 | + ov_config=ov_config, |
| 260 | + stateful=not self.args.disable_stateful, |
| 261 | + convert_tokenizer=self.args.convert_tokenizer, |
| 262 | + library_name=library_name, |
| 263 | + # **input_shapes, |
| 264 | + ) |
0 commit comments