Skip to content

Commit 26b260e

Browse files
authored
ONNXRT LLM examples support latest optimum version (#1578)
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
1 parent ac47d9b commit 26b260e

File tree

12 files changed

+316
-232
lines changed

12 files changed

+316
-232
lines changed

examples/.config/model_params_onnxrt.json

+49-7
Original file line numberDiff line numberDiff line change
@@ -756,45 +756,87 @@
756756
"main_script": "main.py",
757757
"batch_size": 1
758758
},
759-
"llama-2-7b": {
759+
"llama-2-7b-sq": {
760760
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/ptq_static",
761761
"dataset_location": "",
762-
"input_model": "/tf_dataset2/models/onnx/llama-2-7b",
762+
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf",
763+
"main_script": "main.py",
764+
"batch_size": 1
765+
},
766+
"llama-2-7b-sq-with-past": {
767+
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/ptq_static",
768+
"dataset_location": "",
769+
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past",
763770
"main_script": "main.py",
764771
"batch_size": 1
765772
},
766773
"llama-2-7b-lwq": {
767774
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/ptq_static",
768775
"dataset_location": "",
769-
"input_model": "/tf_dataset2/models/onnx/llama-2-7b",
776+
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf",
777+
"main_script": "main.py",
778+
"batch_size": 1
779+
},
780+
"llama-2-7b-with-past-lwq": {
781+
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/ptq_static",
782+
"dataset_location": "",
783+
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past",
770784
"main_script": "main.py",
771785
"batch_size": 1
772786
},
773787
"llama-2-7b-rtn": {
774788
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
775789
"dataset_location": "",
776-
"input_model": "/tf_dataset2/models/onnx/llama-2-7b",
790+
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf",
791+
"main_script": "main.py",
792+
"batch_size": 1
793+
},
794+
"llama-2-7b-rtn-with-past": {
795+
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
796+
"dataset_location": "",
797+
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past",
777798
"main_script": "main.py",
778799
"batch_size": 1
779800
},
780801
"llama-2-7b-awq": {
781802
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
782803
"dataset_location": "",
783-
"input_model": "/tf_dataset2/models/onnx/llama-2-7b",
804+
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf",
805+
"main_script": "main.py",
806+
"batch_size": 1
807+
},
808+
"llama-2-7b-awq-with-past": {
809+
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
810+
"dataset_location": "",
811+
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past",
784812
"main_script": "main.py",
785813
"batch_size": 1
786814
},
787815
"llama-2-7b-gptq": {
788816
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
789817
"dataset_location": "",
790-
"input_model": "/tf_dataset2/models/onnx/llama-2-7b",
818+
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf",
819+
"main_script": "main.py",
820+
"batch_size": 1
821+
},
822+
"llama-2-7b-gptq-with-past": {
823+
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
824+
"dataset_location": "",
825+
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past",
791826
"main_script": "main.py",
792827
"batch_size": 1
793828
},
794829
"llama-2-7b-woq_tune": {
795830
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
796831
"dataset_location": "",
797-
"input_model": "/tf_dataset2/models/onnx/llama-2-7b",
832+
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf",
833+
"main_script": "main.py",
834+
"batch_size": 1
835+
},
836+
"llama-2-7b-woq_tune-with-past": {
837+
"model_src_dir": "nlp/huggingface_model/text_generation/llama/quantization/weight_only",
838+
"dataset_location": "",
839+
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past",
798840
"main_script": "main.py",
799841
"batch_size": 1
800842
},

examples/onnxrt/nlp/huggingface_model/text_generation/llama/quantization/ptq_static/README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ Note that this README.md uses meta-llama/Llama-2-7b-hf as an example. There are
2727

2828
Export to ONNX model:
2929
```bash
30-
python prepare_model.py --input_model="meta-llama/Llama-2-7b-hf" --output_model="./llama-2-7b-hf"
30+
python prepare_model.py --input_model="meta-llama/Llama-2-7b-hf" \
31+
--output_model="./llama-2-7b-hf" \
32+
--task=text-generation-with-past \ # or text-generation
3133
```
3234

3335
# Run
@@ -41,7 +43,7 @@ bash run_quant.sh --input_model=/path/to/model \ # folder path of onnx model
4143
--output_model=/path/to/model_tune \ # folder path to save onnx model
4244
--batch_size=batch_size # optional \
4345
--dataset NeelNanda/pile-10k \
44-
--alpha 0.6 \ # 0.6 for llama-7b, 0.8 for llama-13b
46+
--alpha 0.75 \
4547
--tokenizer=meta-llama/Llama-2-7b-hf \ # model name or folder path containing all relevant files for model's tokenizer
4648
--quant_format="QOperator" # or QDQ, optional
4749
```

examples/onnxrt/nlp/huggingface_model/text_generation/llama/quantization/ptq_static/main.py

+60-68
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
parser.add_argument(
8585
'--quant_format',
8686
type=str,
87-
default='QOperator',
87+
default='QOperator',
8888
choices=['QOperator', 'QDQ'],
8989
help="quantization format"
9090
)
@@ -124,8 +124,9 @@
124124
)
125125
args = parser.parse_args()
126126

127-
# load model
127+
# load model tokenize and config
128128
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer)
129+
config = LlamaConfig.from_pretrained(args.model_path)
129130

130131
def tokenize_function(examples):
131132
example = tokenizer(examples['text'])
@@ -134,29 +135,20 @@ def tokenize_function(examples):
134135
def benchmark(model):
135136
import json
136137
import time
137-
config = LlamaConfig.from_pretrained(args.model_path)
138138
sess_options = ort.SessionOptions()
139139
sess_options.intra_op_num_threads = args.intra_op_num_threads
140-
141-
if os.path.exists(os.path.join(model, "decoder_with_past_model.onnx")):
142-
sessions = ORTModelForCausalLM.load_model( # pylint: disable=E1123
143-
os.path.join(model, "decoder_model.onnx"),
144-
os.path.join(model, "decoder_with_past_model.onnx"),
145-
session_options=sess_options)
146-
model = ORTModelForCausalLM(sessions[0], # pylint: disable=E1121
147-
config,
148-
model,
149-
sessions[1],
150-
use_cache=True)
151-
else:
152-
sessions = ORTModelForCausalLM.load_model( # pylint: disable=E1123
153-
os.path.join(model, "decoder_model.onnx"),
154-
session_options=sess_options)
155-
model = ORTModelForCausalLM(sessions[0], # pylint: disable=E1121
156-
config,
157-
model,
158-
use_cache=False,
159-
use_io_binding=False)
140+
141+
session = ORTModelForCausalLM.load_model( # pylint: disable=E1123
142+
os.path.join(model, "model.onnx"),
143+
session_options=sess_options)
144+
inputs_names = session.get_inputs()
145+
key_value_input_names = [key.name for key in inputs_names if (".key" in key.name) or (".value" in key.name)]
146+
use_cache = len(key_value_input_names) > 0
147+
148+
model = ORTModelForCausalLM(session, # pylint: disable=E1121
149+
config,
150+
use_cache=True if use_cache else False,
151+
use_io_binding=True if use_cache else False,)
160152

161153
input_tokens = '32'
162154
max_new_tokens = 32
@@ -192,7 +184,7 @@ def benchmark(model):
192184
print(args)
193185
throughput = (num_iter - num_warmup) / total_time
194186
print("Throughput: {} samples/s".format(throughput))
195-
187+
196188

197189
def replace_architectures(json_path):
198190
# replace 'LLaMATokenizer' to lowercase 'LlamaTokenizer'
@@ -201,7 +193,7 @@ def replace_architectures(json_path):
201193
with open(json_path, "r") as file:
202194
data = json.load(file)
203195
data["architectures"] = ["LlamaForCausalLM"]
204-
196+
205197
with open(json_path, 'w') as file:
206198
json.dump(data, file, indent=4)
207199

@@ -234,6 +226,7 @@ def eval_func(model):
234226

235227
return eval_acc
236228

229+
237230
class KVDataloader:
238231
def __init__(self, model_path, pad_max=196, batch_size=1, sub_folder='train'):
239232
self.pad_max = pad_max
@@ -247,10 +240,11 @@ def __init__(self, model_path, pad_max=196, batch_size=1, sub_folder='train'):
247240
shuffle=False,
248241
collate_fn=self.collate_batch,
249242
)
250-
self.sess = None
251-
if not model_path.endswith('decoder_model.onnx'):
252-
self.sess = ort.InferenceSession(os.path.join(os.path.dirname(model_path), 'decoder_model.onnx'))
253-
243+
session = ort.InferenceSession(model_path)
244+
inputs_names = [input.name for input in session.get_inputs()]
245+
self.key_value_input_names = [key for key in inputs_names if (".key" in key) or (".value" in key)]
246+
self.use_cache = len(self.key_value_input_names) > 0
247+
self.session = session if self.use_cache else None
254248

255249
def collate_batch(self, batch):
256250

@@ -269,23 +263,26 @@ def collate_batch(self, batch):
269263
attention_mask_padded.append(attention_mask)
270264
return (torch.vstack(input_ids_padded), torch.vstack(attention_mask_padded)), torch.tensor(last_ind)
271265

272-
273266
def __iter__(self):
274267
try:
275268
for (input_ids, attention_mask), last_ind in self.dataloader:
276-
if self.sess is None:
277-
yield {'input_ids': input_ids[:, :-1].detach().cpu().numpy().astype('int64'),
278-
'attention_mask':attention_mask[:, :-1].detach().cpu().numpy().astype('int64')}, last_ind.detach().cpu().numpy()
279-
else:
280-
outputs = self.sess.run(None, {'input_ids': input_ids[:, :-1].detach().cpu().numpy().astype('int64'),
281-
'attention_mask':attention_mask[:, :-1].detach().cpu().numpy().astype('int64')})
282-
ort_input = {}
283-
ort_input['input_ids'] = input_ids[:, -1].unsqueeze(0).detach().cpu().numpy().astype('int64')
284-
for i in range(int((len(outputs) - 1) / 2)):
285-
ort_input['past_key_values.{}.key'.format(i)] = outputs[i*2+1]
286-
ort_input['past_key_values.{}.value'.format(i)] = outputs[i*2+2]
287-
ort_input['attention_mask'] = np.zeros([self.batch_size, ort_input['past_key_values.0.key'].shape[2]+1], dtype='int64')
288-
yield ort_input, last_ind.detach().cpu().numpy()
269+
ort_input = {}
270+
ort_input["input_ids"] = input_ids[:, :-1].detach().cpu().numpy().astype("int64")
271+
ort_input["attention_mask"] = attention_mask[:, :-1].detach().cpu().numpy().astype("int64")
272+
position_ids = attention_mask.long().cumsum(-1) - 1
273+
position_ids.masked_fill_(attention_mask == 0, 1)
274+
ort_input["position_ids"] = position_ids[:,:-1].detach().cpu().numpy().astype("int64")
275+
if self.use_cache:
276+
# Create dummy past_key_values for decoder
277+
num_attention_heads = config.num_key_value_heads
278+
embed_size_per_head = config.hidden_size // config.num_attention_heads
279+
shape = (self.batch_size, num_attention_heads, 0, embed_size_per_head)
280+
key_or_value = np.zeros(shape, dtype=np.float32)
281+
for key_value_input_name in self.key_value_input_names:
282+
ort_input[key_value_input_name] = key_or_value
283+
284+
yield ort_input, last_ind.detach().cpu().numpy()
285+
289286
except StopIteration:
290287
return
291288

@@ -294,43 +291,38 @@ def __iter__(self):
294291
set_workspace(args.workspace)
295292

296293
if args.benchmark:
297-
if args.mode == 'performance':
294+
if args.mode == 'performance':
298295
benchmark(args.model_path)
299296
elif args.mode == 'accuracy':
300297
eval_func(args.model_path)
301298

302299
if args.tune:
303300
from neural_compressor import quantization, PostTrainingQuantConfig
301+
302+
model_name = "model.onnx" # require optimum >= 1.14.0
303+
model_path = os.path.join(args.model_path, model_name)
304+
304305
if args.layer_wise:
305306
# layer-wise quantization for ONNX models is still under development and only support W8A8 quantization now
306-
config = PostTrainingQuantConfig(
307+
ptq_config = PostTrainingQuantConfig(
307308
calibration_sampling_size=[8],
308309
recipes={'optypes_to_exclude_output_quant': ['MatMul'],
309-
'layer_wise_quant': True},
310+
'layer_wise_quant': True,
311+
'graph_optimization_level': 'ENABLE_EXTENDED'},
310312
op_type_dict={'^((?!(MatMul|Gather|Conv)).)*$': {'weight': {'dtype': ['fp32']}, 'activation': {'dtype': ['fp32']}}})
311-
for model in ['decoder_model.onnx']:
312-
# only test decoder_model
313-
q_model = quantization.fit(
314-
os.path.join(args.model_path, model),
315-
config,
316-
calib_dataloader=KVDataloader(os.path.join(args.model_path, model), pad_max=args.pad_max, batch_size=1))
317-
q_model.save(os.path.join(args.output_model, model))
318-
319-
tokenizer.save_pretrained(args.output_model)
320-
321313
else:
322-
config = PostTrainingQuantConfig(
314+
ptq_config = PostTrainingQuantConfig(
323315
calibration_sampling_size=[8],
324316
recipes={'optypes_to_exclude_output_quant': ['MatMul'],
325-
'smooth_quant': True,
326-
'smooth_quant_args': {'alpha': args.smooth_quant_alpha},
327-
},
317+
'smooth_quant': True,
318+
'smooth_quant_args': {'alpha': args.smooth_quant_alpha},
319+
'graph_optimization_level': 'ENABLE_EXTENDED'},
328320
op_type_dict={'^((?!(MatMul|Gather|Conv)).)*$': {'weight': {'dtype': ['fp32']}, 'activation': {'dtype': ['fp32']}}})
329-
for model in ['decoder_model.onnx', 'decoder_with_past_model.onnx']:
330-
q_model = quantization.fit(
331-
os.path.join(args.model_path, model),
332-
config,
333-
calib_dataloader=KVDataloader(os.path.join(args.model_path, model), pad_max=args.pad_max, batch_size=1))
334-
q_model.save(os.path.join(args.output_model, model))
335-
336-
tokenizer.save_pretrained(args.output_model)
321+
322+
q_model = quantization.fit(
323+
model_path,
324+
ptq_config,
325+
calib_dataloader=KVDataloader(model_path, pad_max=args.pad_max, batch_size=1))
326+
q_model.save(os.path.join(args.output_model, model_name))
327+
328+
tokenizer.save_pretrained(args.output_model)

examples/onnxrt/nlp/huggingface_model/text_generation/llama/quantization/ptq_static/prepare_model.py

+24-33
Original file line numberDiff line numberDiff line change
@@ -10,46 +10,37 @@ def parse_arguments():
1010
parser = argparse.ArgumentParser()
1111
parser.add_argument("--input_model", type=str, required=False, default="")
1212
parser.add_argument("--output_model", type=str, required=True)
13+
parser.add_argument("--task",
14+
type=str,
15+
required=False,
16+
default="text-generation-with-past",
17+
choices=["text-generation-with-past", "text-generation"])
1318
return parser.parse_args()
1419

1520

16-
def prepare_model(input_model, output_model):
21+
def prepare_model(input_model, output_model, task):
1722
print("\nexport model...")
18-
if Version(optimum.version.__version__) >= OPTIMUM114_VERSION:
19-
subprocess.run(
20-
[
21-
"optimum-cli",
22-
"export",
23-
"onnx",
24-
"--model",
25-
f"{input_model}",
26-
"--task",
27-
"text-generation-with-past",
28-
"--legacy",
29-
f"{output_model}",
30-
],
31-
stdout=subprocess.PIPE,
32-
text=True,
33-
)
34-
else:
35-
subprocess.run(
36-
[
37-
"optimum-cli",
38-
"export",
39-
"onnx",
40-
"--model",
41-
f"{input_model}",
42-
"--task",
43-
"text-generation-with-past",
44-
f"{output_model}",
45-
],
46-
stdout=subprocess.PIPE,
47-
text=True,
48-
)
23+
if Version(optimum.version.__version__) < OPTIMUM114_VERSION:
24+
raise ImportError("Please upgrade optimum to >= 1.14.0")
25+
26+
subprocess.run(
27+
[
28+
"optimum-cli",
29+
"export",
30+
"onnx",
31+
"--model",
32+
f"{input_model}",
33+
"--task",
34+
task,
35+
f"{output_model}",
36+
],
37+
stdout=subprocess.PIPE,
38+
text=True,
39+
)
4940

5041
assert os.path.exists(output_model), f"{output_model} doesn't exist!"
5142

5243

5344
if __name__ == "__main__":
5445
args = parse_arguments()
55-
prepare_model(args.input_model, args.output_model)
46+
prepare_model(args.input_model, args.output_model, args.task)

0 commit comments

Comments
 (0)