31
31
help = "For accuracy measurement only." )
32
32
parser .add_argument ("--save_accuracy_path" , default = None ,
33
33
help = "Save accuracy results path." )
34
- parser .add_argument ("--tasks" , type = str , default = "lambada_openai" ,
35
- help = "tasks list for accuracy validation" )
34
+ parser .add_argument ("--tasks" , nargs = "+" , default = ["lambada_openai" ], type = str ,
35
+ help = "tasks list for accuracy validation"
36
+ )
36
37
parser .add_argument ("--peft_model_id" , type = str , default = None , help = "model_name_or_path of peft model" )
37
38
38
39
args = parser .parse_args ()
@@ -54,57 +55,41 @@ def get_user_model():
54
55
return user_model , tokenizer
55
56
56
57
user_model , tokenizer = get_user_model ()
57
- if args .quantize :
58
- from neural_compressor .torch .quantization import MXQuantConfig , quantize
59
- quant_config = MXQuantConfig (w_dtype = args .w_dtype , act_dtype = args .act_dtype , weight_only = args .woq )
60
- user_model = quantize (model = user_model , quant_config = quant_config )
61
58
59
+ from neural_compressor .torch .quantization import MXQuantConfig , prepare , convert
60
+ quant_config = MXQuantConfig (w_dtype = args .w_dtype , act_dtype = args .act_dtype , weight_only = args .woq )
61
+ user_model = prepare (model = user_model , quant_config = quant_config )
62
+ user_model = convert (model = user_model )
63
+ user_model .eval ()
62
64
63
- if args .accuracy :
64
- user_model .eval ()
65
- from intel_extension_for_transformers .transformers .llm .evaluation .lm_eval import evaluate , LMEvalParser
66
- args = LMEvalParser (
67
- model = "hf" ,
68
- user_model = user_model ,
69
- tokenizer = tokenizer ,
70
- batch_size = args .batch_size ,
71
- tasks = args .tasks ,
72
- device = "cpu" ,
73
- )
74
- results = evaluate (args )
75
- dumped = json .dumps (results , indent = 2 )
76
- if args .save_accuracy_path :
77
- with open (args .save_accuracy_path , "w" ) as f :
78
- f .write (dumped )
79
- for task_name in args .tasks :
80
- if task_name == "wikitext" :
81
- acc = results ["results" ][task_name ]["word_perplexity" ]
82
- else :
83
- acc = results ["results" ][task_name ]["acc" ]
84
- print ("Accuracy: %.5f" % acc )
85
- print ('Batch size = %d' % args .batch_size )
65
+ from intel_extension_for_transformers .transformers .llm .evaluation .lm_eval import evaluate , LMEvalParser
66
+ eval_args = LMEvalParser (
67
+ model = "hf" ,
68
+ user_model = user_model ,
69
+ tokenizer = tokenizer ,
70
+ batch_size = args .batch_size ,
71
+ tasks = ',' .join (args .tasks ),
72
+ device = "cpu" ,
73
+ )
86
74
87
- if args .performance :
88
- user_model .eval ()
89
- from intel_extension_for_transformers .llm .evaluation .lm_eval import evaluate
90
- import time
91
- samples = args .iters * args .batch_size
92
- start = time .time ()
93
- results = evaluate (
94
- model = "hf" ,
95
- tokenizer = tokenizer ,
96
- user_model = user_model ,
97
- batch_size = args .batch_size ,
98
- tasks = args .tasks ,
99
- limit = samples ,
100
- )
101
- end = time .time ()
102
- for task_name in args .tasks :
103
- if task_name == "wikitext" :
104
- acc = results ["results" ][task_name ]["word_perplexity" ]
105
- else :
106
- acc = results ["results" ][task_name ]["acc" ]
107
- print ("Accuracy: %.5f" % acc )
108
- print ('Throughput: %.3f samples/sec' % (samples / (end - start )))
109
- print ('Latency: %.3f ms' % ((end - start )* 1000 / samples ))
110
- print ('Batch size = %d' % args .batch_size )
75
+ results = evaluate (eval_args )
76
+ dumped = json .dumps (results , indent = 2 )
77
+ if args .save_accuracy_path :
78
+ with open (args .save_accuracy_path , "w" ) as f :
79
+ f .write (dumped )
80
+
81
+ eval_acc = 0
82
+ for task_name in args .tasks :
83
+ if task_name == "wikitext" :
84
+ print ("Accuracy for %s is: %s" %
85
+ (task_name , results ["results" ][task_name ]["word_perplexity,none" ]))
86
+ eval_acc += results ["results" ][task_name ]["word_perplexity,none" ]
87
+ else :
88
+ print ("Accuracy for %s is: %s" %
89
+ (task_name , results ["results" ][task_name ]["acc,none" ]))
90
+ eval_acc += results ["results" ][task_name ]["acc,none" ]
91
+
92
+ if len (args .tasks ) != 0 :
93
+ eval_acc /= len (args .tasks )
94
+ print ("Accuracy: %.5f" % eval_acc )
95
+ print ('Batch size = %d' % args .batch_size )
0 commit comments