46
46
from optimum .utils import NormalizedConfigManager
47
47
48
48
from ..generation .modeling import jit_trace , prepare_jit_inputs
49
- from ..utils .import_utils import is_torch_version
49
+ from ..utils .import_utils import is_torch_version , is_transformers_version
50
50
from ..utils .modeling_utils import MULTI_QUERY_ATTN_MODELS , patch_decoder_attention_mask
51
51
52
52
@@ -326,7 +326,8 @@ def __init__(
326
326
# Perform the initial warmup at the end of __init__
327
327
super ().__init__ (model , config , model_save_dir = model_save_dir , warmup = False )
328
328
329
- self .normalized_config = NormalizedConfigManager .get_normalized_config_class (config .model_type )(config )
329
+ model_type = config .model_type .replace ("_" , "-" )
330
+ self .normalized_config = NormalizedConfigManager .get_normalized_config_class (model_type )(config )
330
331
self .model_dtype = kwargs .get ("model_dtype" , self .dtype )
331
332
self .use_cache = "past_key_values" in self .input_names
332
333
@@ -339,6 +340,7 @@ def __init__(
339
340
)
340
341
config .is_decoder = True
341
342
config .is_encoder_decoder = False
343
+
342
344
self .generation_config = GenerationConfig .from_model_config (config )
343
345
try :
344
346
self .model_cls = get_class_from_dynamic_module (
@@ -347,7 +349,12 @@ def __init__(
347
349
except AttributeError :
348
350
self .model_cls = get_model_class (self .config , AutoModelForCausalLM ._model_mapping )
349
351
self ._reorder_cache = self .model_cls ._reorder_cache .__get__ (self )
350
- self .prepare_inputs_for_generation = self .model_cls .prepare_inputs_for_generation .__get__ (self )
352
+
353
+ if is_transformers_version (">=" , "4.38.0" ) and model_type in {"llama" , "phi" , "persimmon" }:
354
+ self .prepare_inputs_for_generation = _prepare_inputs_for_generation_for_llama
355
+ else :
356
+ self .prepare_inputs_for_generation = self .model_cls .prepare_inputs_for_generation .__get__ (self )
357
+
351
358
if hasattr (self .model_cls , "_convert_to_standard_cache" ):
352
359
self ._convert_to_standard_cache = self .model_cls ._convert_to_standard_cache
353
360
if hasattr (self .model_cls , "_convert_to_bloom_cache" ):
@@ -430,3 +437,62 @@ def forward(
430
437
past_key_values = outputs ["past_key_values" ] if self .use_cache else None
431
438
432
439
return CausalLMOutputWithPast (logits = logits , past_key_values = past_key_values )
440
+
441
+
442
+ def _prepare_inputs_for_generation_for_llama (
443
+ input_ids , past_key_values = None , attention_mask = None , inputs_embeds = None , ** kwargs
444
+ ):
445
+ from transformers .cache_utils import Cache
446
+
447
+ if past_key_values is not None :
448
+ if isinstance (past_key_values , Cache ):
449
+ cache_length = past_key_values .get_seq_length ()
450
+ past_length = past_key_values .seen_tokens
451
+ max_cache_length = past_key_values .get_max_length ()
452
+ else :
453
+ cache_length = past_length = past_key_values [0 ][0 ].shape [2 ]
454
+ max_cache_length = None
455
+
456
+ # Keep only the unprocessed tokens:
457
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
458
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
459
+ # input)
460
+ if attention_mask is not None and attention_mask .shape [1 ] > input_ids .shape [1 ]:
461
+ input_ids = input_ids [:, - (attention_mask .shape [1 ] - past_length ) :]
462
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
463
+ # input_ids based on the past_length.
464
+ elif past_length < input_ids .shape [1 ]:
465
+ input_ids = input_ids [:, past_length :]
466
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
467
+
468
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
469
+ if (
470
+ max_cache_length is not None
471
+ and attention_mask is not None
472
+ and cache_length + input_ids .shape [1 ] > max_cache_length
473
+ ):
474
+ attention_mask = attention_mask [:, - max_cache_length :]
475
+
476
+ position_ids = kwargs .get ("position_ids" , None )
477
+ if attention_mask is not None and position_ids is None :
478
+ # create position_ids on the fly for batch generation
479
+ position_ids = attention_mask .long ().cumsum (- 1 ) - 1
480
+ position_ids .masked_fill_ (attention_mask == 0 , 1 )
481
+ if past_key_values :
482
+ position_ids = position_ids [:, - input_ids .shape [1 ] :]
483
+
484
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
485
+ if inputs_embeds is not None and past_key_values is None :
486
+ model_inputs = {"inputs_embeds" : inputs_embeds }
487
+ else :
488
+ model_inputs = {"input_ids" : input_ids }
489
+
490
+ model_inputs .update (
491
+ {
492
+ "position_ids" : position_ids ,
493
+ "past_key_values" : past_key_values ,
494
+ "use_cache" : kwargs .get ("use_cache" ),
495
+ "attention_mask" : attention_mask ,
496
+ }
497
+ )
498
+ return model_inputs
0 commit comments