Skip to content

Commit bb7f0e7

Browse files
authored
[CPU] Disable ConvertGatherToGatherCompressed optimization for quantized models (openvinotoolkit#25478)
### Details: - *Disable ConvertGatherToGatherCompressed pass in case `useLPT` is false* ### Tickets: - *138337* --------- Signed-off-by: xipingya <xiping.yan@intel.com>
1 parent 554e6fe commit bb7f0e7

File tree

3 files changed

+157
-2
lines changed

3 files changed

+157
-2
lines changed

src/common/transformations/src/transformations/op_conversions/convert_gather_to_compressed.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ ov::pass::ConvertGatherToGatherCompressed::ConvertGatherToGatherCompressed() {
134134
gather_input_scale);
135135
}
136136

137-
transformation_callback(new_gather_node);
137+
if (transformation_callback(new_gather_node)) {
138+
return false;
139+
}
138140

139141
result_nodes.push_back(new_gather_node);
140142
new_gather_node->set_friendly_name(gather_node->get_friendly_name());

src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp

+10-1
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
310310
ov::pass::Manager decompression_handling_manager;
311311
decompression_handling_manager.set_per_pass_validation(false);
312312
CPU_REGISTER_PASS_COMMON(decompression_handling_manager, ov::pass::InitNodeInfo);
313+
const bool useLpt = !defaultPrecisions.empty();
313314
CPU_REGISTER_PASS_COMMON(decompression_handling_manager, ov::pass::ConvertGatherToGatherCompressed);
314315
CPU_REGISTER_PASS_COMMON(decompression_handling_manager, ov::pass::MarkShapeOfSubgraphs);
315316
// We need to fuse Transpose to MatMul to have a simpler callback for the next transformation
@@ -330,6 +331,15 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
330331
if (ov::is_type<ov::op::internal::GatherCompressed>(node)) {
331332
// It is necessary to avoid precision conversion for constant node(compressed weights)
332333
ov::enable_keep_const_precision(node->get_input_node_shared_ptr(0));
334+
335+
// Prioritize LPT pipeline to handle dequantization part for quantized models as it more optimal in
336+
// general case
337+
if (ov::intel_cpu::one_of(node->get_input_node_shared_ptr(0)->get_element_type(),
338+
ov::element::u8,
339+
ov::element::i8) &&
340+
useLpt) {
341+
return true;
342+
}
333343
}
334344
return false;
335345
},
@@ -338,7 +348,6 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
338348

