56
56
logger = logging .getLogger (__name__ )
57
57
58
58
59
+ IPEX_EXPORTED_LIST = ("LlamaForCausalLM" , )
60
+
61
+
62
+ def is_ipex_exported_model (model_name ):
63
+ for name in IPEX_EXPORTED_LIST :
64
+ if model_name == name :
65
+ return True
66
+ return False
67
+
68
+
59
69
def ipex_jit_trace (model ):
60
70
sample_inputs = get_dummy_input (model , return_dict = True )
61
71
model .config .return_dict = False
@@ -138,8 +148,9 @@ def _from_transformers(
138
148
}
139
149
140
150
model = TasksManager .get_model_from_task (task , model_id , ** model_kwargs )
141
- is_ipex_exported = export_model (model )
151
+ is_ipex_exported = is_ipex_exported_model (model . __class__ . __name__ )
142
152
if is_ipex_exported :
153
+ model = export_model (model )
143
154
traced_model = ipex_jit_trace (model )
144
155
else :
145
156
model = patch_decoder_attention_mask (model )
@@ -199,6 +210,8 @@ def _from_pretrained(
199
210
200
211
model = torch .jit .load (model_cache_path )
201
212
torch .jit .freeze (model .eval ())
213
+ is_ipex_exported = is_ipex_exported_model (model .original_name )
214
+ kwargs ["is_ipex_exported" ] = is_ipex_exported
202
215
203
216
return cls (model , config = config , model_save_dir = model_save_dir , ** kwargs )
204
217
@@ -379,12 +392,15 @@ def __init__(
379
392
except AttributeError :
380
393
self .model_cls = get_model_class (self .config , AutoModelForCausalLM ._model_mapping )
381
394
382
- self ._reorder_cache = self .model_cls ._reorder_cache
395
+ if self .is_ipex_exported :
396
+ self ._reorder_cache = _ipex_reorder_cache
397
+ else :
398
+ self ._reorder_cache = self .model_cls ._reorder_cache .__get__ (self )
383
399
384
400
if is_transformers_version (">=" , "4.38.0" ) and model_type in {"llama" , "phi" , "persimmon" }:
385
401
self .prepare_inputs_for_generation = _prepare_inputs_for_generation_for_llama
386
402
else :
387
- self .prepare_inputs_for_generation = self .model_cls .prepare_inputs_for_generation
403
+ self .prepare_inputs_for_generation = self .model_cls .prepare_inputs_for_generation . __get__ ( self )
388
404
389
405
if hasattr (self .model_cls , "_convert_to_standard_cache" ):
390
406
self ._convert_to_standard_cache = self .model_cls ._convert_to_standard_cache
@@ -393,37 +409,6 @@ def __init__(
393
409
if warmup :
394
410
self ._init_warmup ()
395
411
396
- def prepare_inputs_for_generation (self , input_ids , past_key_values = None , ** kwargs ):
397
- past_key_values = past_key_values or kwargs .get ("past" , None )
398
-
399
- if self .use_cache and past_key_values is not None :
400
- input_ids = input_ids [:, - 1 :]
401
-
402
- # `past_key_values` may be in the stardard format (e.g. in contrastive search), converts to bloom's format if needed
403
- if past_key_values is not None and self .config .model_type == "bloom" :
404
- if past_key_values [0 ][0 ].shape [0 ] == input_ids .shape [0 ]:
405
- past_key_values = self ._convert_to_bloom_cache (past_key_values )
406
-
407
- position_ids = kwargs .get ("position_ids" , None )
408
-
409
- attention_mask = kwargs .get ("attention_mask" , None )
410
-
411
- if attention_mask is not None and position_ids is None :
412
- # create position_ids on the fly for batch generation
413
- position_ids = attention_mask .long ().cumsum (- 1 ) - 1
414
- position_ids .masked_fill_ (attention_mask == 0 , 1 )
415
- if past_key_values :
416
- position_ids = position_ids [:, - 1 ].unsqueeze (- 1 )
417
-
418
- return {
419
- "input_ids" : input_ids ,
420
- "past_key_values" : past_key_values ,
421
- "use_cache" : self .use_cache ,
422
- "position_ids" : position_ids ,
423
- "attention_mask" : attention_mask ,
424
- "token_type_ids" : None ,
425
- }
426
-
427
412
def _prepare_past_key_values (self , input_ids ):
428
413
model_type = self .config .model_type .replace ("_" , "-" )
429
414
nb_pkv = 2
@@ -469,104 +454,6 @@ def _prepare_past_key_values(self, input_ids):
469
454
470
455
return past_key_values
471
456
472
- def _reorder_cache (
473
- self , past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
474
- ) -> Tuple [Tuple [torch .Tensor ]]:
475
- """
476
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
477
- [`~PreTrainedModel.beam_sample`] is called.
478
- This is required to match `past_key_values` with the correct beam_idx at every generation step.
479
- """
480
- if self .config .model_type == "bloom" :
481
- return self ._reorder_cache_bloom (past_key_values , beam_idx )
482
-
483
- if self .is_ipex_exported :
484
- if len (past_key_values [0 ]) == 4 and past_key_values [0 ][0 ].shape [- 1 ] == 1 : # discrete kv_cache
485
- for layer_past in past_key_values :
486
- layer_past [3 ][layer_past [0 ].size (- 2 ) - 1 ] = beam_idx
487
- return past_key_values
488
- elif len (past_key_values [0 ]) == 8 :
489
- for layer_past in past_key_values :
490
- layer_past [3 ][layer_past [0 ].size (- 2 ) - 1 ] = beam_idx
491
- layer_past [7 ][layer_past [0 ].size (- 2 ) - 1 ] = beam_idx
492
- return past_key_values
493
- else :
494
- return tuple (
495
- tuple (past_state .index_select (0 , beam_idx .to (past_state .device )) for past_state in layer_past )
496
- for layer_past in past_key_values
497
- )
498
- # from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
499
- return tuple (
500
- tuple (past_state .index_select (0 , beam_idx .to (past_state .device )) for past_state in layer_past )
501
- for layer_past in past_key_values
502
- )
503
-
504
- # Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
505
- def _reorder_cache_bloom (
506
- self , past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
507
- ) -> Tuple [Tuple [torch .Tensor ]]:
508
- """
509
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
510
- [`~PreTrainedModel.beam_sample`] is called for bloom architecture.
511
- This is required to match `past_key_values` with the correct beam_idx at every generation step.
512
- """
513
- standardized_past = self ._convert_to_standard_cache (past_key_values , batch_size = len (beam_idx ))
514
-
515
- # Get a copy of `beam_idx` on all the devices where we need those indices.
516
- device_to_beam_idx = {
517
- past_state .device : beam_idx .to (past_state .device )
518
- for layer_past in past_key_values
519
- for past_state in layer_past
520
- }
521
- reordered_past = tuple (
522
- (
523
- layer_past [0 ].index_select (0 , device_to_beam_idx [layer_past [0 ].device ]),
524
- layer_past [1 ].index_select (0 , device_to_beam_idx [layer_past [0 ].device ]),
525
- )
526
- for layer_past in standardized_past
527
- )
528
- return self ._convert_to_bloom_cache (reordered_past )
529
-
530
- # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_bloom_cache
531
- @staticmethod
532
- def _convert_to_bloom_cache (past_key_value : Tuple [Tuple [torch .Tensor ]]) -> Tuple [Tuple [torch .Tensor ]]:
533
- """
534
- Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
535
- """
536
- batch_size , num_heads , head_dim , seq_length = past_key_value [0 ][0 ].shape
537
- batch_size_times_num_heads = batch_size * num_heads
538
- # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
539
- # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
540
- return tuple (
541
- (
542
- layer_past [0 ].view (batch_size_times_num_heads , head_dim , seq_length ),
543
- layer_past [1 ].view (batch_size_times_num_heads , seq_length , head_dim ),
544
- )
545
- for layer_past in past_key_value
546
- )
547
-
548
- # Adapted from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_standard_cache
549
- def _convert_to_standard_cache (
550
- self , past_key_value : Tuple [Tuple [torch .Tensor ]], batch_size : int
551
- ) -> Tuple [Tuple [torch .Tensor ]]:
552
- """
553
- Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, num_heads, ...]))
554
- """
555
- if self .config .model_type != "bloom" :
556
- return past_key_value
557
-
558
- batch_size_times_num_heads , head_dim , seq_length = past_key_value [0 ][0 ].shape
559
- num_heads = batch_size_times_num_heads // batch_size
560
- # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
561
- # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
562
- return tuple (
563
- (
564
- layer_past [0 ].view (batch_size , num_heads , head_dim , seq_length ),
565
- layer_past [1 ].view (batch_size , num_heads , seq_length , head_dim ),
566
- )
567
- for layer_past in past_key_value
568
- )
569
-
570
457
def forward (
571
458
self ,
572
459
input_ids : torch .LongTensor = None ,
@@ -670,3 +557,28 @@ def _prepare_inputs_for_generation_for_llama(
670
557
}
671
558
)
672
559
return model_inputs
560
+
561
+
562
+ def _ipex_reorder_cache (
563
+ past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
564
+ ) -> Tuple [Tuple [torch .Tensor ]]:
565
+
566
+ if len (past_key_values [0 ]) == 4 and past_key_values [0 ][0 ].shape [- 1 ] == 1 : # discrete kv_cache
567
+ for layer_past in past_key_values :
568
+ layer_past [3 ][layer_past [0 ].size (- 2 ) - 1 ] = beam_idx
569
+ return past_key_values
570
+ elif len (past_key_values [0 ]) == 8 :
571
+ for layer_past in past_key_values :
572
+ layer_past [3 ][layer_past [0 ].size (- 2 ) - 1 ] = beam_idx
573
+ layer_past [7 ][layer_past [0 ].size (- 2 ) - 1 ] = beam_idx
574
+ return past_key_values
575
+ else :
576
+ return tuple (
577
+ tuple (past_state .index_select (0 , beam_idx .to (past_state .device )) for past_state in layer_past )
578
+ for layer_past in past_key_values
579
+ )
580
+
581
+ return tuple (
582
+ tuple (past_state .index_select (0 , beam_idx .to (past_state .device )) for past_state in layer_past )
583
+ for layer_past in past_key_values
584
+ )
0 commit comments