Skip to content

Commit e3fd1fe

Browse files
committed
manual insert FQ
1 parent c7f7eb3 commit e3fd1fe

File tree

3 files changed

+174
-1
lines changed

3 files changed

+174
-1
lines changed

CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,8 @@ add_executable(ov_model_splitter ov_model_splitter.cpp)
77

88
target_include_directories(ov_model_splitter PUBLIC ${InferenceEngine_INCLUDE_DIRS})
99
target_link_libraries(ov_model_splitter ${InferenceEngine_LIBRARIES})
10+
11+
add_executable(ov_test main.cpp)
12+
target_include_directories(ov_test PUBLIC ${InferenceEngine_INCLUDE_DIRS})
13+
target_link_libraries(ov_test ${InferenceEngine_LIBRARIES})
14+

insert_fq.py

+167
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
from cv2 import multiply, subtract
2+
from openvino.runtime import Core, serialize, opset8, Type
3+
from openvino.runtime.passes import ModelPass, Matcher, MatcherPass, WrapType, Manager, VisualizeTree, AnyInput, ConstantFolding
4+
from openvino.runtime.utils import replace_node
5+
import numpy as np
6+
import sys
7+
import json
8+
core = Core()
9+
10+
class WeightQuantizationReplacement(MatcherPass):
11+
def __init__(self, weight_scales):
12+
MatcherPass.__init__(self)
13+
self.model_changed = False
14+
self.weight_scales = weight_scales
15+
weight = WrapType("opset8::Constant")
16+
convert = WrapType("opset8::Convert", weight.output(0))
17+
zero_point = WrapType("opset8::Constant")
18+
subtract = WrapType("opset8::Subtract", [convert.output(0), zero_point.output(0)])
19+
scales = WrapType("opset8::Constant")
20+
multiply = WrapType("opset8::Multiply", [subtract.output(0), scales.output(0)])
21+
22+
def callback(m: Matcher) -> bool:
23+
self.applied = True
24+
weight_node = m.get_pattern_value_map()[weight].get_node()
25+
zero_point_node = m.get_pattern_value_map()[zero_point].get_node()
26+
new_zp = opset8.constant(0, zero_point_node.get_element_type(), zero_point_node.friendly_name)
27+
# For testing purpose
28+
self.model_changed = False
29+
replace_node(zero_point_node, new_zp)
30+
scale_node = m.get_pattern_value_map()[scales].get_node()
31+
new_scale = opset8.constant(np.expand_dims(np.array(self.weight_scales[weight_node.friendly_name]), axis=1), scale_node.get_element_type(), scale_node.friendly_name)
32+
replace_node(scale_node, new_scale)
33+
# self.register_new_node(new_relu)
34+
35+
# Input->Relu->Result => Input->Relu->Relu->Result
36+
# root.input(0).replace_source_output(new_relu.output(0))
37+
return False
38+
39+
self.register_matcher(Matcher(multiply, "WeightQuantizationReplacement"), callback)
40+
41+
42+
class InsertQuantization(MatcherPass):
43+
def __init__(self, weight_scales):
44+
MatcherPass.__init__(self)
45+
self.model_changed = False
46+
self.weight_scales = weight_scales
47+
concat = WrapType("opset8::Concat", [AnyInput(), AnyInput()])
48+
weight = WrapType("opset8::Constant")
49+
convert = WrapType("opset8::Convert", weight.output(0))
50+
zero_point = WrapType("opset8::Constant")
51+
subtract = WrapType("opset8::Subtract", [convert.output(0), zero_point.output(0)])
52+
scales = WrapType("opset8::Constant")
53+
multiply = WrapType("opset8::Multiply", [subtract.output(0), scales.output(0)])
54+
matmul = WrapType("opset8::MatMul", [concat.output(0), multiply.output(0)])
55+
56+
def callback(m: Matcher) -> bool:
57+
self.applied = True
58+
concat_node = m.get_pattern_value_map()[concat].get_node()
59+
matmul_node = m.get_pattern_value_map()[matmul].get_node()
60+
mutiply_node = m.get_pattern_value_map()[multiply].get_node()
61+
scales_node = m.get_pattern_value_map()[scales].get_node()
62+
zp_node = m.get_pattern_value_map()[zero_point].get_node()
63+
weight_node = m.get_pattern_value_map()[weight].get_node()
64+
65+
const_scales = scales_node.get_vector()
66+
const_zp = zp_node.get_vector()
67+
const_weight = weight_node.get_vector()
68+
np.save("fc_scale", const_scales);
69+
np.save("fc_zp", const_zp);
70+
np.save("fc_weight", const_weight)
71+
input_low = opset8.constant(-5.12978, Type.f32, "input_low")
72+
input_high = opset8.constant(5.089652, Type.f32, "inpu_high")
73+
output_low = opset8.constant(-5.12978, Type.f32, "output_low")
74+
output_high = opset8.constant(5.089652, Type.f32, "output_high")
75+
new_fq = opset8.fake_quantize(concat_node, input_low, input_high, output_low, output_high, 256)
76+
new_matmul = opset8.matmul(new_fq, mutiply_node, False, True)
77+
# For testing purpose
78+
self.model_changed = False
79+
replace_node(matmul_node, new_matmul)
80+
# self.register_new_node(new_relu)
81+
82+
# Input->Relu->Result => Input->Relu->Relu->Result
83+
# root.input(0).replace_source_output(new_relu.output(0))
84+
return False
85+
86+
self.register_matcher(Matcher(matmul, "InsertQuantization"), callback)
87+
88+
class InsertQuantization2(MatcherPass):
89+
def __init__(self, weight_scales):
90+
MatcherPass.__init__(self)
91+
self.model_changed = False
92+
self.weight_scales = weight_scales
93+
concat = WrapType("opset8::Concat", [AnyInput(), AnyInput()])
94+
weight = WrapType("opset8::Constant")
95+
matmul = WrapType("opset8::MatMul", [concat.output(0), weight.output(0)])
96+
97+
def callback(m: Matcher) -> bool:
98+
self.applied = True
99+
concat_node = m.get_pattern_value_map()[concat].get_node()
100+
matmul_node = m.get_pattern_value_map()[matmul].get_node()
101+
weight_node = m.get_pattern_value_map()[weight].get_node()
102+
input_low = opset8.constant(-5.12978, Type.f32, "input_low")
103+
input_high = opset8.constant(5.089652, Type.f32, "inpu_high")
104+
output_low = opset8.constant(-5.12978, Type.f32, "output_low")
105+
output_high = opset8.constant(5.089652, Type.f32, "output_high")
106+
const_scales = np.load("full_connected_scales.npy")
107+
const_zp = np.load("fc_zp.npy")
108+
const_weight = np.load("fc_weight.npy")
109+
scales2 = opset8.constant(const_scales.reshape(512, 1), Type.f32, "scales2")
110+
111+
zp = opset8.constant(const_zp.reshape(512, 1), Type.f32, "zero_points")
112+
# div = opset8.divide(concat_node, scales);
113+
# convert2 = opset8.convert(convert1, "F32")
114+
# mul = opset8.multiply(convert2, scales2)
115+
new_weight = opset8.constant(const_weight.reshape(512, 415), Type.i8, "fake_weight")
116+
convert = opset8.convert(new_weight, "F32")
117+
sub = opset8.subtract(convert, zp)
118+
mul = opset8.multiply(sub, scales2)
119+
new_fq = opset8.fake_quantize(concat_node, input_low, input_high, output_low, output_high, 256)
120+
new_matmul = opset8.matmul(new_fq, mul, False, True)
121+
# For testing purpose
122+
self.model_changed = False
123+
replace_node(matmul_node, new_matmul)
124+
# self.register_new_node(new_relu)
125+
126+
# Input->Relu->Result => Input->Relu->Relu->Result
127+
# root.input(0).replace_source_output(new_relu.output(0))
128+
return False
129+
130+
self.register_matcher(Matcher(matmul, "InsertQuantization"), callback)
131+
132+
# model_path = "./bottom_mlp_int8/90_bottom_mlp_int8.xml"
133+
# model_path = "/home/zhangyi7/ov_dlrm/results/dlrm_2048_10GB_int8_MinMaxQuantization/2022-05-20_21-41-16/optimized/dlrm_2048_10GB_int8.xml"
134+
model_path = "/home/zhangyi7/ov_dlrm/results/dlrm_2048_10GB_int8_MinMaxQuantization/2022-06-11_12-51-12/optimized/dlrm_2048_10GB_int8.xml"
135+
model = core.read_model(model_path)
136+
ops = model.get_ordered_ops()
137+
print(model.get_ordered_ops())
138+
139+
ops_to_modify = [
140+
{"Gemm_0/WithoutBiases/fq_input_0": 0},
141+
{"Gemm_2/WithoutBiases/fq_input_0": 2},
142+
{"Gemm_4/WithoutBiases/fq_input_0": 4}
143+
]
144+
145+
weight_to_modify = [
146+
{"bot_l.0.weight2993579/quantized39156452", 0},
147+
{"bot_l.2.weight3043581/quantized40355795", 2},
148+
{"bot_l.4.weight3093583/quantized40656443", 4}
149+
]
150+
151+
with open("int8_configure.json", "r") as f:
152+
int8_config = json.load(f)
153+
print(len(int8_config[0]["weight_scales"][0]))
154+
print(len(int8_config[2]["weight_scales"][0]))
155+
print(len(int8_config[4]["weight_scales"][0]))
156+
weight_scales = {
157+
"bot_l.0.weight3043676/quantized40836625": int8_config[0]["weight_scales"][0],
158+
"bot_l.2.weight3093678/quantized41736343": int8_config[2]["weight_scales"][0],
159+
"bot_l.4.weight3143680/quantized40236520": int8_config[4]["weight_scales"][0]
160+
}
161+
m = Manager()
162+
# check that register pass returns pass instance
163+
p = m.register_pass(InsertQuantization2(weight_scales))
164+
# p = m.register_pass(ConstantFolding())
165+
m.run_passes(model)
166+
serialize(model, "dlrm_10_final.xml", "dlrm_10_final.bin")
167+

ov_model_splitter.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ int main(int args, char *argv[]) {
4646
}
4747
std::vector<std::string> target_input = {argv[2]};
4848
std::vector<std::string> target_output = {argv[3]};
49-
49+
std::cout << "Start " << argv[2] << std::endl;
50+
std::cout << "End " << argv[3] << std::endl;
5051
std::vector<std::shared_ptr<opset8::Parameter> > subgraph_parameters = {};
5152
std::vector<std::shared_ptr<opset8::Result> > subgraph_results = {};
5253
for(auto& input_name : target_input) {

0 commit comments

Comments
 (0)