52
52
arg_parser .add_argument ('--int8' , dest = 'int8' , action = 'store_true' , help = 'whether to use int8 model for benchmark' )
53
53
args = arg_parser .parse_args ()
54
54
55
- def evaluate (model , eval_dataloader , metric , postprocess = None ):
55
+ def evaluate (model , eval_dataloader , postprocess = None ):
56
56
"""Custom evaluate function to estimate the accuracy of the model.
57
57
58
58
Args:
@@ -61,12 +61,14 @@ def evaluate(model, eval_dataloader, metric, postprocess=None):
61
61
Returns:
62
62
accuracy (float): evaluation result, the larger is better.
63
63
"""
64
+ from neural_compressor import METRICS
64
65
from neural_compressor .model import Model
65
66
model = Model (model )
66
67
input_tensor = model .input_tensor
67
68
output_tensor = model .output_tensor if len (model .output_tensor )> 1 else \
68
69
model .output_tensor [0 ]
69
70
iteration = - 1
71
+ metric = METRICS ('tensorflow' )['topk' ]()
70
72
if args .benchmark and args .mode == 'performance' :
71
73
iteration = args .iters
72
74
@@ -136,9 +138,6 @@ def run(self):
136
138
accuracy_criterion = AccuracyCriterion (tolerable_loss = 0.01 ),
137
139
op_type_dict = {'conv2d' :{ 'weight' :{'dtype' :['fp32' ]}, 'activation' :{'dtype' :['fp32' ]} }}
138
140
)
139
- from neural_compressor import METRICS
140
- metrics = METRICS ('tensorflow' )
141
- top1 = metrics ['topk' ]()
142
141
from tensorflow .core .protobuf import saved_model_pb2
143
142
sm = saved_model_pb2 .SavedModel ()
144
143
with tf .io .gfile .GFile (args .input_graph , "rb" ) as f :
@@ -147,10 +146,9 @@ def run(self):
147
146
from neural_compressor .data import TensorflowShiftRescale
148
147
postprocess = TensorflowShiftRescale ()
149
148
def eval (model ):
150
- return evaluate (model , eval_dataloader , top1 , postprocess )
151
- q_model = quantization .fit (graph_def , conf = conf , calib_dataloader = calib_dataloader ,
152
- # eval_dataloader=eval_dataloader, eval_metric=top1)
153
- eval_func = eval )
149
+ return evaluate (model , eval_dataloader , postprocess )
150
+ q_model = quantization .fit (graph_def , conf = conf , eval_func = eval ,
151
+ calib_dataloader = calib_dataloader )
154
152
q_model .save (args .output_graph )
155
153
156
154
if args .benchmark :
@@ -163,9 +161,6 @@ def eval(model):
163
161
'filter' : None
164
162
}
165
163
dataloader = create_dataloader ('tensorflow' , dataloader_args )
166
- from neural_compressor import METRICS
167
- metrics = METRICS ('tensorflow' )
168
- top1 = metrics ['topk' ]()
169
164
170
165
if args .int8 or args .input_graph .endswith ("-tune.pb" ):
171
166
input_graph = args .input_graph
@@ -180,7 +175,7 @@ def eval(model):
180
175
from neural_compressor .data import TensorflowShiftRescale
181
176
postprocess = TensorflowShiftRescale ()
182
177
def eval (model ):
183
- return evaluate (model , dataloader , top1 , postprocess )
178
+ return evaluate (model , dataloader , postprocess )
184
179
185
180
if args .mode == 'performance' :
186
181
from neural_compressor .benchmark import fit
0 commit comments