Skip to content

Commit 34351b1

Browse files
committed
rework chatglm config
1 parent dfb9ae0 commit 34351b1

File tree

2 files changed

+13
-29
lines changed

2 files changed

+13
-29
lines changed

optimum/exporters/openvino/model_configs.py

+12-28
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,9 @@ class ChatGLM2DummyTextInputGenerator(DummyTextInputGenerator):
8585
}
8686

8787
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
88-
import torch
89-
9088
input = super().generate(input_name, framework, int_dtype, float_dtype)
9189
if input_name == "attention_mask":
92-
input = torch.ones(input.shape, dtype=input.dtype)
93-
if input_name == "position_ids":
94-
bs = input.shape[0]
95-
input = torch.range(0, input.shape[1], dtype=input.dtype).repeat(bs, 1)
90+
input = self.random_int_tensor(input.shape, max_value=1, min_value=1)
9691
return input
9792

9893

@@ -141,11 +136,10 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
141136

142137

143138
@register_in_tasks_manager("chatglm", *["text-generation", "text-generation-with-past"])
144-
class ChatGLM2OpenVINOConfig(TextDecoderOnnxConfig):
139+
class ChatGLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
145140
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(vocab_size="padded_vocab_size", num_layers="num_layers")
146141
DUMMY_INPUT_GENERATOR_CLASSES = (ChatGLM2DummyTextInputGenerator, ChatGLM2DummyPastKeyValuesGenerator)
147142
DUMMY_PKV_GENERATOR_CLASS = ChatGLM2DummyPastKeyValuesGenerator
148-
no_position_ids = False
149143

150144
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
151145
dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)
@@ -173,34 +167,24 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
173167
)
174168

175169
# refer to https://github.com/huggingface/optimum/pull/764
176-
cond1 = self.use_past_in_inputs
177-
cond2 = self.PAD_ATTENTION_MASK_TO_PAST
178-
cond3 = self.use_cache_branch is not False
179-
cond4 = "attention_mask" in dummy_inputs
180-
if cond1 and cond2 and cond3 and cond4:
181-
# Obtain the past sequence length from the value instead of the key (Bloom).
182-
past_length = dummy_inputs["past_key_values"][0][1].shape[0]
183-
for k, v in dummy_inputs.items():
184-
if k not in ["attention_mask", "past_key_values"]:
185-
dummy_inputs[k] = v[:, -1:]
170+
if (
171+
self.use_past_in_inputs
172+
and self.PAD_ATTENTION_MASK_TO_PAST
173+
and self.use_cache_branch is not False
174+
and "attention_mask" in dummy_inputs
175+
):
176+
# Obtain the past sequence length from the value instead of the key (Bloom). ChatGLM has seq_len in 0 dim instead of -2
177+
past_present_length = dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[0]
186178

187179
dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
188180
dummy_inputs["attention_mask"],
189-
desired_length=past_length + 1,
181+
desired_length=past_present_length,
190182
dim=1,
191183
dtype=dummy_inputs["attention_mask"].dtype,
192184
)
193185

194186
return dummy_inputs
195187

196-
@property
197-
def inputs(self) -> Dict[str, Dict[int, str]]:
198-
common_inputs = super().inputs
199-
if not self.no_position_ids and self.task == "text-generation":
200-
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}
201-
202-
return common_inputs
203-
204188
def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
205189
"""
206190
Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction.
@@ -218,7 +202,7 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
218202
decoder_sequence_name = "past_sequence_length"
219203
name = "past_key_values"
220204
else:
221-
decoder_sequence_name = "past_sequence_length + 1"
205+
decoder_sequence_name = "past_sequence_length + present_lenght"
222206
name = "present"
223207

224208
for i in range(self._normalized_config.num_layers):

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
"onnx",
5050
"onnxruntime",
5151
"transformers>=4.36.0",
52-
"optimum @ git+https://github.com/huggingface/optimum.git#egg=optimum"
52+
"optimum @ git+https://github.com/huggingface/optimum.git#egg=optimum",
5353
],
5454
"openvino-tokenizers": ["openvino-tokenizers[transformers]"],
5555
"nncf": ["nncf>=2.8.1"],

0 commit comments

Comments
 (0)