@@ -90,13 +90,13 @@ def _from_transformers(
90
90
cls ,
91
91
model_id : str ,
92
92
config : PretrainedConfig ,
93
+ use_cache : bool = True ,
93
94
use_auth_token : Optional [Union [bool , str ]] = None ,
94
95
revision : Optional [str ] = None ,
95
96
force_download : bool = False ,
96
97
cache_dir : Optional [str ] = None ,
97
98
subfolder : str = "" ,
98
99
local_files_only : bool = False ,
99
- use_cache : bool = True ,
100
100
torch_dtype : Optional [Union [str , "torch.dtype" ]] = None ,
101
101
trust_remote_code : bool = False ,
102
102
):
@@ -134,6 +134,7 @@ def _from_transformers(
134
134
cache_dir = cache_dir ,
135
135
local_files_only = local_files_only ,
136
136
use_cache = use_cache ,
137
+ model_dtype = torch_dtype ,
137
138
)
138
139
139
140
@classmethod
@@ -325,9 +326,11 @@ def __init__(
325
326
):
326
327
# Perform the initial warmup at the end of __init__
327
328
super ().__init__ (model , config , model_save_dir = model_save_dir , warmup = False )
329
+ GenerationMixin .__init__ (self )
328
330
329
331
self .normalized_config = NormalizedConfigManager .get_normalized_config_class (config .model_type )(config )
330
332
self .model_dtype = kwargs .get ("model_dtype" , self .dtype )
333
+ self ._dtype = self .model_dtype
331
334
self .use_cache = "past_key_values" in self .input_names
332
335
333
336
if use_cache ^ self .use_cache :
@@ -346,15 +349,44 @@ def __init__(
346
349
)
347
350
except AttributeError :
348
351
self .model_cls = get_model_class (self .config , AutoModelForCausalLM ._model_mapping )
349
- self ._reorder_cache = self .model_cls ._reorder_cache .__get__ (self )
350
- self .prepare_inputs_for_generation = self .model_cls .prepare_inputs_for_generation .__get__ (self )
351
352
if hasattr (self .model_cls , "_convert_to_standard_cache" ):
352
353
self ._convert_to_standard_cache = self .model_cls ._convert_to_standard_cache
353
354
if hasattr (self .model_cls , "_convert_to_bloom_cache" ):
354
355
self ._convert_to_bloom_cache = self .model_cls ._convert_to_bloom_cache
355
356
if warmup :
356
357
self ._init_warmup ()
357
358
359
+ def prepare_inputs_for_generation (self , input_ids , past_key_values = None , ** kwargs ):
360
+ past_key_values = past_key_values or kwargs .get ("past" , None )
361
+
362
+ if self .use_cache and past_key_values is not None :
363
+ input_ids = input_ids [:, - 1 :]
364
+
365
+ # `past_key_values` may be in the stardard format (e.g. in contrastive search), converts to bloom's format if needed
366
+ if past_key_values is not None and self .config .model_type == "bloom" :
367
+ if past_key_values [0 ][0 ].shape [0 ] == input_ids .shape [0 ]:
368
+ past_key_values = self ._convert_to_bloom_cache (past_key_values )
369
+
370
+ position_ids = kwargs .get ("position_ids" , None )
371
+
372
+ attention_mask = kwargs .get ("attention_mask" , None )
373
+
374
+ if attention_mask is not None and position_ids is None :
375
+ # create position_ids on the fly for batch generation
376
+ position_ids = attention_mask .long ().cumsum (- 1 ) - 1
377
+ position_ids .masked_fill_ (attention_mask == 0 , 1 )
378
+ if past_key_values :
379
+ position_ids = position_ids [:, - 1 ].unsqueeze (- 1 )
380
+
381
+ return {
382
+ "input_ids" : input_ids ,
383
+ "past_key_values" : past_key_values ,
384
+ "use_cache" : self .use_cache ,
385
+ "position_ids" : position_ids ,
386
+ "attention_mask" : attention_mask ,
387
+ "token_type_ids" : None ,
388
+ }
389
+
358
390
def _prepare_past_key_values (self , input_ids ):
359
391
model_type = self .config .model_type .replace ("_" , "-" )
360
392
nb_pkv = 2
0 commit comments