16
16
import os
17
17
from pathlib import Path
18
18
from tempfile import TemporaryDirectory
19
- from typing import Dict , Optional , Tuple , Union
19
+ from typing import Any , Dict , Optional , Tuple , Union
20
20
21
21
import numpy as np
22
22
import openvino
25
25
from openvino .runtime import Core , Tensor , Type
26
26
from transformers import AutoModelForCausalLM , PretrainedConfig
27
27
from transformers .file_utils import add_start_docstrings , add_start_docstrings_to_model_forward
28
- from transformers .modeling_outputs import CausalLMOutputWithPast
28
+ from transformers .modeling_outputs import CausalLMOutputWithPast , ModelOutput
29
29
30
30
from optimum .utils import NormalizedConfigManager
31
31
@@ -401,9 +401,8 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
401
401
# create position_ids on the fly for batch generation
402
402
position_ids = attention_mask .long ().cumsum (- 1 ) - 1
403
403
position_ids .masked_fill_ (attention_mask == 0 , 1 )
404
- if past_key_values :
405
- position_ids = position_ids [:, - 1 ].unsqueeze (- 1 )
406
-
404
+ if past_key_values :
405
+ position_ids = position_ids [:, - 1 ].unsqueeze (- 1 )
407
406
return {
408
407
"input_ids" : input_ids ,
409
408
"past_key_values" : past_key_values ,
@@ -413,6 +412,35 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
413
412
"token_type_ids" : None ,
414
413
}
415
414
415
+ def _update_model_kwargs_for_generation (
416
+ self ,
417
+ outputs : ModelOutput ,
418
+ model_kwargs : Dict [str , Any ],
419
+ is_encoder_decoder : bool = False ,
420
+ standardize_cache_format : bool = False ,
421
+ ) -> Dict [str , Any ]:
422
+ # update past_key_values
423
+ model_kwargs ["past_key_values" ] = self ._extract_past_from_model_output (
424
+ outputs , standardize_cache_format = standardize_cache_format
425
+ )
426
+
427
+ # update attention mask
428
+ if "attention_mask" in model_kwargs :
429
+ attention_mask = model_kwargs ["attention_mask" ]
430
+ model_kwargs ["attention_mask" ] = torch .cat (
431
+ [attention_mask , attention_mask .new_ones ((attention_mask .shape [0 ], 1 ))], dim = - 1
432
+ )
433
+
434
+ # update position ids
435
+ if "position_ids" in model_kwargs :
436
+ position_ids = model_kwargs ["position_ids" ]
437
+ new_position_id = position_ids [..., - 1 :].clone ()
438
+ new_position_id += 1
439
+ model_kwargs ["position_ids" ] = torch .cat ([position_ids , new_position_id ], dim = - 1 )
440
+
441
+ model_kwargs ["is_first_forward" ] = False
442
+ return model_kwargs
443
+
416
444
def _reorder_cache (
417
445
self , past_key_values : Tuple [Tuple [torch .Tensor ]], beam_idx : torch .Tensor
418
446
) -> Tuple [Tuple [torch .Tensor ]]:
0 commit comments