|
16 | 16 | import os
|
17 | 17 | from pathlib import Path
|
18 | 18 | from tempfile import TemporaryDirectory
|
19 |
| -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union |
| 19 | +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union |
20 | 20 |
|
21 | 21 | import numpy as np
|
22 | 22 | import openvino
|
|
31 | 31 | from transformers.generation.logits_process import LogitsProcessorList
|
32 | 32 | from transformers.generation.stopping_criteria import StoppingCriteriaList
|
33 | 33 | from transformers.generation.utils import GenerateOutput, GenerationMode
|
34 |
| -from transformers.modeling_outputs import CausalLMOutputWithPast |
| 34 | +from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput |
35 | 35 |
|
36 | 36 | from optimum.utils.normalized_config import NormalizedConfigManager
|
37 | 37 |
|
@@ -504,8 +504,8 @@ def prepare_inputs(
|
504 | 504 | else:
|
505 | 505 | position_ids = np.cumsum(attention_mask, axis=1) - 1
|
506 | 506 | position_ids[attention_mask == 0] = 1
|
507 |
| - if past_key_values: |
508 |
| - position_ids = position_ids[:, -input_ids.shape[1] :] |
| 507 | + if past_key_values: |
| 508 | + position_ids = position_ids[:, -input_ids.shape[1] :] |
509 | 509 |
|
510 | 510 | inputs["position_ids"] = position_ids
|
511 | 511 |
|
@@ -604,6 +604,24 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
|
604 | 604 |
|
605 | 605 | return model_inputs
|
606 | 606 |
|
| 607 | + def _update_model_kwargs_for_generation( |
| 608 | + self, |
| 609 | + outputs: ModelOutput, |
| 610 | + model_kwargs: Dict[str, Any], |
| 611 | + is_encoder_decoder: bool = False, |
| 612 | + **kwargs, |
| 613 | + ) -> Dict[str, Any]: |
| 614 | + model_kwargs = super()._update_model_kwargs_for_generation( |
| 615 | + outputs=outputs, model_kwargs=model_kwargs, is_encoder_decoder=is_encoder_decoder, **kwargs |
| 616 | + ) |
| 617 | + |
| 618 | + if "position_ids" in model_kwargs: |
| 619 | + position_ids = model_kwargs["position_ids"] |
| 620 | + new_position_id = position_ids[..., -1:].clone() |
| 621 | + new_position_id += 1 |
| 622 | + model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1) |
| 623 | + return model_kwargs |
| 624 | + |
607 | 625 | def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple):
|
608 | 626 | batch_size = logits.shape[0]
|
609 | 627 | if indicies.shape[0] != 1:
|
|
0 commit comments