Skip to content

Commit e6fadb1

Browse files
authored
Merge pull request #700 from nikita-savelyevv/refactor-sd-calibration-data-collection
Refactor SD calibration data collection
2 parents 9bb4334 + 068236d commit e6fadb1

File tree

5 files changed

+213
-148
lines changed

5 files changed

+213
-148
lines changed

notebooks/openvino/stable_diffusion_hybrid_quantization.ipynb

+10-4
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@
5252
"import transformers\n",
5353
"from pathlib import Path\n",
5454
"from openvino.runtime import Core\n",
55-
"from optimum.intel import OVStableDiffusionPipeline, OVWeightQuantizationConfig\n",
55+
"from optimum.intel import OVConfig, OVQuantizer, OVStableDiffusionPipeline, OVWeightQuantizationConfig\n",
56+
"from optimum.intel.openvino.configuration import OVQuantizationMethod\n",
5657
"\n",
5758
"transformers.logging.set_verbosity_error()\n",
5859
"datasets.logging.set_verbosity_error()"
@@ -198,9 +199,14 @@
198199
},
199200
"outputs": [],
200201
"source": [
201-
"quantization_config = OVWeightQuantizationConfig(bits=8, dataset=calibration_dataset, num_samples=NUM_SAMPLES)\n",
202-
"int8_pipe = OVStableDiffusionPipeline.from_pretrained(model_id=MODEL_ID, export=True, quantization_config=quantization_config)\n",
203-
"int8_pipe.save_pretrained(int8_model_path)"
202+
"int8_pipe = OVStableDiffusionPipeline.from_pretrained(model_id=MODEL_ID, export=True)\n",
203+
"quantization_config = OVWeightQuantizationConfig(bits=8, num_samples=NUM_SAMPLES, quant_method=OVQuantizationMethod.HYBRID)\n",
204+
"quantizer = OVQuantizer(int8_pipe)\n",
205+
"quantizer.quantize(\n",
206+
" ov_config=OVConfig(quantization_config=quantization_config),\n",
207+
" calibration_dataset=calibration_dataset,\n",
208+
" save_directory=int8_model_path\n",
209+
")"
204210
]
205211
},
206212
{

optimum/intel/openvino/configuration.py

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757

5858
class OVQuantizationMethod(str, Enum):
5959
DEFAULT = "default"
60+
HYBRID = "hybrid"
6061

6162

6263
@dataclass

optimum/intel/openvino/modeling_diffusion.py

+17-80
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,13 @@
5757
)
5858

5959
from ...exporters.openvino import main_export
60-
from .configuration import OVConfig, OVWeightQuantizationConfig
60+
from .configuration import OVConfig, OVQuantizationMethod, OVWeightQuantizationConfig
6161
from .loaders import OVTextualInversionLoaderMixin
6262
from .modeling_base import OVBaseModel
6363
from .utils import (
6464
ONNX_WEIGHTS_NAME,
6565
OV_TO_NP_TYPE,
6666
OV_XML_FILE_NAME,
67-
PREDEFINED_SD_DATASETS,
6867
_print_compiled_model_properties,
6968
)
7069

@@ -293,35 +292,27 @@ def _from_pretrained(
293292
else:
294293
kwargs[name] = load_method(new_model_save_dir)
295294

296-
quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)
297-
298295
unet_path = new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name
299-
if quantization_config is not None and quantization_config.dataset is not None:
300-
# load the UNet model uncompressed to apply hybrid quantization further
301-
unet = cls.load_model(unet_path)
302-
# Apply weights compression to other `components` without dataset
303-
weight_quantization_params = {
304-
param: value for param, value in quantization_config.__dict__.items() if param != "dataset"
305-
}
306-
weight_quantization_config = OVWeightQuantizationConfig.from_dict(weight_quantization_params)
307-
else:
308-
weight_quantization_config = quantization_config
309-
unet = cls.load_model(unet_path, weight_quantization_config)
310-
311296
components = {
312297
"vae_encoder": new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name,
313298
"vae_decoder": new_model_save_dir / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name,
314299
"text_encoder": new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name,
315300
"text_encoder_2": new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER / text_encoder_2_file_name,
316301
}
317302

