Skip to content

Commit 4c98ac3

Browse files
committed
fix
1 parent c62d6ca commit 4c98ac3

File tree

2 files changed

+34
-27
lines changed

2 files changed

+34
-27
lines changed

optimum/intel/openvino/modeling_decoder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS
3636
from .configuration import OVWeightQuantizationConfig
3737
from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel
38-
from .quantization import compress_decoder_weights
38+
from .quantization import _int4_weight_only_quantization
3939
from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE
4040

4141

@@ -595,7 +595,7 @@ def _from_pretrained(
595595
causal_model = init_cls(model=model, config=config, model_save_dir=model_cache_path.parent, **kwargs)
596596

597597
if load_in_4bit:
598-
compress_decoder_weights(causal_model, quantization_config)
598+
_int4_weight_only_quantization(causal_model, quantization_config)
599599
return causal_model
600600

601601

optimum/intel/openvino/quantization.py

+32-25
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from openvino.runtime import Core, Tensor
3333
from torch.utils._pytree import tree_map
3434
from torch.utils.data import DataLoader, RandomSampler
35-
from transformers import DataCollator, PreTrainedModel, default_data_collator
35+
from transformers import AutoTokenizer, DataCollator, PreTrainedModel, default_data_collator
3636
from transformers.pytorch_utils import Conv1D
3737

3838
from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed
@@ -46,7 +46,6 @@
4646
from ..utils.modeling_utils import get_model_device
4747
from .configuration import OVConfig, OVWeightQuantizationConfig, _check_default_4bit_configs
4848
from .modeling_base import OVBaseModel
49-
from .modeling_decoder import OVBaseDecoderModel
5049
from .utils import (
5150
MAX_ONNX_OPSET,
5251
MIN_ONNX_QDQ_OPSET,
@@ -233,27 +232,29 @@ def quantize(
233232
)
234233
ov_config = ov_config or quantization_config
235234

236-
if isinstance(self.model, OVBaseDecoderModel) and self.model.use_cache:
237-
self._quantize_ovcausallm(
238-
calibration_dataset,
239-
save_directory,
240-
batch_size,
241-
data_collator,
242-
remove_unused_columns,
243-
weights_only,
244-
ov_config,
245-
**kwargs,
246-
)
247-
elif isinstance(self.model, OVBaseModel):
248-
self._quantize_ovbasemodel(
249-
calibration_dataset,
250-
save_directory,
251-
batch_size,
252-
data_collator,
253-
remove_unused_columns,
254-
weights_only,
255-
**kwargs,
256-
)
235+
if isinstance(self.model, OVBaseModel):
236+
if self.model.export_feature == "text-generation" and self.model.use_cache:
237+
self._quantize_ovcausallm(
238+
calibration_dataset,
239+
save_directory,
240+
batch_size,
241+
data_collator,
242+
remove_unused_columns,
243+
weights_only,
244+
ov_config,
245+
**kwargs,
246+
)
247+
else:
248+
self._quantize_ovbasemodel(
249+
calibration_dataset,
250+
save_directory,
251+
batch_size,
252+
data_collator,
253+
remove_unused_columns,
254+
weights_only,
255+
**kwargs,
256+
)
257+
257258
elif isinstance(self.model, torch.nn.Module):
258259
self._quantize_torchmodel(
259260
calibration_dataset,
@@ -276,6 +277,7 @@ def _get_compression_options(self, config: OVConfig):
276277
options["ratio"] = config.compression["ratio"]
277278
return options
278279

280+
# TODO : add ov_config
279281
def _quantize_ovbasemodel(
280282
self,
281283
calibration_dataset: Dataset,
@@ -333,7 +335,7 @@ def _quantize_ovcausallm(
333335
quantization_config = OVWeightQuantizationConfig(mode=nncf.CompressWeightsMode.INT8_SYM)
334336
self.model.model = nncf.compress_weights(self.model.model)
335337
else:
336-
compress_decoder_weights(self.model, quantization_config)
338+
_int4_weight_only_quantization(self.model, quantization_config)
337339

338340
self.model.save_pretrained(save_directory)
339341
return
@@ -580,7 +582,12 @@ def _remove_unused_columns(self, dataset: Dataset):
580582
return dataset.remove_columns(ignored_columns)
581583

582584

583-
def compress_decoder_weights(model, quantization_config: Union[OVWeightQuantizationConfig, Dict] = None):
585+
def _int4_weight_only_quantization(
586+
model: OVBaseModel, quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None
587+
):
588+
if model.export_feature != "text-generation":
589+
raise ValueError("Only `OVModelForCausalLM` are supported for now")
590+
584591
quantization_config = quantization_config or _check_default_4bit_configs(model.config)
585592
ov_model = model.model
586593

0 commit comments

Comments
 (0)