-
Notifications
You must be signed in to change notification settings - Fork 249
/
Copy pathcommands.py
230 lines (184 loc) · 8.46 KB
/
commands.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
import numpy as np
import openvino as ov
from nncf.common.graph.transformations.commands import Command
from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.graph.transformations.commands import TransformationType
from nncf.openvino.graph.node_utils import InplaceInsertionFnType
from nncf.quantization.fake_quantize import FakeConvertParameters
from nncf.quantization.fake_quantize import FakeQuantizeParameters
class OVTargetPoint(TargetPoint):
def __init__(self, target_type: TargetType, target_node_name: str, port_id: int):
super().__init__(target_type)
self.target_node_name = target_node_name
self.port_id = port_id
def __eq__(self, other: "OVTargetPoint") -> bool:
return (
isinstance(other, OVTargetPoint)
and self.type == other.type
and self.target_node_name == other.target_node_name
and self.port_id == other.port_id
)
def __hash__(self) -> int:
return hash((self.target_node_name, self.port_id, self._target_type))
class OVInsertionCommand(TransformationCommand):
def __init__(self, target_point: OVTargetPoint):
super().__init__(TransformationType.INSERT, target_point)
class OVOutputInsertionCommand(OVInsertionCommand):
def __init__(self, target_point: OVTargetPoint, output_dtype: ov.Type = ov.Type.f32):
super().__init__(target_point)
self.output_dtype = output_dtype
def union(self, other: "TransformationCommand") -> "TransformationCommand":
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
raise NotImplementedError()
class OVInplaceFnInsertionCommand(OVInsertionCommand):
def __init__(
self,
target_point: OVTargetPoint,
inplace_op_fn: InplaceInsertionFnType,
fn_output_port_id: int,
last_inplace_node_name: str,
output_dtype: Optional[ov.Type] = None,
):
super().__init__(target_point)
self.inplace_op_fn = inplace_op_fn
self.fn_output_port_id = fn_output_port_id
self.last_inplace_node_name = last_inplace_node_name
self.output_dtype = output_dtype
def union(self, other: "TransformationCommand") -> "TransformationCommand":
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
raise NotImplementedError()
class OVFQNodeRemovingCommand(TransformationCommand):
"""
Removes FakeQuantize nodes from the model.
"""
def __init__(self, target_point: OVTargetPoint):
"""
:param target_point: The TargetPoint instance for the layer that contains information for removing.
"""
super().__init__(TransformationType.REMOVE, target_point)
class OVQuantizerInsertionCommand(OVInsertionCommand):
def __init__(self, target_point: OVTargetPoint, quantizer_parameters: FakeQuantizeParameters):
super().__init__(target_point)
self.quantizer_parameters = quantizer_parameters
class OVConvertInsertionCommand(OVInsertionCommand):
def __init__(self, target_point: OVTargetPoint, convert_parameters: FakeConvertParameters):
super().__init__(target_point)
self.convert_parameters = convert_parameters
class OVBiasCorrectionCommand(TransformationCommand):
"""
Corrects bias value in the model based on the input value.
"""
def __init__(self, target_point: OVTargetPoint, bias_value: np.ndarray):
"""
:param target_point: The TargetPoint instance for the correction that contains layer's information.
:param bias_value: The bias shift value (numpy format) that will be added to the original bias value.
"""
super().__init__(TransformationType.CHANGE, target_point)
self.bias_value = bias_value
class OVWeightUpdateCommand(TransformationCommand):
"""
Updates weight value in the model.
"""
def __init__(self, target_point: OVTargetPoint, weight_value: np.ndarray):
"""
:param target_point: Target point.
:param weight_value: New weight value.
"""
super().__init__(TransformationType.CHANGE, target_point)
self.weight_value = weight_value
class OVModelExtractionCommand(Command):
"""
Extracts sub-graph based on the sub-model input and output names.
"""
def __init__(self, input_ids: List[Tuple[str, int]], output_ids: List[Tuple[str, int]]):
"""
:param input_ids: List of the input IDs: pairs of node names and correspondent input port ids.
Each pair denotes the sub-graph beginning.
:param output_ids: List of the output IDs: pairs of node names and correspondent output port ids.
Each pair denotes the sub-graph ending.
"""
super().__init__(TransformationType.EXTRACT)
self.input_ids = input_ids
self.output_ids = output_ids
class OVStateLessModelExtractionCommand(Command):
"""
Extracts stateless sub-graph based on the sub-model input and output names.
"""
def __init__(self, input_ids: List[Tuple[str, int]], output_ids: List[Tuple[str, int]]):
"""
:param input_ids: List of the input IDs: pairs of node names and correspondent output port ids.
Each pair denotes the sub-graph beginning.
:param output_ids: List of the output IDs: pairs of node names and correspondent output port ids.
Each pair denotes the sub-graph ending.
"""
super().__init__(TransformationType.EXTRACT)
self.input_ids = input_ids
self.output_ids = output_ids
class OVBiasInsertionCommand(TransformationCommand):
"""
Inserts bias for the corresponding node.
"""
def __init__(self, target_point: OVTargetPoint, bias_value: np.ndarray):
"""
:param target_point: The TargetPoint instance for the insertion that contains layer's information.
:param bias_value: Constant value for the bias layer.
"""
super().__init__(TransformationType.INSERT, target_point)
self.bias_value = bias_value
class OVMultiplyInsertionCommand(OVInsertionCommand):
"""
Inserts Multiply nodes before the corresponding nodes.
"""
def __init__(
self,
target_point: OVTargetPoint,
scale_value: np.ndarray,
destination_node_names: List[str],
multiply_node_name: str,
):
"""
:param target_point: The TargetPoint instance for the insertion that contains layer's information.
:param scale_value: Scale value for Multiply layer.
:param destination_node_names: New layer consumers.
:param multiply_node_name: New layer name.
"""
super().__init__(target_point)
self.scale_value = scale_value
self.destination_node_names = destination_node_names
self.multiply_node_name = multiply_node_name
class OVUpdateIfBodyCommand(TransformationCommand):
"""
Updates If node body.
"""
def __init__(self, target_point: OVTargetPoint, body_model: ov.Model):
"""
:param target_point: The TargetPoint instance for the change that contains layer's information.
:param body_model: A new model to set.
"""
super().__init__(TransformationType.CHANGE, target_point)
self.subgraph_model = body_model
class OVExtractIfBodyCommand(Command):
"""
Extracts If node body.
"""
def __init__(self, if_node_name: str, if_body_condition: bool):
"""
:param target_point: The TargetPoint instance for the extraction that contains layer's information.
:param if_body_condition: If true extracts then body, else - else body.
"""
super().__init__(TransformationType.EXTRACT)
self.if_node_name = if_node_name
self.if_body_condition = if_body_condition