@@ -85,14 +85,9 @@ class ChatGLM2DummyTextInputGenerator(DummyTextInputGenerator):
85
85
}
86
86
87
87
def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
88
- import torch
89
-
90
88
input = super ().generate (input_name , framework , int_dtype , float_dtype )
91
89
if input_name == "attention_mask" :
92
- input = torch .ones (input .shape , dtype = input .dtype )
93
- if input_name == "position_ids" :
94
- bs = input .shape [0 ]
95
- input = torch .range (0 , input .shape [1 ], dtype = input .dtype ).repeat (bs , 1 )
90
+ input = self .random_int_tensor (input .shape , max_value = 1 , min_value = 1 )
96
91
return input
97
92
98
93
@@ -141,11 +136,10 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
141
136
142
137
143
138
@register_in_tasks_manager ("chatglm" , * ["text-generation" , "text-generation-with-past" ])
144
- class ChatGLM2OpenVINOConfig (TextDecoderOnnxConfig ):
139
+ class ChatGLM2OpenVINOConfig (TextDecoderWithPositionIdsOnnxConfig ):
145
140
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig .with_args (vocab_size = "padded_vocab_size" , num_layers = "num_layers" )
146
141
DUMMY_INPUT_GENERATOR_CLASSES = (ChatGLM2DummyTextInputGenerator , ChatGLM2DummyPastKeyValuesGenerator )
147
142
DUMMY_PKV_GENERATOR_CLASS = ChatGLM2DummyPastKeyValuesGenerator
148
- no_position_ids = False
149
143
150
144
def generate_dummy_inputs (self , framework : str = "pt" , ** kwargs ):
151
145
dummy_inputs_generators = self ._create_dummy_input_generator_classes (** kwargs )
@@ -173,34 +167,24 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
173
167
)
174
168
175
169
# refer to https://github.com/huggingface/optimum/pull/764
176
- cond1 = self .use_past_in_inputs
177
- cond2 = self .PAD_ATTENTION_MASK_TO_PAST
178
- cond3 = self .use_cache_branch is not False
179
- cond4 = "attention_mask" in dummy_inputs
180
- if cond1 and cond2 and cond3 and cond4 :
181
- # Obtain the past sequence length from the value instead of the key (Bloom).
182
- past_length = dummy_inputs ["past_key_values" ][0 ][1 ].shape [0 ]
183
- for k , v in dummy_inputs .items ():
184
- if k not in ["attention_mask" , "past_key_values" ]:
185
- dummy_inputs [k ] = v [:, - 1 :]
170
+ if (
171
+ self .use_past_in_inputs
172
+ and self .PAD_ATTENTION_MASK_TO_PAST
173
+ and self .use_cache_branch is not False
174
+ and "attention_mask" in dummy_inputs
175
+ ):
176
+ # Obtain the past sequence length from the value instead of the key (Bloom). ChatGLM has seq_len in 0 dim instead of -2
177
+ past_present_length = dummy_inputs ["input_ids" ].shape [1 ] + dummy_inputs ["past_key_values" ][0 ][1 ].shape [0 ]
186
178
187
179
dummy_inputs ["attention_mask" ] = DummyInputGenerator .pad_input_on_dim (
188
180
dummy_inputs ["attention_mask" ],
189
- desired_length = past_length + 1 ,
181
+ desired_length = past_present_length ,
190
182
dim = 1 ,
191
183
dtype = dummy_inputs ["attention_mask" ].dtype ,
192
184
)
193
185
194
186
return dummy_inputs
195
187
196
- @property
197
- def inputs (self ) -> Dict [str , Dict [int , str ]]:
198
- common_inputs = super ().inputs
199
- if not self .no_position_ids and self .task == "text-generation" :
200
- common_inputs ["position_ids" ] = {0 : "batch_size" , 1 : "sequence_length" }
201
-
202
- return common_inputs
203
-
204
188
def add_past_key_values (self , inputs_or_outputs : Dict [str , Dict [int , str ]], direction : str ):
205
189
"""
206
190
Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction.
@@ -218,7 +202,7 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
218
202
decoder_sequence_name = "past_sequence_length"
219
203
name = "past_key_values"
220
204
else :
221
- decoder_sequence_name = "past_sequence_length + 1 "
205
+ decoder_sequence_name = "past_sequence_length + present_lenght "
222
206
name = "present"
223
207
224
208
for i in range (self ._normalized_config .num_layers ):
0 commit comments