35
35
from ..utils .modeling_utils import MULTI_QUERY_ATTN_MODELS
36
36
from .modeling import _TOKENIZER_FOR_DOC , INPUTS_DOCSTRING , MODEL_START_DOCSTRING , OVModel
37
37
from .utils import ONNX_WEIGHTS_NAME , OV_XML_FILE_NAME , STR_TO_OV_TYPE
38
+ from .weight_quantization import OVWeightQuantizationConfig , compress_decoder_weights
38
39
39
40
40
41
if is_transformers_version ("<" , "4.25.0" ):
@@ -243,6 +244,8 @@ def _from_transformers(
243
244
use_cache : bool = True ,
244
245
trust_remote_code : bool = False ,
245
246
load_in_8bit : Optional [bool ] = None ,
247
+ load_in_4bit : Optional [bool ] = None ,
248
+ quantization_config : Optional [Union [OVWeightQuantizationConfig , Dict ]] = None ,
246
249
** kwargs ,
247
250
):
248
251
if config .model_type .replace ("_" , "-" ) not in _SUPPORTED_ARCHITECTURES :
@@ -259,9 +262,10 @@ def _from_transformers(
259
262
if use_cache :
260
263
task = task + "-with-past"
261
264
265
+ # If load_in_8bit is not specified then compression_option should be set to None and will be set by default in main_export depending on the model size
262
266
compression_option = None
263
- if load_in_8bit is not None :
264
- compression_option = "int8" if load_in_8bit else " fp32"
267
+ if load_in_8bit is not None or load_in_4bit is not None :
268
+ compression_option = "fp32"
265
269
stateful = kwargs .pop ("stateful" , ensure_stateful_is_available (warn = False ) and use_cache )
266
270
main_export (
267
271
model_name_or_path = model_id ,
@@ -282,7 +286,14 @@ def _from_transformers(
282
286
config .is_encoder_decoder = False
283
287
config .save_pretrained (save_dir_path )
284
288
return cls ._from_pretrained (
285
- model_id = save_dir_path , config = config , use_cache = use_cache , load_in_8bit = False , stateful = None , ** kwargs
289
+ model_id = save_dir_path ,
290
+ config = config ,
291
+ use_cache = use_cache ,
292
+ load_in_8bit = load_in_8bit ,
293
+ stateful = None ,
294
+ load_in_4bit = load_in_4bit ,
295
+ quantization_config = quantization_config ,
296
+ ** kwargs ,
286
297
)
287
298
288
299
def _reshape (
@@ -356,15 +367,14 @@ class OVModelForCausalLM(OVBaseDecoderModel, GenerationMixin):
356
367
checkpoint = "gpt2" ,
357
368
)
358
369
)
359
- def forward (
370
+ def prepare_inputs (
360
371
self ,
361
372
input_ids : torch .LongTensor ,
362
373
attention_mask : Optional [torch .LongTensor ] = None ,
363
374
past_key_values : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None ,
364
375
position_ids : Optional [torch .LongTensor ] = None ,
365
376
** kwargs ,
366
- ) -> CausalLMOutputWithPast :
367
- self .compile ()
377
+ ) -> Dict :
368
378
if self .use_cache and past_key_values is not None :
369
379
input_ids = input_ids [:, - 1 :]
370
380
@@ -449,6 +459,26 @@ def forward(
449
459
self .next_beam_idx if self .next_beam_idx is not None else np .arange (batch_size , dtype = int )
450
460
)
451
461
462
+ return inputs
463
+
464
+ def forward (
465
+ self ,
466
+ input_ids : torch .LongTensor ,
467
+ attention_mask : Optional [torch .LongTensor ] = None ,
468
+ past_key_values : Optional [Tuple [Tuple [torch .FloatTensor ]]] = None ,
469
+ position_ids : Optional [torch .LongTensor ] = None ,
470
+ ** kwargs ,
471
+ ) -> CausalLMOutputWithPast :
472
+ self .compile ()
473
+
474
+ inputs = self .prepare_inputs (
475
+ input_ids = input_ids ,
476
+ attention_mask = attention_mask ,
477
+ past_key_values = past_key_values ,
478
+ position_ids = position_ids ,
479
+ ** kwargs ,
480
+ )
481
+
452
482
# Run inference
453
483
self .request .start_async (inputs , share_inputs = True )
454
484
self .request .wait ()
@@ -532,6 +562,8 @@ def _from_pretrained(
532
562
from_onnx : bool = False ,
533
563
local_files_only : bool = False ,
534
564
load_in_8bit : bool = False ,
565
+ load_in_4bit : bool = False ,
566
+ quantization_config : Union [OVWeightQuantizationConfig , Dict ] = None ,
535
567
** kwargs ,
536
568
):
537
569
model_path = Path (model_id )
@@ -549,7 +581,9 @@ def _from_pretrained(
549
581
local_files_only = local_files_only ,
550
582
)
551
583
552
- model = cls .load_model (model_cache_path , load_in_8bit = load_in_8bit )
584
+ if load_in_8bit and load_in_4bit :
585
+ raise ValueError ("Either load_in_8bit or load_in_4bit should be set to True." )
586
+ model = cls .load_model (model_cache_path , load_in_8bit = False if load_in_4bit else load_in_8bit )
553
587
554
588
model_type = config .model_type .replace ("_" , "-" )
555
589
if model_type == "bloom" :
@@ -563,7 +597,11 @@ def _from_pretrained(
563
597
else :
564
598
init_cls = cls
565
599
566
- return init_cls (model = model , config = config , model_save_dir = model_cache_path .parent , ** kwargs )
600
+ causal_model = init_cls (model = model , config = config , model_save_dir = model_cache_path .parent , ** kwargs )
601
+
602
+ if load_in_4bit :
603
+ compress_decoder_weights (causal_model , quantization_config )
604
+ return causal_model
567
605
568
606
569
607
class OVBloomForCausalLM (OVModelForCausalLM ):
0 commit comments