339349
ov::pass::Manager manager;
340350
manager.set_per_pass_validation(false);
341-
const bool useLpt = !defaultPrecisions.empty();
342351
if (useLpt)
343352
CPU_REGISTER_PASS_COMMON(manager, ov::pass::MarkDequantizationSubgraph, defaultPrecisions);
344353

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "common_test_utils/data_utils.hpp"
6+
#include "common_test_utils/node_builders/constant.hpp"
7+
#include "openvino/runtime/exec_model_info.hpp"
8+
#include "shared_test_classes/base/ov_subgraph.hpp"
9+
10+
namespace ov {
11+
namespace test {
12+
/*
13+
* input2
14+
* |
15+
* Constant(i8) Softmax
16+
* | /
17+
* Convert Multiply
18+
* | /
19+
* Multiply Convert input1(u8/i8)
20+
* \ / |
21+
* Gather FakeQuantize
22+
* \ /
23+
* \ /
24+
* MatMul
25+
*/
26+
using DisableGatherCompressedForQuantizedModelParams = std::tuple<element::Type, InputShape, InputShape>;
27+
class DisableGatherCompressedForQuantizedModel : public testing::WithParamInterface<DisableGatherCompressedForQuantizedModelParams>,
28+
virtual public SubgraphBaseTest {
29+
public:
30+
static std::string getTestCaseName(testing::TestParamInfo<DisableGatherCompressedForQuantizedModelParams> obj) {
31+
element::Type weight_prec;
32+
InputShape inputShape1, inputShape2;
33+
std::tie(weight_prec, inputShape1, inputShape2) = obj.param;
34+
std::ostringstream result;
35+
result << "weight_prec=" << weight_prec << "_" << "inputShape1=" << inputShape1 << "_"
36+
<< "inputShape2=" << inputShape2;
37+
return result.str();
38+
}
39+
40+
protected:
41+
void SetUp() override {
42+
targetDevice = utils::DEVICE_CPU;
43+
element::Type weight_prec;
44+
InputShape inputShape1, inputShape2;
45+
std::tie(weight_prec, inputShape1, inputShape2) = GetParam();
46+
47+
init_input_shapes({inputShape1, inputShape2});
48+
49+
targetDevice = utils::DEVICE_CPU;
50+
auto type = element::f32;
51+
52+
auto input1 = std::make_shared<op::v0::Parameter>(type, inputDynamicShapes[0]);
53+
auto input2 = std::make_shared<op::v0::Parameter>(type, inputDynamicShapes[1]);
54+
55+
auto shared_il = op::v0::Constant::create(type, {1, 1, 1, 1}, {0.f});
56+
auto shared_ih = op::v0::Constant::create(type, {1, 1, 1, 1}, {12.5f});
57+
auto shared_ol = op::v0::Constant::create(type, {1, 1, 1, 1}, {0.f});
58+
auto shared_oh = op::v0::Constant::create(type, {1, 1, 1, 1}, {12.5f});
59+
auto fq = std::make_shared<op::v0::FakeQuantize>(input1, shared_il, shared_ih, shared_ol, shared_oh, 256);
60+
61+
// Weights
62+
auto weights_shape = Shape{64, 64};
63+
auto weights = utils::make_constant(weight_prec, weights_shape, utils::InputGenerateData(-1, 2, 32768));
64+
auto convert = std::make_shared<op::v0::Convert>(weights, element::f32);
65+
auto multiply = std::make_shared<op::v1::Multiply>(convert, op::v0::Constant::create(type, {1, 1}, {0.625}));
66+
// Indics
67+
auto softmax = std::make_shared<op::v1::Softmax>(input2, 0);
68+
auto multiply2 = std::make_shared<op::v1::Multiply>(softmax, op::v0::Constant::create(type, {1}, {64}));
69+
auto indics = std::make_shared<op::v0::Convert>(multiply2, element::i64);
70+
// Gather
71+
auto gather =
72+
std::make_shared<op::v8::Gather>(multiply, indics, op::v0::Constant::create(element::i32, Shape{1}, {0}));
73+
74+
auto matMul = std::make_shared<ov::op::v0::MatMul>(fq, gather, false, true);
75+
76+
function = std::make_shared<Model>(matMul, ParameterVector{input1, input2});
77+
}
78+
79+
void check_results() {
80+
const auto& test_param = GetParam();
81+
const auto compressed_weights_precision = std::get<0>(test_param);
82+
83+
const auto runtime_model = compiledModel.get_runtime_model();
84+
const auto matmul = runtime_model->get_result()->get_input_node_shared_ptr(0);
85+
86+
bool have_gather = false;
87+
bool have_gather_compressed = false;
88+
for (const auto& n : runtime_model->get_ordered_ops()) {
89+
const auto type = n->get_rt_info().at(ov::exec_model_info::LAYER_TYPE).as<std::string>();
90+
if (type == "Gather") {
91+
// Gather has >=4 inputs means it is GatherCompressed.
92+
if (n->get_input_size() >= 4) {
93+
have_gather_compressed = true;
94+
} else {
95+
have_gather = true;
96+
}
97+
}
98+
}
99+
100+
switch (compressed_weights_precision) {
101+
case element::i8:
102+
EXPECT_TRUE(have_gather);
103+
EXPECT_EQ(matmul->get_input_element_type(1), element::i8);
104+
// FakeQuantize(matmul's input(0))'s output precision is u8
105+
EXPECT_EQ(matmul->get_rt_info().at(ov::exec_model_info::RUNTIME_PRECISION).as<ov::element::Type>(),
106+
element::u8);
107+
break;
108+
case element::u8:
109+
EXPECT_TRUE(have_gather);
110+
// Current oneDNN MutMul official support precision: Source(u8, s8), Weights(s8).
111+
// So reorder will be inserted when weights is not s8, don't need to check matmul's input(1) precision.
112+
break;
113+
case element::u4:
114+
case element::i4:
115+
EXPECT_TRUE(have_gather_compressed);
116+
break;
117+
default:
118+
break;
119+
}
120+
}
121+
};
122+
123+
TEST_P(DisableGatherCompressedForQuantizedModel, CompareWithRefs) {
124+
SKIP_IF_CURRENT_TEST_IS_DISABLED()
125+
run();
126+
check_results();
127+
}
128+
129+
namespace {
130+
131+
const std::vector<InputShape> inputShapes1 = {{{-1, 3, -1, -1}, {{1, 3, 64, 64}}}};
132+
const std::vector<InputShape> inputShapes2 = {{{}, {{32}}}};
133+
const std::vector<element::Type> weightsPrecisions = {element::i8, element::u8, element::u4, element::i4};
134+
135+
INSTANTIATE_TEST_SUITE_P(smoke_DisableGatherCompressedForQuantizedModel_basic,
136+
DisableGatherCompressedForQuantizedModel,
137+
::testing::Combine(::testing::ValuesIn(weightsPrecisions),
138+
::testing::ValuesIn(inputShapes1),
139+
::testing::ValuesIn(inputShapes2)),
140+
DisableGatherCompressedForQuantizedModel::getTestCaseName);
141+
142+
} // namespace
143+
} // namespace test
144+
} // namespace ov

0 commit comments

Comments
 (0)