Skip to content

Commit 831b982

Browse files
[CPU] Optimize FullyConnected op in dynamic quantization mode
1 parent 3c7f2af commit 831b982

File tree

3 files changed

+47
-28
lines changed

3 files changed

+47
-28
lines changed

src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected_primitive.cpp

+14-24
Original file line numberDiff line numberDiff line change
@@ -177,37 +177,27 @@ static bool useDynamicQuantizationImpl(size_t dqGroupSize,
177177
return false;
178178
}
179179

180-
// TODO: heuristic: disable avx2 asymmetric
181-
bool is_asymmetric_weights = one_of(weightsDesc->getPrecision(), ov::element::u8, ov::element::u4);
182-
if (is_asymmetric_weights && !dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_vnni)) {
183-
return false;
184-
}
185-
186180
const size_t simdWidth = 16;
187181
if (dqGroupSize % simdWidth) {
188182
return false;
189183
}
190184

191-
if (weightsDesc->getPrecision() == ov::element::u4) {
192-
int ic = weightsDesc->getShape().getStaticDims()[1];
193-
int minGroupSize = INT_MAX;
194-
195-
MemoryCPtr scalesPtr = memory.count(ARG_WEI | ARG_ATTR_SCALES) ? memory.at(ARG_WEI | ARG_ATTR_SCALES) : nullptr;
196-
197-
if (scalesPtr && scalesPtr->getShape().getRank() == 3) {
198-
auto scalesDims = scalesPtr->getShape().getStaticDims();
199-
auto groupsNum = needTranspose ? scalesDims[1] : scalesDims[0];
200-
minGroupSize = ic / groupsNum;
201-
}
202-
203-
if (zpPtr && zpPtr->getShape().getRank() == 3) {
204-
auto zpDims = zpPtr->getShape().getStaticDims();
205-
int groupsNum = needTranspose ? zpDims[1] : zpDims[0];
206-
minGroupSize = std::min(minGroupSize, ic / groupsNum);
185+
MemoryCPtr scalesPtr = memory.count(ARG_WEI | ARG_ATTR_SCALES) ? memory.at(ARG_WEI | ARG_ATTR_SCALES) : nullptr;
186+
int ic = weightsDesc->getShape().getStaticDims()[1];
187+
if (scalesPtr && scalesPtr->getShape().getRank() != 1) {
188+
auto scalesDims = scalesPtr->getShape().getStaticDims();
189+
auto groupsNum = scalesDims[1];
190+
size_t groupSize = ic / groupsNum;
191+
if (groupsNum != 1 && groupSize % std::min(dqGroupSize, groupSize)) {
192+
return false;
207193
}
194+
}
208195

209-
const size_t minLoopSize = 8;
210-
if (minGroupSize != INT_MAX && minGroupSize % minLoopSize) {
196+
if (zpPtr && zpPtr->getShape().getRank() != 1) {
197+
auto zpDims = zpPtr->getShape().getStaticDims();
198+
int groupsNum = zpDims[1];
199+
size_t groupSize = ic / groupsNum;
200+
if (groupsNum != 1 && groupSize % std::min(dqGroupSize, groupSize)) {
211201
return false;
212202
}
213203
}

src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/matmul_weights_decompression.cpp

+32-3
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ const std::vector<MatMulDecompressionShapeParams> input_shapes_basic_dyn_quant =
208208
{{{}, {{1, 1, 128}}}, {128, 32}},
209209
{{{}, {{1, 3, 144}}}, {144, 64}, 16lu},
210210
{{{}, {{1, 1, 1728}}}, {1728, 128}, 64lu},
211+
// jit_brgemm_kernel corner cases: ic iters > 1 && has oc tail
212+
{{{}, {{1, 1, 640}}}, {640, 90}},
211213
};
212214

213215
const std::vector<ov::test::ElementType> weights_precisions_dyn_quant = {ov::element::u8, ov::element::u4};
@@ -280,8 +282,6 @@ const std::vector<MatMulDecompressionShapeParams> input_shapes_scalar_scale = {
280282
{{{}, {{1, 10, 128}}}, {128, 32}},
281283
};
282284

283-
const std::vector<ov::test::ElementType> weights_precisions_scalar_scale = {ov::element::u8};
284-
285285
std::vector<ov::AnyMap> filter_additional_config_scalar_scale() {
286286
std::vector<ov::AnyMap> additional_config = {
287287
{{ov::hint::dynamic_quantization_group_size(0)}},
@@ -293,7 +293,7 @@ std::vector<ov::AnyMap> filter_additional_config_scalar_scale() {
293293
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_scalar_scale,
294294
MatmulWeightsDecompression,
295295
::testing::Combine(::testing::ValuesIn(input_shapes_scalar_scale),
296-
::testing::ValuesIn(weights_precisions_scalar_scale),
296+
::testing::Values(ov::element::u8),
297297
::testing::ValuesIn(decompression_precisions),
298298
::testing::Values(ov::element::undefined),
299299
::testing::Values(false),
@@ -305,6 +305,35 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_scalar_scale,
305305
::testing::Values(true)),
306306
MatmulWeightsDecompression::getTestCaseName);
307307

308+
309+
const std::vector<MatMulDecompressionShapeParams> input_shapes_non_multiples_groups = {
310+
{{{}, {{1, 3, 192}}}, {192, 128}, 96lu},
311+
};
312+
313+
std::vector<ov::AnyMap> filter_additional_config_non_multiples_groups() {
314+
std::vector<ov::AnyMap> additional_config = {
315+
{{ov::hint::dynamic_quantization_group_size(64)}}
316+
};
317+
return additional_config;
318+
}
319+
320+
// Dynamic quantization requires weights compression group size to be divisible on dq group size
321+
// The test is intended to chech such case is correctly handled via non dq path
322+
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_non_multiples_groups,
323+
MatmulWeightsDecompression,
324+
::testing::Combine(::testing::ValuesIn(input_shapes_non_multiples_groups),
325+
::testing::Values(ov::element::u8),
326+
::testing::ValuesIn(decompression_precisions),
327+
::testing::Values(ov::element::undefined),
328+
::testing::ValuesIn(transpose_weights),
329+
::testing::Values(DecompressionType::full),
330+
::testing::Values(DecompressionType::full),
331+
::testing::Values(false),
332+
::testing::ValuesIn(filter_additional_config_non_multiples_groups()),
333+
::testing::Values(emptyFusingSpec),
334+
::testing::Values(true)),
335+
MatmulWeightsDecompression::getTestCaseName);
336+
308337
} // namespace
309338
} // namespace test
310339
} // namespace ov

0 commit comments

Comments
 (0)