Skip to content

Commit 1c0c929

Browse files
authored
[CPU] MLAS backend integration (openvinotoolkit#17885)
- currently enabled only for FP32 FullyConnected node on x86 CPUs
1 parent 97b4b13 commit 1c0c929

20 files changed

+592
-8
lines changed

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,6 @@
7272
[submodule "ARMComputeLibrary"]
7373
path = src/plugins/intel_cpu/thirdparty/ComputeLibrary
7474
url = https://github.com/ARM-software/ComputeLibrary.git
75+
[submodule "src/plugins/intel_cpu/thirdparty/mlas"]
76+
path = src/plugins/intel_cpu/thirdparty/mlas
77+
url = https://github.com/openvinotoolkit/mlas.git

licensing/runtime-third-party-programs.txt

+26
Original file line numberDiff line numberDiff line change
@@ -1399,3 +1399,29 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
13991399
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
14001400
SOFTWARE.
14011401

1402+
-------------------------------------------------------------
1403+
1404+
21 MLAS (https://github.com/microsoft/onnxruntime)
1405+
1406+
MIT License
1407+
1408+
Copyright (c) Microsoft Corporation
1409+
1410+
Permission is hereby granted, free of charge, to any person obtaining a copy
1411+
of this software and associated documentation files (the "Software"), to deal
1412+
in the Software without restriction, including without limitation the rights
1413+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
1414+
copies of the Software, and to permit persons to whom the Software is
1415+
furnished to do so, subject to the following conditions:
1416+
1417+
The above copyright notice and this permission notice shall be included in all
1418+
copies or substantial portions of the Software.
1419+
1420+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
1421+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
1422+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
1423+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1424+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
1425+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1426+
SOFTWARE.
1427+

src/bindings/python/tests/test_onnx/test_zoo_models.py

+1
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def tinyyolov3_post_processing(outputs : Sequence[Any]) -> Sequence[Any]:
115115
"GPT2": {"atol": 5e-06, "rtol": 0.01},
116116
"GPT-2-LM-HEAD": {"atol": 4e-06},
117117
"test_retinanet_resnet101": {"atol": 1.3e-06},
118+
"resnet34-v1-7" : {"atol": 1e-5}
118119
}
119120

120121
def tolerance_map_key_in_model_path(path):

src/plugins/intel_cpu/CMakeLists.txt

+15-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ elseif(OV_COMPILER_IS_CLANG)
2020
ie_add_compiler_flags(-Wno-delete-non-abstract-non-virtual-dtor)
2121
endif()
2222

23+
# enbale mlas for X86 cpus only
24+
ie_dependent_option(ENABLE_MLAS_FOR_CPU "MLAS GEMM for OpenVINO CPU Plugin" ON "X86 OR X86_64" OFF)
2325
add_subdirectory(thirdparty)
2426

2527
if(WIN32)
@@ -64,6 +66,10 @@ if(NOT (AARCH64 OR ARM))
6466
list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/arm/*)
6567
endif()
6668

69+
if (NOT ENABLE_MLAS_FOR_CPU)
70+
list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/mlas/*)
71+
endif()
72+
6773
file(GLOB_RECURSE FILES_TO_REMOVE ${EXCLUDE_PATHS})
6874
list(REMOVE_ITEM SOURCES ${FILES_TO_REMOVE})
6975
list(REMOVE_ITEM HEADERS ${FILES_TO_REMOVE})
@@ -94,8 +100,12 @@ target_link_libraries(${TARGET_NAME} PRIVATE dnnl
94100

95101
target_compile_definitions(${TARGET_NAME} PRIVATE IMPLEMENT_INFERENCE_EXTENSION_API)
96102
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
103+
if (ENABLE_MLAS_FOR_CPU)
104+
target_link_libraries(${TARGET_NAME} PRIVATE mlas)
105+
target_include_directories(${TARGET_NAME} SYSTEM PRIVATE $<TARGET_PROPERTY:mlas,INCLUDE_DIRECTORIES>)
106+
add_definitions(-DOV_CPU_WITH_MLAS)
107+
endif()
97108
target_include_directories(${TARGET_NAME} SYSTEM PRIVATE $<TARGET_PROPERTY:dnnl,INCLUDE_DIRECTORIES>)
98-
99109
# Cross compiled function
100110
# TODO: The same for proposal, proposalONNX, topk
101111
cross_compiled_file(${TARGET_NAME}
@@ -133,6 +143,10 @@ if(BUILD_SHARED_LIBS)
133143
$<TARGET_PROPERTY:openvino::conditional_compilation,INTERFACE_INCLUDE_DIRECTORIES>)
134144

135145
target_include_directories(${TARGET_NAME}_obj SYSTEM PUBLIC $<TARGET_PROPERTY:dnnl,INCLUDE_DIRECTORIES>)
146+
147+
if(ENABLE_MLAS_FOR_CPU)
148+
target_include_directories(${TARGET_NAME}_obj SYSTEM PUBLIC $<TARGET_PROPERTY:mlas,INCLUDE_DIRECTORIES>)
149+
endif()
136150

137151
set_ie_threading_interface_for(${TARGET_NAME}_obj)
138152

src/plugins/intel_cpu/src/graph_optimizer.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,9 @@ void GraphOptimizer::FuseFCAndConvertOnWeights(Graph& graph) {
705705
if (parent->getType() == Type::Convert && parent->isConstant() && parent->getChildEdgeAt(0)->getChild()->getType() == Type::FullyConnected
706706
&& parent->getOriginalInputPrecisionAtPort(0) == Precision::FP16
707707
&& one_of(parent->getOriginalOutputPrecisionAtPort(0), Precision::FP32, Precision::BF16)) {
708+
auto childNode = parent->getChildEdgeAt(0)->getChild();
709+
// set correct weight precision
710+
childNode->setOriginalInputPrecisionAtPort(1, parent->getOriginalInputPrecisionAtPort(0));
708711
graph.DropNode(parent);
709712
}
710713
}
+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// Copyright (C) 2018-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#include "sgemm.hpp"
5+
6+
#include <string>
7+
#include <vector>
8+
9+
#include "mlas.h"
10+
#include "onednn/dnnl.h"
11+
#include "openvino/core/parallel.hpp"
12+
#include "thread_pool.hpp"
13+
14+
namespace ov {
15+
namespace intel_cpu {
16+
17+
size_t mlas_sgemm_pack_get_size(const int64_t N, const int64_t K) {
18+
return MlasGemmPackBSize(N, K);
19+
}
20+
21+
void mlas_sgemm_pack(const char* transb,
22+
const int64_t N,
23+
const int64_t K,
24+
const int64_t ldb,
25+
const float* src,
26+
float* dst) {
27+
MlasGemmPackB(*transb == 'T' ? CblasTrans : CblasNoTrans, N, K, src, ldb, dst);
28+
}
29+
30+
void mlas_sgemm(const char* transa,
31+
const char* transb,
32+
const int64_t M,
33+
const int64_t N,
34+
const int64_t K,
35+
const float alpha,
36+
const float* A,
37+
const int64_t lda,
38+
const float* B,
39+
const int64_t ldb,
40+
const float beta,
41+
float* C,
42+
const int64_t ldc,
43+
size_t thread_num) {
44+
// C = alpha*op( A )op( B ) + beta * C
45+
MLAS_SGEMM_DATA_PARAMS sgemmParam;
46+
sgemmParam.BIsPacked = false;
47+
sgemmParam.A = A;
48+
sgemmParam.lda = lda;
49+
sgemmParam.B = B;
50+
sgemmParam.ldb = ldb;
51+
sgemmParam.C = C;
52+
sgemmParam.ldc = ldc;
53+
sgemmParam.alpha = alpha;
54+
sgemmParam.beta = beta;
55+
auto _transa = *transa == 'N' ? CblasNoTrans : CblasTrans;
56+
auto _transb = *transb == 'N' ? CblasNoTrans : CblasTrans;
57+
ov::cpu::OVMlasThreadPool threadPool(0 == thread_num ? parallel_get_num_threads() : thread_num);
58+
MlasGemmBatch(_transa, _transb, M, N, K, &sgemmParam, 1, &threadPool);
59+
}
60+
61+
void mlas_sgemm_compute(const char* transa,
62+
const char* transb,
63+
const int64_t M,
64+
const int64_t N,
65+
const int64_t K,
66+
const float alpha,
67+
const float* A,
68+
const int64_t lda,
69+
const float* B,
70+
const int64_t ldb,
71+
const float beta,
72+
float* C,
73+
const int64_t ldc,
74+
const float* bias,
75+
size_t thread_num) {
76+
// C = alpha*op( A )op( B ) + beta * C
77+
ov::cpu::OVMlasThreadPool threadPool(0 == thread_num ? parallel_get_num_threads() : thread_num);
78+
MLAS_SGEMM_DATA_PARAMS sgemmParam;
79+
sgemmParam.BIsPacked = true;
80+
sgemmParam.A = A;
81+
sgemmParam.lda = lda;
82+
sgemmParam.B = B;
83+
sgemmParam.ldb = ldb;
84+
sgemmParam.C = C;
85+
sgemmParam.ldc = ldc;
86+
sgemmParam.alpha = alpha;
87+
sgemmParam.beta = beta;
88+
sgemmParam.bias = bias;
89+
auto _transa = *transa == 'N' ? CblasNoTrans : CblasTrans;
90+
auto _transb = *transb == 'N' ? CblasNoTrans : CblasTrans;
91+
MlasGemmBatch(_transa, _transb, M, N, K, &sgemmParam, 1, &threadPool);
92+
}
93+
} // namespace intel_cpu
94+
} // namespace ov
+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// Copyright (C) 2018-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include <cstddef>
8+
#include <cstdint>
9+
10+
namespace ov {
11+
namespace intel_cpu {
12+
/**
13+
* @brief Computes the length in bytes for the packed matrix B buffer(SGEMM).
14+
*
15+
* @param N Supplies the number of columns of matrix B.
16+
* @param K Supplies the number of rows of matrix B.
17+
* @return bytes of the packing buffer
18+
*/
19+
size_t mlas_sgemm_pack_get_size(const int64_t N, const int64_t K);
20+
21+
/**
22+
* @brief Packs the contents of matrix B
23+
*
24+
* @param transb T for transpose B, N for none-tranpose B
25+
* @param N Supplies the number of columns of matrix B and matrix C.
26+
* @param K Supplies the number of columns of matrix A and the number
27+
of rows of matrix B.
28+
* @param ldb Supplies the first dimension of matrix B.
29+
* @param src Supplies the address of matrix B
30+
* @param dst Supplies pointer to prePacked B buffer
31+
*/
32+
void mlas_sgemm_pack(const char* transb,
33+
const int64_t N,
34+
const int64_t K,
35+
const int64_t ldb,
36+
const float* src,
37+
float* dst);
38+
39+
/**
40+
* @brief SGEMM with planar B matrix
41+
*
42+
* @param transa T for transpose A, N for none-tranpose A.
43+
* @param transb T for transpose B, N for none-tranpose B.
44+
* @param M Supplies the number of rows of matrix A and matrix C.
45+
* @param N Supplies the number of columns of matrix B and matrix C.
46+
* @param K Supplies the number of columns of matrix A and the number
47+
of rows of matrix B.
48+
* @param alpha Supplies the scalar alpha multiplier (see SGEMM definition)
49+
* @param A Supplies the address of matrix A
50+
* @param lda Supplies the first dimension of matrix A.
51+
* @param B Supplies the address of matrix B
52+
* @param ldb Supplies the first dimension of matrix B.
53+
* @param beta Supplies the scalar beta multiplier (see SGEMM definition)
54+
* @param C Supplies the address of matrix C
55+
* @param ldc Supplies the first dimension of matrix C.
56+
* @param thread_num 0 for all threads, otherwise use thread_num
57+
*/
58+
void mlas_sgemm(const char* transa,
59+
const char* transb,
60+
const int64_t M,
61+
const int64_t N,
62+
const int64_t K,
63+
const float alpha,
64+
const float* A,
65+
const int64_t lda,
66+
const float* B,
67+
const int64_t ldb,
68+
const float beta,
69+
float* C,
70+
const int64_t ldc,
71+
size_t thread_num = 0);
72+
73+
/**
74+
* @brief SGEMM with B matrix prepacked
75+
*
76+
* @param transa T for transpose A, N for none-tranpose A.
77+
* @param transb T for transpose B, N for none-tranpose B.
78+
* @param M Supplies the number of rows of matrix A and matrix C.
79+
* @param N Supplies the number of columns of matrix B and matrix C.
80+
* @param K Supplies the number of columns of matrix A and the number
81+
of rows of matrix B.
82+
* @param alpha Supplies the scalar alpha multiplier (see SGEMM definition)
83+
* @param A Supplies the address of matrix A
84+
* @param lda Supplies the first dimension of matrix A.
85+
* @param B Supplies the address of matrix B
86+
* @param ldb Supplies the first dimension of matrix B.
87+
* @param beta Supplies the scalar beta multiplier (see SGEMM definition)
88+
* @param C Supplies the address of matrix C
89+
* @param ldc Supplies the first dimension of matrix C.
90+
* @param bias Supplies the address of by-channel bias
91+
* @param thread_num 0 for all threads, otherwise use thread_num
92+
*/
93+
void mlas_sgemm_compute(const char* transa,
94+
const char* transb,
95+
const int64_t M,
96+
const int64_t N,
97+
const int64_t K,
98+
const float alpha,
99+
const float* A,
100+
const int64_t lda,
101+
const float* B,
102+
const int64_t ldb,
103+
const float beta,
104+
float* C,
105+
const int64_t ldc,
106+
const float* bias = nullptr,
107+
size_t thread_num = 0);
108+
} // namespace intel_cpu
109+
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (C) 2018-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "thread_pool.hpp"
6+
7+
#include "onednn/dnnl.h"
8+
#include "openvino/core/parallel.hpp"
9+
10+
// This function impl the forward declaration in MLAS
11+
size_t getCacheSizeMlas(int level, bool perCore) {
12+
return dnnl::utils::get_cache_size(level, perCore);
13+
}
14+
15+
namespace ov {
16+
namespace cpu {
17+
18+
size_t OVMlasThreadPool::DegreeOfParallelism() {
19+
// threadpool nullptr means single threaded
20+
return threadNum;
21+
}
22+
23+
void OVMlasThreadPool::TrySimpleParallelFor(const std::ptrdiff_t total, const std::function<void(std::ptrdiff_t)>& fn) {
24+
ov::parallel_nt(threadNum, [&](const size_t ithr, const size_t nthr) {
25+
std::ptrdiff_t start = 0, end = 0;
26+
ov::splitter(total, nthr, ithr, start, end);
27+
for (std::ptrdiff_t i = start; i < end; i++) {
28+
fn(i);
29+
}
30+
});
31+
}
32+
}; // namespace cpu
33+
}; // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright (C) 2018-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include <cstddef>
8+
#include <cstdint>
9+
#include <functional>
10+
#include "mlas.h"
11+
12+
namespace ov {
13+
namespace cpu {
14+
class OVMlasThreadPool : public IMlasThreadPool {
15+
public:
16+
OVMlasThreadPool() = delete;
17+
explicit OVMlasThreadPool(const size_t& threadNum) : threadNum(threadNum) {}
18+
size_t DegreeOfParallelism() override;
19+
void TrySimpleParallelFor(const std::ptrdiff_t total, const std::function<void(std::ptrdiff_t)>& fn) override;
20+
public:
21+
// the actual threads used for sgemm
22+
size_t threadNum = 0;
23+
};
24+
}; // namespace cpu
25+
}; // namespace ov

src/plugins/intel_cpu/src/node.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ std::string Node::getPrimitiveDescriptorType() const {
471471
SEARCH_TYPE(avx);
472472
SEARCH_TYPE(sse42);
473473
SEARCH_TYPE(blas);
474+
SEARCH_TYPE(mlas);
474475
SEARCH_TYPE(any);
475476
SEARCH_TYPE(uni);
476477

0 commit comments

Comments
 (0)