84
84
parser .add_argument (
85
85
'--quant_format' ,
86
86
type = str ,
87
- default = 'QOperator' ,
87
+ default = 'QOperator' ,
88
88
choices = ['QOperator' , 'QDQ' ],
89
89
help = "quantization format"
90
90
)
124
124
)
125
125
args = parser .parse_args ()
126
126
127
- # load model
127
+ # load model tokenize and config
128
128
tokenizer = LlamaTokenizer .from_pretrained (args .tokenizer )
129
+ config = LlamaConfig .from_pretrained (args .model_path )
129
130
130
131
def tokenize_function (examples ):
131
132
example = tokenizer (examples ['text' ])
@@ -134,29 +135,20 @@ def tokenize_function(examples):
134
135
def benchmark (model ):
135
136
import json
136
137
import time
137
- config = LlamaConfig .from_pretrained (args .model_path )
138
138
sess_options = ort .SessionOptions ()
139
139
sess_options .intra_op_num_threads = args .intra_op_num_threads
140
-
141
- if os .path .exists (os .path .join (model , "decoder_with_past_model.onnx" )):
142
- sessions = ORTModelForCausalLM .load_model ( # pylint: disable=E1123
143
- os .path .join (model , "decoder_model.onnx" ),
144
- os .path .join (model , "decoder_with_past_model.onnx" ),
145
- session_options = sess_options )
146
- model = ORTModelForCausalLM (sessions [0 ], # pylint: disable=E1121
147
- config ,
148
- model ,
149
- sessions [1 ],
150
- use_cache = True )
151
- else :
152
- sessions = ORTModelForCausalLM .load_model ( # pylint: disable=E1123
153
- os .path .join (model , "decoder_model.onnx" ),
154
- session_options = sess_options )
155
- model = ORTModelForCausalLM (sessions [0 ], # pylint: disable=E1121
156
- config ,
157
- model ,
158
- use_cache = False ,
159
- use_io_binding = False )
140
+
141
+ session = ORTModelForCausalLM .load_model ( # pylint: disable=E1123
142
+ os .path .join (model , "model.onnx" ),
143
+ session_options = sess_options )
144
+ inputs_names = session .get_inputs ()
145
+ key_value_input_names = [key .name for key in inputs_names if (".key" in key .name ) or (".value" in key .name )]
146
+ use_cache = len (key_value_input_names ) > 0
147
+
148
+ model = ORTModelForCausalLM (session , # pylint: disable=E1121
149
+ config ,
150
+ use_cache = True if use_cache else False ,
151
+ use_io_binding = True if use_cache else False ,)
160
152
161
153
input_tokens = '32'
162
154
max_new_tokens = 32
@@ -192,7 +184,7 @@ def benchmark(model):
192
184
print (args )
193
185
throughput = (num_iter - num_warmup ) / total_time
194
186
print ("Throughput: {} samples/s" .format (throughput ))
195
-
187
+
196
188
197
189
def replace_architectures (json_path ):
198
190
# replace 'LLaMATokenizer' to lowercase 'LlamaTokenizer'
@@ -201,7 +193,7 @@ def replace_architectures(json_path):
201
193
with open (json_path , "r" ) as file :
202
194
data = json .load (file )
203
195
data ["architectures" ] = ["LlamaForCausalLM" ]
204
-
196
+
205
197
with open (json_path , 'w' ) as file :
206
198
json .dump (data , file , indent = 4 )
207
199
@@ -234,6 +226,7 @@ def eval_func(model):
234
226
235
227
return eval_acc
236
228
229
+
237
230
class KVDataloader :
238
231
def __init__ (self , model_path , pad_max = 196 , batch_size = 1 , sub_folder = 'train' ):
239
232
self .pad_max = pad_max
@@ -247,10 +240,11 @@ def __init__(self, model_path, pad_max=196, batch_size=1, sub_folder='train'):
247
240
shuffle = False ,
248
241
collate_fn = self .collate_batch ,
249
242
)
250
- self .sess = None
251
- if not model_path .endswith ('decoder_model.onnx' ):
252
- self .sess = ort .InferenceSession (os .path .join (os .path .dirname (model_path ), 'decoder_model.onnx' ))
253
-
243
+ session = ort .InferenceSession (model_path )
244
+ inputs_names = [input .name for input in session .get_inputs ()]
245
+ self .key_value_input_names = [key for key in inputs_names if (".key" in key ) or (".value" in key )]
246
+ self .use_cache = len (self .key_value_input_names ) > 0
247
+ self .session = session if self .use_cache else None
254
248
255
249
def collate_batch (self , batch ):
256
250
@@ -269,23 +263,26 @@ def collate_batch(self, batch):
269
263
attention_mask_padded .append (attention_mask )
270
264
return (torch .vstack (input_ids_padded ), torch .vstack (attention_mask_padded )), torch .tensor (last_ind )
271
265
272
-
273
266
def __iter__ (self ):
274
267
try :
275
268
for (input_ids , attention_mask ), last_ind in self .dataloader :
276
- if self .sess is None :
277
- yield {'input_ids' : input_ids [:, :- 1 ].detach ().cpu ().numpy ().astype ('int64' ),
278
- 'attention_mask' :attention_mask [:, :- 1 ].detach ().cpu ().numpy ().astype ('int64' )}, last_ind .detach ().cpu ().numpy ()
279
- else :
280
- outputs = self .sess .run (None , {'input_ids' : input_ids [:, :- 1 ].detach ().cpu ().numpy ().astype ('int64' ),
281
- 'attention_mask' :attention_mask [:, :- 1 ].detach ().cpu ().numpy ().astype ('int64' )})
282
- ort_input = {}
283
- ort_input ['input_ids' ] = input_ids [:, - 1 ].unsqueeze (0 ).detach ().cpu ().numpy ().astype ('int64' )
284
- for i in range (int ((len (outputs ) - 1 ) / 2 )):
285
- ort_input ['past_key_values.{}.key' .format (i )] = outputs [i * 2 + 1 ]
286
- ort_input ['past_key_values.{}.value' .format (i )] = outputs [i * 2 + 2 ]
287
- ort_input ['attention_mask' ] = np .zeros ([self .batch_size , ort_input ['past_key_values.0.key' ].shape [2 ]+ 1 ], dtype = 'int64' )
288
- yield ort_input , last_ind .detach ().cpu ().numpy ()
269
+ ort_input = {}
270
+ ort_input ["input_ids" ] = input_ids [:, :- 1 ].detach ().cpu ().numpy ().astype ("int64" )
271
+ ort_input ["attention_mask" ] = attention_mask [:, :- 1 ].detach ().cpu ().numpy ().astype ("int64" )
272
+ position_ids = attention_mask .long ().cumsum (- 1 ) - 1
273
+ position_ids .masked_fill_ (attention_mask == 0 , 1 )
274
+ ort_input ["position_ids" ] = position_ids [:,:- 1 ].detach ().cpu ().numpy ().astype ("int64" )
275
+ if self .use_cache :
276
+ # Create dummy past_key_values for decoder
277
+ num_attention_heads = config .num_key_value_heads
278
+ embed_size_per_head = config .hidden_size // config .num_attention_heads
279
+ shape = (self .batch_size , num_attention_heads , 0 , embed_size_per_head )
280
+ key_or_value = np .zeros (shape , dtype = np .float32 )
281
+ for key_value_input_name in self .key_value_input_names :
282
+ ort_input [key_value_input_name ] = key_or_value
283
+
284
+ yield ort_input , last_ind .detach ().cpu ().numpy ()
285
+
289
286
except StopIteration :
290
287
return
291
288
@@ -294,43 +291,38 @@ def __iter__(self):
294
291
set_workspace (args .workspace )
295
292
296
293
if args .benchmark :
297
- if args .mode == 'performance' :
294
+ if args .mode == 'performance' :
298
295
benchmark (args .model_path )
299
296
elif args .mode == 'accuracy' :
300
297
eval_func (args .model_path )
301
298
302
299
if args .tune :
303
300
from neural_compressor import quantization , PostTrainingQuantConfig
301
+
302
+ model_name = "model.onnx" # require optimum >= 1.14.0
303
+ model_path = os .path .join (args .model_path , model_name )
304
+
304
305
if args .layer_wise :
305
306
# layer-wise quantization for ONNX models is still under development and only support W8A8 quantization now
306
- config = PostTrainingQuantConfig (
307
+ ptq_config = PostTrainingQuantConfig (
307
308
calibration_sampling_size = [8 ],
308
309
recipes = {'optypes_to_exclude_output_quant' : ['MatMul' ],
309
- 'layer_wise_quant' : True },
310
+ 'layer_wise_quant' : True ,
311
+ 'graph_optimization_level' : 'ENABLE_EXTENDED' },
310
312
op_type_dict = {'^((?!(MatMul|Gather|Conv)).)*$' : {'weight' : {'dtype' : ['fp32' ]}, 'activation' : {'dtype' : ['fp32' ]}}})
311
- for model in ['decoder_model.onnx' ]:
312
- # only test decoder_model
313
- q_model = quantization .fit (
314
- os .path .join (args .model_path , model ),
315
- config ,
316
- calib_dataloader = KVDataloader (os .path .join (args .model_path , model ), pad_max = args .pad_max , batch_size = 1 ))
317
- q_model .save (os .path .join (args .output_model , model ))
318
-
319
- tokenizer .save_pretrained (args .output_model )
320
-
321
313
else :
322
- config = PostTrainingQuantConfig (
314
+ ptq_config = PostTrainingQuantConfig (
323
315
calibration_sampling_size = [8 ],
324
316
recipes = {'optypes_to_exclude_output_quant' : ['MatMul' ],
325
- 'smooth_quant' : True ,
326
- 'smooth_quant_args' : {'alpha' : args .smooth_quant_alpha },
327
- },
317
+ 'smooth_quant' : True ,
318
+ 'smooth_quant_args' : {'alpha' : args .smooth_quant_alpha },
319
+ 'graph_optimization_level' : 'ENABLE_EXTENDED' },
328
320
op_type_dict = {'^((?!(MatMul|Gather|Conv)).)*$' : {'weight' : {'dtype' : ['fp32' ]}, 'activation' : {'dtype' : ['fp32' ]}}})
329
- for model in [ 'decoder_model.onnx' , 'decoder_with_past_model.onnx' ]:
330
- q_model = quantization .fit (
331
- os . path . join ( args . model_path , model ) ,
332
- config ,
333
- calib_dataloader = KVDataloader (os . path . join ( args . model_path , model ) , pad_max = args .pad_max , batch_size = 1 ))
334
- q_model .save (os .path .join (args .output_model , model ))
335
-
336
- tokenizer .save_pretrained (args .output_model )
321
+
322
+ q_model = quantization .fit (
323
+ model_path ,
324
+ ptq_config ,
325
+ calib_dataloader = KVDataloader (model_path , pad_max = args .pad_max , batch_size = 1 ))
326
+ q_model .save (os .path .join (args .output_model , model_name ))
327
+
328
+ tokenizer .save_pretrained (args .output_model )
0 commit comments