32
32
from openvino .runtime import Core , Tensor
33
33
from torch .utils ._pytree import tree_map
34
34
from torch .utils .data import DataLoader , RandomSampler
35
- from transformers import DataCollator , PreTrainedModel , default_data_collator
35
+ from transformers import AutoTokenizer , DataCollator , PreTrainedModel , default_data_collator
36
36
from transformers .pytorch_utils import Conv1D
37
37
38
38
from optimum .exporters .onnx .convert import check_dummy_inputs_are_allowed
46
46
from ..utils .modeling_utils import get_model_device
47
47
from .configuration import OVConfig , OVWeightQuantizationConfig , _check_default_4bit_configs
48
48
from .modeling_base import OVBaseModel
49
- from .modeling_decoder import OVBaseDecoderModel
50
49
from .utils import (
51
50
MAX_ONNX_OPSET ,
52
51
MIN_ONNX_QDQ_OPSET ,
@@ -233,27 +232,29 @@ def quantize(
233
232
)
234
233
ov_config = ov_config or quantization_config
235
234
236
- if isinstance (self .model , OVBaseDecoderModel ) and self .model .use_cache :
237
- self ._quantize_ovcausallm (
238
- calibration_dataset ,
239
- save_directory ,
240
- batch_size ,
241
- data_collator ,
242
- remove_unused_columns ,
243
- weights_only ,
244
- ov_config ,
245
- ** kwargs ,
246
- )
247
- elif isinstance (self .model , OVBaseModel ):
248
- self ._quantize_ovbasemodel (
249
- calibration_dataset ,
250
- save_directory ,
251
- batch_size ,
252
- data_collator ,
253
- remove_unused_columns ,
254
- weights_only ,
255
- ** kwargs ,
256
- )
235
+ if isinstance (self .model , OVBaseModel ):
236
+ if self .model .export_feature == "text-generation" and self .model .use_cache :
237
+ self ._quantize_ovcausallm (
238
+ calibration_dataset ,
239
+ save_directory ,
240
+ batch_size ,
241
+ data_collator ,
242
+ remove_unused_columns ,
243
+ weights_only ,
244
+ ov_config ,
245
+ ** kwargs ,
246
+ )
247
+ else :
248
+ self ._quantize_ovbasemodel (
249
+ calibration_dataset ,
250
+ save_directory ,
251
+ batch_size ,
252
+ data_collator ,
253
+ remove_unused_columns ,
254
+ weights_only ,
255
+ ** kwargs ,
256
+ )
257
+
257
258
elif isinstance (self .model , torch .nn .Module ):
258
259
self ._quantize_torchmodel (
259
260
calibration_dataset ,
@@ -276,6 +277,7 @@ def _get_compression_options(self, config: OVConfig):
276
277
options ["ratio" ] = config .compression ["ratio" ]
277
278
return options
278
279
280
+ # TODO : add ov_config
279
281
def _quantize_ovbasemodel (
280
282
self ,
281
283
calibration_dataset : Dataset ,
@@ -333,7 +335,7 @@ def _quantize_ovcausallm(
333
335
quantization_config = OVWeightQuantizationConfig (mode = nncf .CompressWeightsMode .INT8_SYM )
334
336
self .model .model = nncf .compress_weights (self .model .model )
335
337
else :
336
- compress_decoder_weights (self .model , quantization_config )
338
+ _int4_weight_only_quantization (self .model , quantization_config )
337
339
338
340
self .model .save_pretrained (save_directory )
339
341
return
@@ -580,7 +582,12 @@ def _remove_unused_columns(self, dataset: Dataset):
580
582
return dataset .remove_columns (ignored_columns )
581
583
582
584
583
- def compress_decoder_weights (model , quantization_config : Union [OVWeightQuantizationConfig , Dict ] = None ):
585
+ def _int4_weight_only_quantization (
586
+ model : OVBaseModel , quantization_config : Optional [Union [OVWeightQuantizationConfig , Dict ]] = None
587
+ ):
588
+ if model .export_feature != "text-generation" :
589
+ raise ValueError ("Only `OVModelForCausalLM` are supported for now" )
590
+
584
591
quantization_config = quantization_config or _check_default_4bit_configs (model .config )
585
592
ov_model = model .model
586
593
0 commit comments