Skip to content

Commit b49390e

Browse files
authored
Restoring state after inference only for stateful models (#2445)
### Changes Restoring state after inference only for stateful models ### Reason for changes Inference optimization ### Related tickets 131141 ### Tests test_examples 230
1 parent edb3987 commit b49390e

File tree

8 files changed

+113
-74
lines changed

8 files changed

+113
-74
lines changed

nncf/common/factory.py

-8
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818
from nncf.common.graph.transformations.command_creation import CommandCreator
1919
from nncf.common.tensor_statistics import aggregator
2020
from nncf.common.utils.backend import BackendType
21-
from nncf.common.utils.backend import get_available_backends
2221
from nncf.common.utils.backend import get_backend
23-
from nncf.common.utils.backend import is_openvino_compiled_model
2422
from nncf.data.dataset import Dataset
2523

2624
TModel = TypeVar("TModel")
@@ -88,12 +86,6 @@ def create(model: TModel) -> Engine:
8886
:param model: backend-specific model instance.
8987
:return: backend-specific Engine instance.
9088
"""
91-
available_backends = get_available_backends()
92-
if BackendType.OPENVINO in available_backends and is_openvino_compiled_model(model):
93-
from nncf.openvino.engine import OVCompiledModelEngine
94-
95-
return OVCompiledModelEngine(model)
96-
9789
model_backend = get_backend(model)
9890
if model_backend == BackendType.ONNX:
9991
from nncf.onnx.engine import ONNXEngine

nncf/openvino/engine.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import nncf
1818
from nncf.common.engine import Engine
19+
from nncf.openvino.graph.model_utils import model_has_state
1920
from nncf.parameters import TargetDevice
2021

2122

@@ -28,13 +29,12 @@ class OVCompiledModelEngine(Engine):
2829
to infer the compiled model.
2930
"""
3031

31-
def __init__(self, model: ov.CompiledModel):
32-
self.compiled_model = model
33-
self.infer_request = model.create_infer_request()
34-
self.reset_state = hasattr(self.infer_request, "reset_state")
32+
def __init__(self, compiled_model: ov.CompiledModel, stateful: bool):
33+
self.infer_request = compiled_model.create_infer_request()
34+
self.reset_state = stateful and hasattr(self.infer_request, "reset_state")
3535
self.input_tensor_names = set()
36-
self.number_of_inputs = len(model.inputs)
37-
for model_input in model.inputs:
36+
self.number_of_inputs = len(compiled_model.inputs)
37+
for model_input in compiled_model.inputs:
3838
self.input_tensor_names.update(model_input.get_names())
3939

4040
def _check_input_data_format(
@@ -95,8 +95,9 @@ def __init__(self, model: ov.Model, target_device: TargetDevice = TargetDevice.C
9595
target_device = TargetDevice.CPU
9696

9797
ie = ov.Core()
98+
stateful = model_has_state(model)
9899
compiled_model = ie.compile_model(model, target_device.value)
99-
self.engine = OVCompiledModelEngine(compiled_model)
100+
self.engine = OVCompiledModelEngine(compiled_model, stateful)
100101

101102
def infer(
102103
self, input_data: Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray], Dict[str, np.ndarray]]

nncf/openvino/graph/model_utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,13 @@ def get_start_nodes_for_activation_path_tracing(nncf_graph: NNCFGraph) -> List[N
6060
:return: Target NNCFGraph input nodes.
6161
"""
6262
return nncf_graph.get_input_nodes() + nncf_graph.get_nodes_by_metatypes([OVReadValueMetatype])
63+
64+
65+
def model_has_state(model: ov.Model) -> bool:
66+
"""
67+
Returns True if model has state else False
68+
69+
:param model: OpenVINO model
70+
:return: True if model has state else False
71+
"""
72+
return len(model.get_sinks()) > 0

nncf/quantization/algorithms/accuracy_control/backend.py

+30-12
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from abc import abstractmethod
1414
from typing import Any, List, Optional, TypeVar
1515

16+
from nncf.common.engine import Engine
1617
from nncf.common.graph.graph import NNCFGraph
1718
from nncf.common.graph.graph import NNCFNode
1819
from nncf.common.graph.operator_metatypes import OperatorMetatype
@@ -21,6 +22,35 @@
2122
TPModel = TypeVar("TPModel")
2223

2324

25+
class PreparedModel(ABC):
26+
@property
27+
@abstractmethod
28+
def model_for_inference(self) -> TPModel:
29+
"""
30+
Returns prepared model for inference.
31+
32+
:return: Prepared model for inference.
33+
"""
34+
35+
@property
36+
@abstractmethod
37+
def engine(self) -> Engine:
38+
"""
39+
Returns the engine for inference the prepared model.
40+
41+
:return: The engine for inference the prepared model.
42+
"""
43+
44+
def __call__(self, input_data: Any) -> Any:
45+
"""
46+
Runs model on the provided input data and returns the raw model outputs.
47+
48+
:param input_data: inputs for the model
49+
:return: raw model outputs
50+
"""
51+
return self.engine.infer(input_data)
52+
53+
2454
class AccuracyControlAlgoBackend(ABC):
2555
# Metatypes
2656

@@ -158,15 +188,3 @@ def get_model_size(model: TModel) -> int:
158188
:param model: A model
159189
:return: Model size (in bytes)
160190
"""
161-
162-
# Preparation of model
163-
164-
@staticmethod
165-
@abstractmethod
166-
def prepare_for_inference(model: TModel) -> TPModel:
167-
"""
168-
Prepares model for inference.
169-
170-
:param model: A model that should be prepared.
171-
:return: Prepared model for inference.
172-
"""

nncf/quantization/algorithms/accuracy_control/evaluator.py

+26-32
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@
1313
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union
1414

1515
import nncf
16-
from nncf.common.factory import EngineFactory
1716
from nncf.common.logging import nncf_logger
1817
from nncf.common.utils.backend import BackendType
1918
from nncf.common.utils.backend import get_backend
2019
from nncf.common.utils.timer import timer
2120
from nncf.data.dataset import Dataset
21+
from nncf.quantization.algorithms.accuracy_control.backend import PreparedModel
2222

2323
TModel = TypeVar("TModel")
24-
TPModel = TypeVar("TPModel")
2524
TTensor = TypeVar("TTensor")
2625

2726

@@ -112,7 +111,7 @@ def is_metric_mode(self) -> bool:
112111
"""
113112
return self._metric_mode
114113

115-
def prepare_model_for_inference(self, model: TModel) -> TPModel:
114+
def prepare_model(self, model: TModel) -> PreparedModel:
116115
"""
117116
Prepares model for inference.
118117
@@ -122,21 +121,19 @@ def prepare_model_for_inference(self, model: TModel) -> TPModel:
122121
backend = get_backend(model)
123122

124123
if backend == BackendType.OPENVINO:
125-
import openvino.runtime as ov
124+
from nncf.quantization.algorithms.accuracy_control.openvino_backend import OVPreparedModel
126125

127-
return ov.compile_model(model)
126+
return OVPreparedModel(model)
128127

129-
raise NotImplementedError(
130-
f"The `prepare_model_for_inference()` method is not implemented for the {backend} backend."
131-
)
128+
raise NotImplementedError(f"The `prepare_model()` method is not implemented for the {backend} backend.")
132129

133-
def validate_model_for_inference(
134-
self, model_for_inference: TPModel, dataset: Dataset, indices: Optional[List[int]] = None
130+
def validate_prepared_model(
131+
self, prepared_model: PreparedModel, dataset: Dataset, indices: Optional[List[int]] = None
135132
):
136133
"""
137134
Validates prepared model for inference.
138135
139-
:param model: Prepared model to validate.
136+
:param prepared_model: Prepared model to validate.
140137
:param dataset: Dataset to validate the model.
141138
:param indices: Zero-based indices of data items that should be selected from
142139
the dataset.
@@ -148,7 +145,7 @@ def validate_model_for_inference(
148145
item.
149146
"""
150147
if self._metric_mode is None:
151-
self._metric_mode = Evaluator.determine_mode(model_for_inference, dataset, self._validation_fn)
148+
self._metric_mode = Evaluator.determine_mode(prepared_model, dataset, self._validation_fn)
152149

153150
if not self.is_metric_mode() and indices is not None:
154151
raise ValueError("The `indices` parameter can be used only if Evaluator.is_metric_mode() = True")
@@ -157,7 +154,7 @@ def validate_model_for_inference(
157154
if self._enable_iteration_count:
158155
validation_dataset = IterationCounter(validation_dataset)
159156

160-
metric, values_for_each_item = self._validation_fn(model_for_inference, validation_dataset)
157+
metric, values_for_each_item = self._validation_fn(prepared_model.model_for_inference, validation_dataset)
161158

162159
self._num_passed_iterations = validation_dataset.num_iterations if self._enable_iteration_count else 0
163160

@@ -190,20 +187,20 @@ def validate(
190187
Otherwise, if the condition is false, it represents list of logits for each
191188
item.
192189
"""
193-
model_for_inference = self.prepare_model_for_inference(model)
194-
return self.validate_model_for_inference(model_for_inference, dataset, indices)
190+
prepared_model = self.prepare_model(model)
191+
return self.validate_prepared_model(prepared_model, dataset, indices)
195192

196193
@staticmethod
197194
def determine_mode(
198-
model_for_inference: TPModel,
195+
prepared_model: PreparedModel,
199196
dataset: Dataset,
200197
validation_fn: Callable[[Any, Iterable[Any]], Tuple[float, Union[None, List[float], List[List[TTensor]]]]],
201198
) -> bool:
202199
"""
203200
Determines mode based on the type of returned value from the
204201
validation function.
205202
206-
:param model_for_inference: Model to validate.
203+
:param prepared_model: Model to validate.
207204
:param dataset: Dataset to validate the model.
208205
:param validation_fn: Validation function to validate model.
209206
:return: A boolean indicator where `True` means that the `Evaluator` collects
@@ -215,7 +212,7 @@ def determine_mode(
215212
data_item = dataset.get_data([0])
216213

217214
try:
218-
metric_value, values_for_each_item = validation_fn(model_for_inference, data_item)
215+
metric_value, values_for_each_item = validation_fn(prepared_model.model_for_inference, data_item)
219216
except Exception:
220217
metric_mode = False
221218

@@ -262,15 +259,15 @@ def determine_mode(
262259

263260
return metric_mode
264261

265-
def collect_values_for_each_item_using_model_for_inference(
266-
self, model_for_inference: TPModel, dataset: Dataset, indices: Optional[List[int]] = None
262+
def collect_values_for_each_item_using_prepared_model(
263+
self, prepared_model: PreparedModel, dataset: Dataset, indices: Optional[List[int]] = None
267264
) -> Union[List[float], List[List[TTensor]]]:
268265
"""
269266
Collects value for each item from the dataset using prepared model for inference.
270267
If `is_metric_mode()` returns `True` then i-th value is a metric for i-th data item.
271268
It is an output of the model for i-th data item otherwise.
272269
273-
:param model: Model to infer.
270+
:param prepared_model: Model to infer.
274271
:param dataset: Dataset to collect values.
275272
:param indices: The zero-based indices of data items that should be selected from
276273
the dataset.
@@ -279,15 +276,14 @@ def collect_values_for_each_item_using_model_for_inference(
279276
if self._metric_mode:
280277
# Collect metrics for each item
281278
values_for_each_item = [
282-
self._validation_fn(model_for_inference, [data_item])[0] for data_item in dataset.get_data(indices)
279+
self._validation_fn(prepared_model.model_for_inference, [data_item])[0]
280+
for data_item in dataset.get_data(indices)
283281
]
284282
else:
285283
# Collect outputs for each item
286-
engine = EngineFactory.create(model_for_inference)
287-
288284
values_for_each_item = []
289285
for data_item in dataset.get_inference_data(indices):
290-
logits = engine.infer(data_item)
286+
logits = prepared_model(data_item)
291287
values_for_each_item.append(list(logits.values()))
292288

293289
self._num_passed_iterations = len(values_for_each_item) if self._enable_iteration_count else 0
@@ -308,8 +304,8 @@ def collect_values_for_each_item(
308304
the dataset.
309305
:return: Collected values.
310306
"""
311-
model_for_inference = self.prepare_model_for_inference(model)
312-
return self.collect_values_for_each_item_using_model_for_inference(model_for_inference, dataset, indices)
307+
prepared_model = self.prepare_model(model)
308+
return self.collect_values_for_each_item_using_prepared_model(prepared_model, dataset, indices)
313309

314310
def collect_metric_results(self, model: TModel, dataset: Dataset, model_name: str = "") -> MetricResults:
315311
"""
@@ -323,18 +319,16 @@ def collect_metric_results(self, model: TModel, dataset: Dataset, model_name: st
323319
nncf_logger.info(f"Validation of {model_name} model was started")
324320

325321
with timer() as preparation_time:
326-
model_for_inference = self.prepare_model_for_inference(model)
322+
prepared_model = self.prepare_model(model)
327323

328324
with timer() as validation_time:
329-
metric, values_for_each_item = self.validate_model_for_inference(model_for_inference, dataset)
325+
metric, values_for_each_item = self.validate_prepared_model(prepared_model, dataset)
330326

331327
nncf_logger.info(f"Metric of {model_name} model: {metric}")
332328

333329
if values_for_each_item is None:
334330
nncf_logger.info(f"Collecting values for each data item using the {model_name} model")
335331
with timer():
336-
values_for_each_item = self.collect_values_for_each_item_using_model_for_inference(
337-
model_for_inference, dataset
338-
)
332+
values_for_each_item = self.collect_values_for_each_item_using_prepared_model(prepared_model, dataset)
339333

340334
return MetricResults(metric, values_for_each_item, preparation_time(), validation_time())

nncf/quantization/algorithms/accuracy_control/openvino_backend.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from nncf.common.graph import NNCFGraph
1818
from nncf.common.graph import NNCFNode
19+
from nncf.openvino.engine import OVCompiledModelEngine
1920
from nncf.openvino.graph.layer_attributes import OVLayerAttributes
2021
from nncf.openvino.graph.metatypes.groups import CONSTANT_OPERATIONS
2122
from nncf.openvino.graph.metatypes.groups import FAKE_QUANTIZE_OPERATIONS
@@ -26,10 +27,33 @@
2627
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConcatMetatype
2728
from nncf.openvino.graph.metatypes.openvino_metatypes import OVOpMetatype
2829
from nncf.openvino.graph.model_utils import get_start_nodes_for_activation_path_tracing
30+
from nncf.openvino.graph.model_utils import model_has_state
2931
from nncf.openvino.graph.node_utils import get_bias_value
3032
from nncf.openvino.graph.node_utils import get_weight_value
3133
from nncf.openvino.graph.node_utils import is_node_with_bias
3234
from nncf.quantization.algorithms.accuracy_control.backend import AccuracyControlAlgoBackend
35+
from nncf.quantization.algorithms.accuracy_control.backend import PreparedModel
36+
37+
38+
class OVPreparedModel(PreparedModel):
39+
"""
40+
Implementation of the `PreparedModel` for OpenVINO backend.
41+
"""
42+
43+
def __init__(self, model: ov.Model):
44+
self._stateful = model_has_state(model)
45+
self._compiled_model = ov.compile_model(model)
46+
self._engine = None
47+
48+
@property
49+
def model_for_inference(self) -> ov.CompiledModel:
50+
return self._compiled_model
51+
52+
@property
53+
def engine(self) -> OVCompiledModelEngine:
54+
if self._engine is None:
55+
self._engine = OVCompiledModelEngine(self._compiled_model, self._stateful)
56+
return self._engine
3357

3458

3559
class OVAccuracyControlAlgoBackend(AccuracyControlAlgoBackend):
@@ -97,9 +121,3 @@ def get_model_size(model: ov.Model) -> int:
97121
model_size += op.data.nbytes
98122

99123
return model_size
100-
101-
# Preparation of model
102-
103-
@staticmethod
104-
def prepare_for_inference(model: ov.Model) -> ov.CompiledModel:
105-
return ov.compile_model(model)

nncf/quantization/algorithms/accuracy_control/ranker.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def _sequential_calculation_ranking_score(
200200
self._algo_backend.get_op_with_weights_metatypes(),
201201
)
202202

203-
prepared_model = self._algo_backend.prepare_for_inference(modified_model)
203+
prepared_model = self._evaluator.prepare_model(modified_model)
204204
ranking_score = self._calculate_ranking_score(
205205
prepared_model, ranking_subset_indices, reference_values_for_each_item
206206
)
@@ -229,7 +229,7 @@ def _multithreading_calculation_ranking_score(
229229
self._algo_backend.get_op_with_weights_metatypes(),
230230
)
231231

232-
prepared_model_queue.append(executor.submit(self._algo_backend.prepare_for_inference, modified_model))
232+
prepared_model_queue.append(executor.submit(self._evaluator.prepare_model, modified_model))
233233

234234
if idx >= (self._num_workers - 1):
235235
prepared_model = prepared_model_queue.pop(0).result()
@@ -263,12 +263,12 @@ def _calculate_ranking_score(
263263
"""
264264
if self._evaluator.is_metric_mode():
265265
# Calculate ranking score based on metric
266-
ranking_score, _ = self._evaluator.validate_model_for_inference(
266+
ranking_score, _ = self._evaluator.validate_prepared_model(
267267
prepared_model, self._dataset, ranking_subset_indices
268268
)
269269
else:
270270
# Calculate ranking score based on differences in logits
271-
approximate_outputs = self._evaluator.collect_values_for_each_item_using_model_for_inference(
271+
approximate_outputs = self._evaluator.collect_values_for_each_item_using_prepared_model(
272272
prepared_model, self._dataset, ranking_subset_indices
273273
)
274274
reference_outputs = [reference_values_for_each_item[i] for i in ranking_subset_indices]

0 commit comments

Comments
 (0)