Skip to content

Commit 6733dab

Browse files
authored
update mx script (#1838)
Signed-off-by: Mengni Wang <mengni.wang@intel.com>
1 parent a0dee94 commit 6733dab

File tree

1 file changed

+38
-53
lines changed

1 file changed

+38
-53
lines changed

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/mx/run_clm_no_trainer.py

+38-53
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@
3131
help="For accuracy measurement only.")
3232
parser.add_argument("--save_accuracy_path", default=None,
3333
help="Save accuracy results path.")
34-
parser.add_argument("--tasks", type=str, default="lambada_openai",
35-
help="tasks list for accuracy validation")
34+
parser.add_argument("--tasks", nargs="+", default=["lambada_openai"], type=str,
35+
help="tasks list for accuracy validation"
36+
)
3637
parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model")
3738

3839
args = parser.parse_args()
@@ -54,57 +55,41 @@ def get_user_model():
5455
return user_model, tokenizer
5556

5657
user_model, tokenizer = get_user_model()
57-
if args.quantize:
58-
from neural_compressor.torch.quantization import MXQuantConfig, quantize
59-
quant_config = MXQuantConfig(w_dtype=args.w_dtype, act_dtype=args.act_dtype, weight_only=args.woq)
60-
user_model = quantize(model=user_model, quant_config=quant_config)
6158

59+
from neural_compressor.torch.quantization import MXQuantConfig, prepare, convert
60+
quant_config = MXQuantConfig(w_dtype=args.w_dtype, act_dtype=args.act_dtype, weight_only=args.woq)
61+
user_model = prepare(model=user_model, quant_config=quant_config)
62+
user_model = convert(model=user_model)
63+
user_model.eval()
6264

63-
if args.accuracy:
64-
user_model.eval()
65-
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
66-
args = LMEvalParser(
67-
model="hf",
68-
user_model=user_model,
69-
tokenizer=tokenizer,
70-
batch_size=args.batch_size,
71-
tasks=args.tasks,
72-
device="cpu",
73-
)
74-
results = evaluate(args)
75-
dumped = json.dumps(results, indent=2)
76-
if args.save_accuracy_path:
77-
with open(args.save_accuracy_path, "w") as f:
78-
f.write(dumped)
79-
for task_name in args.tasks:
80-
if task_name == "wikitext":
81-
acc = results["results"][task_name]["word_perplexity"]
82-
else:
83-
acc = results["results"][task_name]["acc"]
84-
print("Accuracy: %.5f" % acc)
85-
print('Batch size = %d' % args.batch_size)
65+
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
66+
eval_args = LMEvalParser(
67+
model="hf",
68+
user_model=user_model,
69+
tokenizer=tokenizer,
70+
batch_size=args.batch_size,
71+
tasks=','.join(args.tasks),
72+
device="cpu",
73+
)
8674

87-
if args.performance:
88-
user_model.eval()
89-
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
90-
import time
91-
samples = args.iters * args.batch_size
92-
start = time.time()
93-
results = evaluate(
94-
model="hf",
95-
tokenizer=tokenizer,
96-
user_model=user_model,
97-
batch_size=args.batch_size,
98-
tasks=args.tasks,
99-
limit=samples,
100-
)
101-
end = time.time()
102-
for task_name in args.tasks:
103-
if task_name == "wikitext":
104-
acc = results["results"][task_name]["word_perplexity"]
105-
else:
106-
acc = results["results"][task_name]["acc"]
107-
print("Accuracy: %.5f" % acc)
108-
print('Throughput: %.3f samples/sec' % (samples / (end - start)))
109-
print('Latency: %.3f ms' % ((end - start)*1000 / samples))
110-
print('Batch size = %d' % args.batch_size)
75+
results = evaluate(eval_args)
76+
dumped = json.dumps(results, indent=2)
77+
if args.save_accuracy_path:
78+
with open(args.save_accuracy_path, "w") as f:
79+
f.write(dumped)
80+
81+
eval_acc = 0
82+
for task_name in args.tasks:
83+
if task_name == "wikitext":
84+
print("Accuracy for %s is: %s" %
85+
(task_name, results["results"][task_name]["word_perplexity,none"]))
86+
eval_acc += results["results"][task_name]["word_perplexity,none"]
87+
else:
88+
print("Accuracy for %s is: %s" %
89+
(task_name, results["results"][task_name]["acc,none"]))
90+
eval_acc += results["results"][task_name]["acc,none"]
91+
92+
if len(args.tasks) != 0:
93+
eval_acc /= len(args.tasks)
94+
print("Accuracy: %.5f" % eval_acc)
95+
print('Batch size = %d' % args.batch_size)

0 commit comments

Comments
 (0)