@@ -348,11 +348,95 @@ def __init__(
348
348
if warmup :
349
349
self ._init_warmup ()
350
350
351
+ def _reorder_cache (
352
+ self , past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
353
+ ) -> Tuple [Tuple [torch .Tensor ]]:
354
+ """
355
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
356
+ [`~PreTrainedModel.beam_sample`] is called.
357
+ This is required to match `past_key_values` with the correct beam_idx at every generation step.
358
+ """
359
+ if self .config .model_type == "bloom" :
360
+ return self ._reorder_cache_bloom (past_key_values , beam_idx )
361
+
362
+ # from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
363
+ return tuple (
364
+ tuple (past_state .index_select (0 , beam_idx .to (past_state .device )) for past_state in layer_past )
365
+ for layer_past in past_key_values
366
+ )
367
+
368
+ # Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
369
+ def _reorder_cache_bloom (
370
+ self , past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
371
+ ) -> Tuple [Tuple [torch .Tensor ]]:
372
+ """
373
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
374
+ [`~PreTrainedModel.beam_sample`] is called for bloom architecture.
375
+ This is required to match `past_key_values` with the correct beam_idx at every generation step.
376
+ """
377
+ standardized_past = self ._convert_to_standard_cache (past_key_values , batch_size = len (beam_idx ))
378
+
379
+ # Get a copy of `beam_idx` on all the devices where we need those indices.
380
+ device_to_beam_idx = {
381
+ past_state .device : beam_idx .to (past_state .device )
382
+ for layer_past in past_key_values
383
+ for past_state in layer_past
384
+ }
385
+ reordered_past = tuple (
386
+ (
387
+ layer_past [0 ].index_select (0 , device_to_beam_idx [layer_past [0 ].device ]),
388
+ layer_past [1 ].index_select (0 , device_to_beam_idx [layer_past [0 ].device ]),
389
+ )
390
+ for layer_past in standardized_past
391
+ )
392
+ return self ._convert_to_bloom_cache (reordered_past )
393
+
394
+ # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_bloom_cache
395
+ @staticmethod
396
+ def _convert_to_bloom_cache (past_key_value : Tuple [Tuple [torch .Tensor ]]) -> Tuple [Tuple [torch .Tensor ]]:
397
+ """
398
+ Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
399
+ """
400
+ batch_size , num_heads , head_dim , seq_length = past_key_value [0 ][0 ].shape
401
+ batch_size_times_num_heads = batch_size * num_heads
402
+ # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
403
+ # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
404
+ return tuple (
405
+ (
406
+ layer_past [0 ].view (batch_size_times_num_heads , head_dim , seq_length ),
407
+ layer_past [1 ].view (batch_size_times_num_heads , seq_length , head_dim ),
408
+ )
409
+ for layer_past in past_key_value
410
+ )
411
+
412
+ # Adapted from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_standard_cache
413
+ def _convert_to_standard_cache (
414
+ self , past_key_value : Tuple [Tuple [torch .Tensor ]], batch_size : int
415
+ ) -> Tuple [Tuple [torch .Tensor ]]:
416
+ """
417
+ Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, num_heads, ...]))
418
+ """
419
+ if self .config .model_type != "bloom" :
420
+ return past_key_value
421
+
422
+ batch_size_times_num_heads , head_dim , seq_length = past_key_value [0 ][0 ].shape
423
+ num_heads = batch_size_times_num_heads // batch_size
424
+ # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
425
+ # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
426
+ return tuple (
427
+ (
428
+ layer_past [0 ].view (batch_size , num_heads , head_dim , seq_length ),
429
+ layer_past [1 ].view (batch_size , num_heads , seq_length , head_dim ),
430
+ )
431
+ for layer_past in past_key_value
432
+ )
433
+
351
434
def prepare_inputs_for_generation (self , input_ids , past_key_values = None , ** kwargs ):
352
435
past_key_values = past_key_values or kwargs .get ("past" , None )
353
436
354
437
if self .use_cache and past_key_values is not None :
355
- input_ids = input_ids [:, - 1 :]
438
+ past_length = self .get_past_length (past_key_values )
439
+ input_ids = input_ids [:, past_length :]
356
440
357
441
# `past_key_values` may be in the stardard format (e.g. in contrastive search), converts to bloom's format if needed
358
442
if past_key_values is not None and self .config .model_type == "bloom" :
@@ -368,7 +452,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
368
452
position_ids = attention_mask .long ().cumsum (- 1 ) - 1
369
453
position_ids .masked_fill_ (attention_mask == 0 , 1 )
370
454
if past_key_values :
371
- position_ids = position_ids [:, - 1 ]. unsqueeze ( - 1 )
455
+ position_ids = position_ids [:, - input_ids . shape [ - 1 ] :]
372
456
373
457
return {
374
458
"input_ids" : input_ids ,
@@ -379,6 +463,15 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
379
463
"token_type_ids" : None ,
380
464
}
381
465
466
+ def get_past_length (self , past_key_values ):
467
+ model_type = self .config .model_type .replace ("_" , "-" )
468
+ if model_type == "bloom" :
469
+ return past_key_values [0 ][0 ].shape [- 1 ]
470
+ elif model_type .replace ("-" , "_" ) in MULTI_QUERY_ATTN_MODELS :
471
+ return past_key_values [0 ].shape [1 ]
472
+ else :
473
+ return past_key_values [0 ][0 ].shape [- 2 ]
474
+
382
475
def _prepare_past_key_values (self , input_ids ):
383
476
model_type = self .config .model_type .replace ("_" , "-" )
384
477
nb_pkv = 2
@@ -431,7 +524,7 @@ def forward(
431
524
position_ids = attention_mask .long ().cumsum (- 1 ) - 1
432
525
position_ids .masked_fill_ (attention_mask == 0 , 1 )
433
526
if past_key_values :
434
- position_ids = position_ids [:, - 1 ]. unsqueeze ( - 1 )
527
+ position_ids = position_ids [:, - input_ids . shape [ - 1 ] :]
435
528
436
529
if "position_ids" in self .input_names or not self .input_names :
437
530
inputs ["position_ids" ] = position_ids
0 commit comments