Skip to content

Commit f4d2544

Browse files
committed
fix wrong filling chatglm position_ids input
1 parent a099280 commit f4d2544

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

optimum/intel/openvino/modeling_decoder.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
from pathlib import Path
1818
from tempfile import TemporaryDirectory
19-
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
19+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
2020

2121
import numpy as np
2222
import openvino
@@ -31,7 +31,7 @@
3131
from transformers.generation.logits_process import LogitsProcessorList
3232
from transformers.generation.stopping_criteria import StoppingCriteriaList
3333
from transformers.generation.utils import GenerateOutput, GenerationMode
34-
from transformers.modeling_outputs import CausalLMOutputWithPast
34+
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
3535

3636
from optimum.utils.normalized_config import NormalizedConfigManager
3737

@@ -604,6 +604,24 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
604604

605605
return model_inputs
606606

607+
def _update_model_kwargs_for_generation(
608+
self,
609+
outputs: ModelOutput,
610+
model_kwargs: Dict[str, Any],
611+
is_encoder_decoder: bool = False,
612+
**kwargs,
613+
) -> Dict[str, Any]:
614+
model_kwargs = super()._update_model_kwargs_for_generation(
615+
outputs=outputs, model_kwargs=model_kwargs, is_encoder_decoder=is_encoder_decoder, **kwargs
616+
)
617+
618+
if "position_ids" in model_kwargs:
619+
position_ids = model_kwargs["position_ids"]
620+
new_position_id = position_ids[..., -1:].clone()
621+
new_position_id += 1
622+
model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
623+
return model_kwargs
624+
607625
def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple):
608626
batch_size = logits.shape[0]
609627
if indicies.shape[0] != 1:

0 commit comments

Comments
 (0)