318-
for key, value in components.items():
319-
components[key] = cls.load_model(value, weight_quantization_config) if value.is_file() else None
320-
321303
if model_save_dir is None:
322304
model_save_dir = new_model_save_dir
323305

324-
if quantization_config is not None and quantization_config.dataset is not None:
306+
quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)
307+
if quantization_config is None or quantization_config.dataset is None:
308+
unet = cls.load_model(unet_path, quantization_config)
309+
for key, value in components.items():
310+
components[key] = cls.load_model(value, quantization_config) if value.is_file() else None
311+
else:
312+
# Load uncompressed models to apply hybrid quantization further
313+
unet = cls.load_model(unet_path)
314+
for key, value in components.items():
315+
components[key] = cls.load_model(value) if value.is_file() else None
325316
sd_model = cls(unet=unet, config=config, model_save_dir=model_save_dir, **components, **kwargs)
326317

327318
supported_pipelines = (
@@ -332,12 +323,14 @@ def _from_pretrained(
332323
if not isinstance(sd_model, supported_pipelines):
333324
raise NotImplementedError(f"Quantization in hybrid mode is not supported for {cls.__name__}")
334325

335-
nsamples = quantization_config.num_samples if quantization_config.num_samples else 200
336-
unet_inputs = sd_model._prepare_unet_inputs(quantization_config.dataset, nsamples)
326+
from optimum.intel import OVQuantizer
337327

338-
from .quantization import _hybrid_quantization
328+
hybrid_quantization_config = deepcopy(quantization_config)
329+
hybrid_quantization_config.quant_method = OVQuantizationMethod.HYBRID
330+
quantizer = OVQuantizer(sd_model)
331+
quantizer.quantize(ov_config=OVConfig(quantization_config=hybrid_quantization_config))
339332

340-
unet = _hybrid_quantization(sd_model.unet.model, weight_quantization_config, dataset=unet_inputs)
333+
return sd_model
341334

342335
return cls(
343336
unet=unet,
@@ -348,62 +341,6 @@ def _from_pretrained(
348341
**kwargs,
349342
)
350343

351-
def _prepare_unet_inputs(
352-
self,
353-
dataset: Union[str, List[Any]],
354-
num_samples: int,
355-
height: Optional[int] = None,
356-
width: Optional[int] = None,
357-
seed: Optional[int] = 42,
358-
**kwargs,
359-
) -> Dict[str, Any]:
360-
self.compile()
361-
362-
size = self.unet.config.get("sample_size", 64) * self.vae_scale_factor
363-
height = height or min(size, 512)
364-
width = width or min(size, 512)
365-
366-
if isinstance(dataset, str):
367-
dataset = deepcopy(dataset)
368-
available_datasets = PREDEFINED_SD_DATASETS.keys()
369-
if dataset not in available_datasets:
370-
raise ValueError(
371-
f"""You have entered a string value for dataset. You can only choose between
372-
{list(available_datasets)}, but the {dataset} was found"""
373-
)
374-
375-
from datasets import load_dataset
376-
377-
dataset_metadata = PREDEFINED_SD_DATASETS[dataset]
378-
dataset = load_dataset(dataset, split=dataset_metadata["split"], streaming=True).shuffle(seed=seed)
379-
input_names = dataset_metadata["inputs"]
380-
dataset = dataset.select_columns(list(input_names.values()))
381-
382-
def transform_fn(data_item):
383-
return {inp_name: data_item[column] for inp_name, column in input_names.items()}
384-
385-
else:
386-
387-
def transform_fn(data_item):
388-
return data_item if isinstance(data_item, (list, dict)) else [data_item]
389-
390-
from .quantization import InferRequestWrapper
391-
392-
calibration_data = []
393-
self.unet.request = InferRequestWrapper(self.unet.request, calibration_data)
394-
395-
for inputs in dataset:
396-
inputs = transform_fn(inputs)
397-
if isinstance(inputs, dict):
398-
self.__call__(**inputs, height=height, width=width)
399-
else:
400-
self.__call__(*inputs, height=height, width=width)
401-
if len(calibration_data) >= num_samples:
402-
break
403-
404-
self.unet.request = self.unet.request.request
405-
return calibration_data[:num_samples]
406-
407344
@classmethod
408345
def _from_transformers(
409346
cls,

0 commit comments

Comments
 (0)