13
13
from typing import Any , Callable , Iterable , List , Optional , Tuple , TypeVar , Union
14
14
15
15
import nncf
16
- from nncf .common .factory import EngineFactory
17
16
from nncf .common .logging import nncf_logger
18
17
from nncf .common .utils .backend import BackendType
19
18
from nncf .common .utils .backend import get_backend
20
19
from nncf .common .utils .timer import timer
21
20
from nncf .data .dataset import Dataset
21
+ from nncf .quantization .algorithms .accuracy_control .backend import PreparedModel
22
22
23
23
TModel = TypeVar ("TModel" )
24
- TPModel = TypeVar ("TPModel" )
25
24
TTensor = TypeVar ("TTensor" )
26
25
27
26
@@ -112,7 +111,7 @@ def is_metric_mode(self) -> bool:
112
111
"""
113
112
return self ._metric_mode
114
113
115
- def prepare_model_for_inference (self , model : TModel ) -> TPModel :
114
+ def prepare_model (self , model : TModel ) -> PreparedModel :
116
115
"""
117
116
Prepares model for inference.
118
117
@@ -122,21 +121,19 @@ def prepare_model_for_inference(self, model: TModel) -> TPModel:
122
121
backend = get_backend (model )
123
122
124
123
if backend == BackendType .OPENVINO :
125
- import openvino . runtime as ov
124
+ from nncf . quantization . algorithms . accuracy_control . openvino_backend import OVPreparedModel
126
125
127
- return ov . compile_model (model )
126
+ return OVPreparedModel (model )
128
127
129
- raise NotImplementedError (
130
- f"The `prepare_model_for_inference()` method is not implemented for the { backend } backend."
131
- )
128
+ raise NotImplementedError (f"The `prepare_model()` method is not implemented for the { backend } backend." )
132
129
133
- def validate_model_for_inference (
134
- self , model_for_inference : TPModel , dataset : Dataset , indices : Optional [List [int ]] = None
130
+ def validate_prepared_model (
131
+ self , prepared_model : PreparedModel , dataset : Dataset , indices : Optional [List [int ]] = None
135
132
):
136
133
"""
137
134
Validates prepared model for inference.
138
135
139
- :param model : Prepared model to validate.
136
+ :param prepared_model : Prepared model to validate.
140
137
:param dataset: Dataset to validate the model.
141
138
:param indices: Zero-based indices of data items that should be selected from
142
139
the dataset.
@@ -148,7 +145,7 @@ def validate_model_for_inference(
148
145
item.
149
146
"""
150
147
if self ._metric_mode is None :
151
- self ._metric_mode = Evaluator .determine_mode (model_for_inference , dataset , self ._validation_fn )
148
+ self ._metric_mode = Evaluator .determine_mode (prepared_model , dataset , self ._validation_fn )
152
149
153
150
if not self .is_metric_mode () and indices is not None :
154
151
raise ValueError ("The `indices` parameter can be used only if Evaluator.is_metric_mode() = True" )
@@ -157,7 +154,7 @@ def validate_model_for_inference(
157
154
if self ._enable_iteration_count :
158
155
validation_dataset = IterationCounter (validation_dataset )
159
156
160
- metric , values_for_each_item = self ._validation_fn (model_for_inference , validation_dataset )
157
+ metric , values_for_each_item = self ._validation_fn (prepared_model . model_for_inference , validation_dataset )
161
158
162
159
self ._num_passed_iterations = validation_dataset .num_iterations if self ._enable_iteration_count else 0
163
160
@@ -190,20 +187,20 @@ def validate(
190
187
Otherwise, if the condition is false, it represents list of logits for each
191
188
item.
192
189
"""
193
- model_for_inference = self .prepare_model_for_inference (model )
194
- return self .validate_model_for_inference ( model_for_inference , dataset , indices )
190
+ prepared_model = self .prepare_model (model )
191
+ return self .validate_prepared_model ( prepared_model , dataset , indices )
195
192
196
193
@staticmethod
197
194
def determine_mode (
198
- model_for_inference : TPModel ,
195
+ prepared_model : PreparedModel ,
199
196
dataset : Dataset ,
200
197
validation_fn : Callable [[Any , Iterable [Any ]], Tuple [float , Union [None , List [float ], List [List [TTensor ]]]]],
201
198
) -> bool :
202
199
"""
203
200
Determines mode based on the type of returned value from the
204
201
validation function.
205
202
206
- :param model_for_inference : Model to validate.
203
+ :param prepared_model : Model to validate.
207
204
:param dataset: Dataset to validate the model.
208
205
:param validation_fn: Validation function to validate model.
209
206
:return: A boolean indicator where `True` means that the `Evaluator` collects
@@ -215,7 +212,7 @@ def determine_mode(
215
212
data_item = dataset .get_data ([0 ])
216
213
217
214
try :
218
- metric_value , values_for_each_item = validation_fn (model_for_inference , data_item )
215
+ metric_value , values_for_each_item = validation_fn (prepared_model . model_for_inference , data_item )
219
216
except Exception :
220
217
metric_mode = False
221
218
@@ -262,15 +259,15 @@ def determine_mode(
262
259
263
260
return metric_mode
264
261
265
- def collect_values_for_each_item_using_model_for_inference (
266
- self , model_for_inference : TPModel , dataset : Dataset , indices : Optional [List [int ]] = None
262
+ def collect_values_for_each_item_using_prepared_model (
263
+ self , prepared_model : PreparedModel , dataset : Dataset , indices : Optional [List [int ]] = None
267
264
) -> Union [List [float ], List [List [TTensor ]]]:
268
265
"""
269
266
Collects value for each item from the dataset using prepared model for inference.
270
267
If `is_metric_mode()` returns `True` then i-th value is a metric for i-th data item.
271
268
It is an output of the model for i-th data item otherwise.
272
269
273
- :param model : Model to infer.
270
+ :param prepared_model : Model to infer.
274
271
:param dataset: Dataset to collect values.
275
272
:param indices: The zero-based indices of data items that should be selected from
276
273
the dataset.
@@ -279,15 +276,14 @@ def collect_values_for_each_item_using_model_for_inference(
279
276
if self ._metric_mode :
280
277
# Collect metrics for each item
281
278
values_for_each_item = [
282
- self ._validation_fn (model_for_inference , [data_item ])[0 ] for data_item in dataset .get_data (indices )
279
+ self ._validation_fn (prepared_model .model_for_inference , [data_item ])[0 ]
280
+ for data_item in dataset .get_data (indices )
283
281
]
284
282
else :
285
283
# Collect outputs for each item
286
- engine = EngineFactory .create (model_for_inference )
287
-
288
284
values_for_each_item = []
289
285
for data_item in dataset .get_inference_data (indices ):
290
- logits = engine . infer (data_item )
286
+ logits = prepared_model (data_item )
291
287
values_for_each_item .append (list (logits .values ()))
292
288
293
289
self ._num_passed_iterations = len (values_for_each_item ) if self ._enable_iteration_count else 0
@@ -308,8 +304,8 @@ def collect_values_for_each_item(
308
304
the dataset.
309
305
:return: Collected values.
310
306
"""
311
- model_for_inference = self .prepare_model_for_inference (model )
312
- return self .collect_values_for_each_item_using_model_for_inference ( model_for_inference , dataset , indices )
307
+ prepared_model = self .prepare_model (model )
308
+ return self .collect_values_for_each_item_using_prepared_model ( prepared_model , dataset , indices )
313
309
314
310
def collect_metric_results (self , model : TModel , dataset : Dataset , model_name : str = "" ) -> MetricResults :
315
311
"""
@@ -323,18 +319,16 @@ def collect_metric_results(self, model: TModel, dataset: Dataset, model_name: st
323
319
nncf_logger .info (f"Validation of { model_name } model was started" )
324
320
325
321
with timer () as preparation_time :
326
- model_for_inference = self .prepare_model_for_inference (model )
322
+ prepared_model = self .prepare_model (model )
327
323
328
324
with timer () as validation_time :
329
- metric , values_for_each_item = self .validate_model_for_inference ( model_for_inference , dataset )
325
+ metric , values_for_each_item = self .validate_prepared_model ( prepared_model , dataset )
330
326
331
327
nncf_logger .info (f"Metric of { model_name } model: { metric } " )
332
328
333
329
if values_for_each_item is None :
334
330
nncf_logger .info (f"Collecting values for each data item using the { model_name } model" )
335
331
with timer ():
336
- values_for_each_item = self .collect_values_for_each_item_using_model_for_inference (
337
- model_for_inference , dataset
338
- )
332
+ values_for_each_item = self .collect_values_for_each_item_using_prepared_model (prepared_model , dataset )
339
333
340
334
return MetricResults (metric , values_for_each_item , preparation_time (), validation_time ())
0 commit comments