Skip to content

Commit 3438ec9

Browse files
authored
Merge branch 'main' into ea/relax_accelerate_deps
2 parents 3944382 + 55d419f commit 3438ec9

File tree

2 files changed

+30
-92
lines changed

2 files changed

+30
-92
lines changed

optimum/intel/openvino/quantization.py

+29-91
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,6 @@
5858
if TYPE_CHECKING:
5959
from datasets import Dataset
6060

61-
# TODO : remove as unused
62-
_COMPRESSION_OPTIONS = {
63-
"int8": {"mode": nncf.CompressWeightsMode.INT8},
64-
"int4_sym_g128": {"mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 128},
65-
"int4_asym_g128": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 128},
66-
"int4_sym_g64": {"mode": nncf.CompressWeightsMode.INT4_SYM, "group_size": 64},
67-
"int4_asym_g64": {"mode": nncf.CompressWeightsMode.INT4_ASYM, "group_size": 64},
68-
}
69-
7061
register_module(ignored_algorithms=[])(Conv1D)
7162

7263
core = Core()
@@ -241,27 +232,16 @@ def quantize(
241232
ov_config = ov_config or quantization_config
242233

243234
if isinstance(self.model, OVBaseModel):
244-
if self.model.export_feature == "text-generation" and self.model.use_cache:
245-
self._quantize_ovcausallm(
246-
calibration_dataset,
247-
save_directory,
248-
batch_size,
249-
data_collator,
250-
remove_unused_columns,
251-
weights_only,
252-
ov_config,
253-
**kwargs,
254-
)
255-
else:
256-
self._quantize_ovbasemodel(
257-
calibration_dataset,
258-
save_directory,
259-
batch_size,
260-
data_collator,
261-
remove_unused_columns,
262-
weights_only,
263-
**kwargs,
264-
)
235+
self._quantize_ovbasemodel(
236+
calibration_dataset,
237+
save_directory,
238+
batch_size,
239+
data_collator,
240+
remove_unused_columns,
241+
weights_only,
242+
ov_config,
243+
**kwargs,
244+
)
265245

266246
elif isinstance(self.model, torch.nn.Module):
267247
self._quantize_torchmodel(
@@ -277,51 +257,7 @@ def quantize(
277257
else:
278258
raise TypeError(f"Unsupported model type: {type(self.model)}")
279259

280-
def _get_compression_options(self, config: OVConfig):
281-
options = {}
282-
if config is not None and "type" in config.compression:
283-
options = _COMPRESSION_OPTIONS[config.compression["type"]]
284-
if "ratio" in config.compression:
285-
options["ratio"] = config.compression["ratio"]
286-
return options
287-
288260
def _quantize_ovbasemodel(
289-
self,
290-
calibration_dataset: "Dataset",
291-
save_directory: Union[str, Path],
292-
batch_size: int = 1,
293-
data_collator: Optional[DataCollator] = None,
294-
remove_unused_columns: bool = True,
295-
weights_only: bool = False,
296-
**kwargs,
297-
):
298-
save_directory = Path(save_directory)
299-
save_directory.mkdir(parents=True, exist_ok=True)
300-
301-
if weights_only:
302-
self.model.model = nncf.compress_weights(self.model.model)
303-
self.model.save_pretrained(save_directory)
304-
return
305-
306-
calibration_dataloader = self._get_calibration_dataloader(
307-
calibration_dataset=calibration_dataset,
308-
batch_size=batch_size,
309-
remove_unused_columns=remove_unused_columns,
310-
data_collator=data_collator,
311-
)
312-
313-
quantization_dataset = nncf.Dataset(calibration_dataloader, lambda x: x)
314-
quantized_model = nncf.quantize(
315-
self.model.model,
316-
quantization_dataset,
317-
model_type=nncf.ModelType.TRANSFORMER if not kwargs.get("model_type") else kwargs.get("model_type"),
318-
fast_bias_correction=kwargs.get("fast_bias_correction", True),
319-
**kwargs,
320-
)
321-
self.model.model = quantized_model
322-
self.model.save_pretrained(save_directory)
323-
324-
def _quantize_ovcausallm(
325261
self,
326262
calibration_dataset: "Dataset",
327263
save_directory: Union[str, Path],
@@ -336,11 +272,11 @@ def _quantize_ovcausallm(
336272
save_directory.mkdir(parents=True, exist_ok=True)
337273

338274
if weights_only:
339-
quantization_config = None if ov_config is None else ov_config.quantization_config
340-
if quantization_config is None:
341-
# Use default 8-bit compression
342-
quantization_config = OVWeightQuantizationConfig(bits=8, sym=True)
343-
_weight_only_quantization(self.model, quantization_config)
275+
# Use default 8-bit compression if not provided
276+
q_config = (
277+
OVWeightQuantizationConfig(bits=8, sym=True) if ov_config is None else ov_config.quantization_config
278+
)
279+
_weight_only_quantization(self.model, q_config)
344280

345281
self.model.save_pretrained(save_directory)
346282
return
@@ -352,21 +288,23 @@ def _quantize_ovcausallm(
352288
data_collator=data_collator,
353289
)
354290

355-
# Prefeth past_key_values
356-
self.model.update_pkv_precision(True)
357-
self.model.compile()
358-
subset_size = kwargs.get("subset_size", 300)
359-
data_cache = []
291+
if self.model.export_feature == "text-generation" and self.model.use_cache:
292+
# Prefeth past_key_values
293+
self.model.update_pkv_precision(True)
294+
self.model.compile()
295+
subset_size = kwargs.get("subset_size", 300)
296+
data_cache = []
360297

361-
self.model.request = InferRequestWrapper(self.model.request, data_cache)
362-
for _, data in enumerate(calibration_dataloader):
363-
self.model.generate(**data, max_new_tokens=1)
364-
if len(data_cache) >= subset_size:
365-
break
366-
self.model.request = self.model.request.request
298+
self.model.request = InferRequestWrapper(self.model.request, data_cache)
299+
for _, data in enumerate(calibration_dataloader):
300+
self.model.generate(**data, max_new_tokens=1)
301+
if len(data_cache) >= subset_size:
302+
break
303+
self.model.request = self.model.request.request
304+
calibration_dataloader = data_cache
367305

368306
# Actual model quantization
369-
quantization_dataset = nncf.Dataset(data_cache, lambda x: x)
307+
quantization_dataset = nncf.Dataset(calibration_dataloader, lambda x: x)
370308
quantized_model = nncf.quantize(
371309
self.model.model,
372310
quantization_dataset,

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
INSTALL_REQUIRE = [
1515
"torch>=1.11",
16-
"optimum @ git+https://github.com/huggingface/optimum.git", # TODO : 1.17.0
16+
"optimum>=1.17.0",
1717
"transformers>=4.26.0",
1818
"datasets>=1.4.0",
1919
"sentencepiece",

0 commit comments

Comments
 (0)