Skip to content

Commit a7b766e

Browse files
authored
Add load_in_4bit option for OVModelForCausalLM (#538)
* Initial code for load_in_4_bit * Dataset does not work * Intermediate changes * Make it working with dataset * Style * Fixed small issue * Fixed failed tests * Style * Comment failed tests due to NNCF 2.8 * Commented failed tests until new NNCF release * Added tests for load_in_4bit * Added awq option. Included NNCF package into openvino extra. * Rolled back including nncf into openvino extra * Style * Fixed tests * Fixed issues with models larger than 1B. Added tests. * Style * Fixed issues. Applied comments. * Removed unnecessary exception * Applied more comments * Fixed issue * Make quantization_config a part of OVConfig in OVQuantizer * Fixed issue with Transformers * Fixed test * Changed the naming. Added additional tests * Fixed tests * Fixed tests * Applied more comments * Style
1 parent e40e627 commit a7b766e

14 files changed

+431
-64
lines changed

optimum/exporters/openvino/convert.py

+1
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def _save_model(model, path: str, compression_option: Optional[str] = None, comp
9595
"ratio": compression_ratio,
9696
},
9797
}
98+
9899
model = nncf.compress_weights(model, **COMPRESSION_OPTIONS[compression_option])
99100

100101
compress_to_fp16 = compression_option == "fp16"

optimum/intel/__init__.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,12 @@
6262
"OVQuantizer",
6363
"OVTrainer",
6464
"OVTrainingArguments",
65+
"OVWeightQuantizationConfig",
6566
]
6667
else:
67-
_import_structure["openvino"].extend(["OVConfig", "OVQuantizer", "OVTrainer", "OVTrainingArguments"])
68+
_import_structure["openvino"].extend(
69+
["OVConfig", "OVQuantizer", "OVTrainer", "OVTrainingArguments", "OVWeightQuantizationConfig"]
70+
)
6871

6972
try:
7073
if not (is_openvino_available() and is_diffusers_available()):
@@ -176,9 +179,15 @@
176179
if not (is_openvino_available() and is_nncf_available()):
177180
raise OptionalDependencyNotAvailable()
178181
except OptionalDependencyNotAvailable:
179-
from .utils.dummy_openvino_and_nncf_objects import OVConfig, OVQuantizer, OVTrainer, OVTrainingArguments
182+
from .utils.dummy_openvino_and_nncf_objects import (
183+
OVConfig,
184+
OVQuantizer,
185+
OVTrainer,
186+
OVTrainingArguments,
187+
OVWeightQuantizationConfig,
188+
)
180189
else:
181-
from .openvino import OVConfig, OVQuantizer, OVTrainer, OVTrainingArguments
190+
from .openvino import OVConfig, OVQuantizer, OVTrainer, OVTrainingArguments, OVWeightQuantizationConfig
182191

183192
try:
184193
if not (is_openvino_available() and is_diffusers_available()):

optimum/intel/openvino/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from .quantization import OVQuantizer
3333
from .trainer import OVTrainer
3434
from .training_args import OVTrainingArguments
35+
from .weight_quantization import OVWeightQuantizationConfig
3536

