Skip to content

Commit 90539d7

Browse files
authored
Merge pull request #2 from eaidova/ea/fix_glm_input
fix wrong filling chatglm position_ids input
2 parents a099280 + 64ef340 commit 90539d7

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

optimum/intel/openvino/modeling_decoder.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
from pathlib import Path
1818
from tempfile import TemporaryDirectory
19-
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
19+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
2020

2121
import numpy as np
2222
import openvino
@@ -31,7 +31,7 @@
3131
from transformers.generation.logits_process import LogitsProcessorList
3232
from transformers.generation.stopping_criteria import StoppingCriteriaList
3333
from transformers.generation.utils import GenerateOutput, GenerationMode
34-
from transformers.modeling_outputs import CausalLMOutputWithPast
34+
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
3535

3636
from optimum.utils.normalized_config import NormalizedConfigManager
3737

@@ -504,8 +504,8 @@ def prepare_inputs(
504504
else:
505505
position_ids = np.cumsum(attention_mask, axis=1) - 1
506506
position_ids[attention_mask == 0] = 1
507-
if past_key_values:
508-
position_ids = position_ids[:, -input_ids.shape[1] :]
507+
if past_key_values:
508+
position_ids = position_ids[:, -input_ids.shape[1] :]
509509

510510
inputs["position_ids"] = position_ids
511511

@@ -604,6 +604,24 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
604604

605605
return model_inputs
606606

607+
def _update_model_kwargs_for_generation(
608+
self,
609+
outputs: ModelOutput,
610+
model_kwargs: Dict[str, Any],
611+
is_encoder_decoder: bool = False,
612+
**kwargs,
613+
) -> Dict[str, Any]:
614+
model_kwargs = super()._update_model_kwargs_for_generation(
615+
outputs=outputs, model_kwargs=model_kwargs, is_encoder_decoder=is_encoder_decoder, **kwargs
616+
)
617+
618+
if "position_ids" in model_kwargs:
619+
position_ids = model_kwargs["position_ids"]
620+
new_position_id = position_ids[..., -1:].clone()
621+
new_position_id += 1
622+
model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
623+
return model_kwargs
624+
607625
def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple):
608626
batch_size = logits.shape[0]
609627
if indicies.shape[0] != 1:

0 commit comments

Comments
 (0)