3637
from .modeling import (
3738
OVModelForAudioClassification,

optimum/intel/openvino/configuration.py

+10
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
from typing import Dict, List, Optional, Union
1616

1717
import torch
18+
from transformers.utils.quantization_config import QuantizationConfigMixin
1819

1920
from optimum.configuration_utils import BaseConfig
2021

22+
from .weight_quantization import OVWeightQuantizationConfig
23+
2124

2225
DEFAULT_QUANTIZATION_CONFIG = {
2326
"algorithm": "quantization",
@@ -83,6 +86,7 @@ def __init__(
8386
compression: Union[List[Dict], Dict, None] = None,
8487
input_info: Optional[List] = None,
8588
save_onnx_model: bool = False,
89+
quantization_config: Optional[QuantizationConfigMixin] = None,
8690
**kwargs,
8791
):
8892
super().__init__()
@@ -91,6 +95,7 @@ def __init__(
9195
self.save_onnx_model = save_onnx_model
9296
self._enable_standard_onnx_export_option()
9397
self.optimum_version = kwargs.pop("optimum_version", None)
98+
self.quantization_config = quantization_config
9499

95100
def add_input_info(self, model_inputs: Dict, force_batch_one: bool = False):
96101
self.input_info = [
@@ -102,6 +107,11 @@ def add_input_info(self, model_inputs: Dict, force_batch_one: bool = False):
102107
for name, value in model_inputs.items()
103108
]
104109

110+
def save_pretrained(self, *args, **kwargs):
111+
if self.quantization_config is None:
112+
self.quantization_config = OVWeightQuantizationConfig()
113+
super().save_pretrained(*args, **kwargs)
114+
105115
def _enable_standard_onnx_export_option(self):
106116
# This method depends on self.save_onnx_model.
107117
# save_onnx_model is defaulted to false so that the final model output is

optimum/intel/openvino/modeling_base.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def _from_pretrained(
164164
from_onnx: bool = False,
165165
local_files_only: bool = False,
166166
load_in_8bit: bool = False,
167+
load_in_4bit: bool = False,
167168
**kwargs,
168169
):
169170
"""
@@ -186,13 +187,18 @@ def _from_pretrained(
186187
force_download (`bool`, defaults to `False`):
187188
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
188189
cached versions if they exist.
189-
file_name(`str`, *optional*):
190+
file_name (`str`, *optional*):
190191
The file name of the model to load. Overwrites the default file name and allows one to load the model
191192
with a different name.
192-
local_files_only(`bool`, *optional*, defaults to `False`):
193+
local_files_only (`bool`, *optional*, defaults to `False`):
193194
Whether or not to only look at local files (i.e., do not try to download the model).
195+
load_in_8bit (`bool`, *optional*, defaults to `False`):
196+
Whether or not to apply 8-bit weight quantization.
197+
load_in_4bit (`bool`, *optional*, defaults to `False`):
198+
Whether or not to apply 4-bit weight quantization.
194199
"""
195-
200+
if load_in_4bit:
201+
raise ValueError("load_in_4bit is available for OVModelForCausalLM only.")
196202
model_path = Path(model_id)
197203
default_file_name = ONNX_WEIGHTS_NAME if from_onnx else OV_XML_FILE_NAME
198204
file_name = file_name or default_file_name
@@ -260,6 +266,7 @@ def _from_transformers(
260266
task: Optional[str] = None,
261267
trust_remote_code: bool = False,
262268
load_in_8bit: Optional[bool] = None,
269+
load_in_4bit: Optional[bool] = None,
263270
**kwargs,
264271
):
265272
"""
@@ -283,9 +290,10 @@ def _from_transformers(
283290
save_dir = TemporaryDirectory()
284291
save_dir_path = Path(save_dir.name)
285292

293+
# If load_in_8bit is not specified then compression_option should be set to None and will be set by default in main_export depending on the model size
286294
compression_option = None
287295
if load_in_8bit is not None:
288-
compression_option = "int8" if load_in_8bit else "fp32"
296+
compression_option = "fp32"
289297

290298
main_export(
291299
model_name_or_path=model_id,
@@ -302,7 +310,7 @@ def _from_transformers(
302310
)
303311

304312
config.save_pretrained(save_dir_path)
305-
return cls._from_pretrained(model_id=save_dir_path, config=config, load_in_8bit=False, **kwargs)
313+
return cls._from_pretrained(model_id=save_dir_path, config=config, load_in_8bit=load_in_8bit, **kwargs)
306314

307315
@classmethod
308316
def _to_load(

optimum/intel/openvino/modeling_base_seq2seq.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def _from_transformers(
253253

254254
compression_option = None
255255
if load_in_8bit is not None:
256-
compression_option = "int8" if load_in_8bit else "fp32"
256+
compression_option = "fp32"
257257
main_export(
258258
model_name_or_path=model_id,
259259
output=save_dir_path,
@@ -270,7 +270,7 @@ def _from_transformers(
270270

271271
config.save_pretrained(save_dir_path)
272272
return cls._from_pretrained(
273-
model_id=save_dir_path, config=config, use_cache=use_cache, load_in_8bit=False, **kwargs
273+
model_id=save_dir_path, config=config, use_cache=use_cache, load_in_8bit=load_in_8bit, **kwargs
274274
)
275275

276276
def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_length: int, is_decoder=True):

optimum/intel/openvino/modeling_decoder.py

+46-8
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS
3636
from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel
3737
from .utils import ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE
38+
from .weight_quantization import OVWeightQuantizationConfig, compress_decoder_weights
3839

3940

4041
if is_transformers_version("<", "4.25.0"):
@@ -243,6 +244,8 @@ def _from_transformers(
243244
use_cache: bool = True,
244245
trust_remote_code: bool = False,
245246
load_in_8bit: Optional[bool] = None,
247+
load_in_4bit: Optional[bool] = None,
248+
quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None,
246249
**kwargs,
247250
):
248251
if config.model_type.replace("_", "-") not in _SUPPORTED_ARCHITECTURES:
@@ -259,9 +262,10 @@ def _from_transformers(
259262
if use_cache:
260263
task = task + "-with-past"
261264

265+
# If load_in_8bit is not specified then compression_option should be set to None and will be set by default in main_export depending on the model size
262266
compression_option = None
263-
if load_in_8bit is not None:
264-
compression_option = "int8" if load_in_8bit else "fp32"
267+
if load_in_8bit is not None or load_in_4bit is not None:
268+
compression_option = "fp32"
265269
stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)
266270
main_export(
267271
model_name_or_path=model_id,
@@ -282,7 +286,14 @@ def _from_transformers(
282286
config.is_encoder_decoder = False
283287
config.save_pretrained(save_dir_path)
284288
return cls._from_pretrained(
285-
model_id=save_dir_path, config=config, use_cache=use_cache, load_in_8bit=False, stateful=None, **kwargs
289+
model_id=save_dir_path,
290+
config=config,
291+
use_cache=use_cache,
292+
load_in_8bit=load_in_8bit,
293+
stateful=None,
294+
load_in_4bit=load_in_4bit,
295+
quantization_config=quantization_config,
296+
**kwargs,
286297
)
287298

288299
def _reshape(
@@ -356,15 +367,14 @@ class OVModelForCausalLM(OVBaseDecoderModel, GenerationMixin):
356367
checkpoint="gpt2",
357368
)
358369
)
359-
def forward(
370+
def prepare_inputs(
360371
self,
361372
input_ids: torch.LongTensor,
362373
attention_mask: Optional[torch.LongTensor] = None,
363374
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
364375
position_ids: Optional[torch.LongTensor] = None,
365376
**kwargs,
366-
) -> CausalLMOutputWithPast:
367-
self.compile()
377+
) -> Dict:
368378
if self.use_cache and past_key_values is not None:
369379
input_ids = input_ids[:, -1:]
370380

@@ -449,6 +459,26 @@ def forward(
449459
self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int)
450460
)
451461

462+
return inputs
463+
464+
def forward(
465+
self,
466+
input_ids: torch.LongTensor,
467+
attention_mask: Optional[torch.LongTensor] = None,
468+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
469+
position_ids: Optional[torch.LongTensor] = None,
470+
**kwargs,
471+
) -> CausalLMOutputWithPast:
472+
self.compile()
473+
474+
inputs = self.prepare_inputs(
475+
input_ids=input_ids,
476+
attention_mask=attention_mask,
477+
past_key_values=past_key_values,
478+
position_ids=position_ids,
479+
**kwargs,
480+
)
481+
452482
# Run inference
453483
self.request.start_async(inputs, share_inputs=True)
454484
self.request.wait()
@@ -532,6 +562,8 @@ def _from_pretrained(
532562
from_onnx: bool = False,
533563
local_files_only: bool = False,
534564
load_in_8bit: bool = False,
565+
load_in_4bit: bool = False,
566+
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
535567
**kwargs,
536568
):
537569
model_path = Path(model_id)
@@ -549,7 +581,9 @@ def _from_pretrained(
549581
local_files_only=local_files_only,
550582
)
551583

552-
model = cls.load_model(model_cache_path, load_in_8bit=load_in_8bit)
584+
if load_in_8bit and load_in_4bit:
585+
raise ValueError("Either load_in_8bit or load_in_4bit should be set to True.")
586+
model = cls.load_model(model_cache_path, load_in_8bit=False if load_in_4bit else load_in_8bit)
553587

554588
model_type = config.model_type.replace("_", "-")
555589
if model_type == "bloom":
@@ -563,7 +597,11 @@ def _from_pretrained(
563597
else:
564598
init_cls = cls
565599

566-
return init_cls(model=model, config=config, model_save_dir=model_cache_path.parent, **kwargs)
600+
causal_model = init_cls(model=model, config=config, model_save_dir=model_cache_path.parent, **kwargs)
601+
602+
if load_in_4bit:
603+
compress_decoder_weights(causal_model, quantization_config)
604+
return causal_model
567605

568606

569607
class OVBloomForCausalLM(OVModelForCausalLM):

0 commit comments

Comments
 (0)