Skip to content

Commit 72516c3

Browse files
azhai219sunxiaoxia2022wangleis
authored
[CPU] FC node tensor parallel (openvinotoolkit#25088)
### Details: - *Enable sub stream to compile model on each socket* - *Limit CPU cores to 32 when TP is enabled* - *Split fc into two sub-fcs at node level instead of graph level* - *Concat fc output in the horizon direction. Drop allgatherv and allreduce for performance reason* - *Supported llm fc int8 and 4bit mode* - *Remove previous version fc TP logic* ### Tickets: - *ticket-id* --------- Co-authored-by: sunxiaoxia2022 <xiaoxia.sun@intel.com> Co-authored-by: Shen, Wanglei <wanglei.shen@intel.com>
1 parent e0392dc commit 72516c3

35 files changed

+1090
-364
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
/**
6+
* @brief OpenVINO Runtime Executor Manager
7+
* @file openvino/runtime/threading/executor_manager.hpp
8+
*/
9+
10+
#pragma once
11+
12+
#include <atomic>
13+
#include <condition_variable>
14+
#include <memory>
15+
#include <mutex>
16+
#include <queue>
17+
#include <set>
18+
#include <thread>
19+
#include <vector>
20+
#include <assert.h>
21+
22+
#include "openvino/runtime/common.hpp"
23+
#include "openvino/runtime/threading/istreams_executor.hpp"
24+
#include "openvino/runtime/threading/itask_executor.hpp"
25+
26+
27+
namespace ov {
28+
29+
namespace threading {
30+
enum MsgType {CALL_BACK};
31+
32+
struct MessageInfo {
33+
MsgType msg_type;
34+
};
35+
36+
class OPENVINO_RUNTIME_API MessageManager {
37+
public:
38+
MessageManager();
39+
40+
void send_message(const MessageInfo& msg_info);
41+
42+
void server_wait();
43+
44+
~MessageManager();
45+
46+
void set_num_sub_streams(int num_sub_streams);
47+
48+
int get_num_sub_streams();
49+
private:
50+
int _num_sub_streams = 0;
51+
std::vector<MessageInfo> _messageQueue;
52+
std::mutex _msgMutex;
53+
std::condition_variable _msgCondVar;
54+
};
55+
56+
OPENVINO_RUNTIME_API std::shared_ptr<MessageManager> message_manager();
57+
} // namespace threading
58+
} // namespace ov

src/inference/dev_api/openvino/runtime/threading/cpu_streams_executor.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class OPENVINO_RUNTIME_API CPUStreamsExecutor : public IStreamsExecutor {
5555

5656
int get_socket_id() override;
5757

58-
void run_sub_stream(Task task, int id) override;
58+
std::vector<int> get_rank() override;
5959

6060
private:
6161
struct Impl;

src/inference/dev_api/openvino/runtime/threading/istreams_executor.hpp

+28-22
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,19 @@ class OPENVINO_RUNTIME_API IStreamsExecutor : virtual public ITaskExecutor {
3838
*/
3939
using Ptr = std::shared_ptr<IStreamsExecutor>;
4040

41+
enum MsgType{
42+
TP,
43+
START_INFER,
44+
CALL_BACK
45+
};
46+
47+
struct MessageInfo{
48+
MsgType msg_type;
49+
std::vector<int> rank;
50+
void* buf;
51+
Task task;
52+
};
53+
4154
/**
4255
* @brief Defines inference thread binding type
4356
*/
@@ -95,6 +108,7 @@ class OPENVINO_RUNTIME_API IStreamsExecutor : virtual public ITaskExecutor {
95108
std::vector<std::vector<int>> _streams_info_table = {};
96109
std::vector<std::vector<int>> _stream_processor_ids;
97110
int _sub_streams = 0;
111+
std::vector<int> _rank = {};
98112

99113
/**
100114
* @brief Get and reserve cpu ids based on configuration and hardware information,
@@ -124,21 +138,24 @@ class OPENVINO_RUNTIME_API IStreamsExecutor : virtual public ITaskExecutor {
124138
* @param[in] cpu_reservation @copybrief Config::_cpu_reservation
125139
* @param[in] cpu_pinning @copybrief Config::_cpu_pinning
126140
* @param[in] streams_info_table @copybrief Config::_streams_info_table
141+
* @param[in] rank @copybrief Config::_rank
127142
*/
128143
Config(std::string name = "StreamsExecutor",
129144
int streams = 1,
130145
int threads_per_stream = 0,
131146
ov::hint::SchedulingCoreType thread_preferred_core_type = ov::hint::SchedulingCoreType::ANY_CORE,
132147
bool cpu_reservation = false,
133148
bool cpu_pinning = false,
134-
std::vector<std::vector<int>> streams_info_table = {})
149+
std::vector<std::vector<int>> streams_info_table = {},
150+
std::vector<int> rank = {})
135151
: _name{std::move(name)},
136152
_streams{streams},
137153
_threads_per_stream{threads_per_stream},
138154
_thread_preferred_core_type(thread_preferred_core_type),
139155
_cpu_reservation{cpu_reservation},
140156
_cpu_pinning{cpu_pinning},
141-
_streams_info_table{std::move(streams_info_table)} {
157+
_streams_info_table{std::move(streams_info_table)},
158+
_rank{rank} {
142159
update_executor_config();
143160
}
144161

@@ -197,6 +214,9 @@ class OPENVINO_RUNTIME_API IStreamsExecutor : virtual public ITaskExecutor {
197214
int get_sub_streams() const {
198215
return _sub_streams;
199216
}
217+
std::vector<int> get_rank() const {
218+
return _rank;
219+
}
200220
StreamsMode get_sub_stream_mode() const {
201221
const auto proc_type_table = get_proc_type_table();
202222
int sockets = proc_type_table.size() > 1 ? static_cast<int>(proc_type_table.size()) - 1 : 1;
@@ -250,31 +270,17 @@ class OPENVINO_RUNTIME_API IStreamsExecutor : virtual public ITaskExecutor {
250270
virtual int get_socket_id() = 0;
251271

252272
/**
253-
* @brief Execute the task in the current thread using streams executor configuration and constraints
254-
* @param task A task to start
273+
* @brief Return the rank of current stream
274+
* Return {} when current stream has no rank
275+
* @return Rank array, or throws exceptions if called not from stream thread
255276
*/
256-
virtual void execute(Task task) = 0;
277+
virtual std::vector<int> get_rank() = 0;
257278

258279
/**
259-
* @brief Execute ov::Task inside sub stream of task executor context
280+
* @brief Execute the task in the current thread using streams executor configuration and constraints
260281
* @param task A task to start
261-
* @param id Sub stream id
262-
*/
263-
virtual void run_sub_stream(Task task, int id) = 0;
264-
265-
/**
266-
* @brief Execute all of the tasks and waits for its completion.
267-
* Default run_sub_stream_and_wait() method implementation uses run_sub_stream() pure virtual method
268-
* and higher level synchronization primitives from STL.
269-
* The task is wrapped into std::packaged_task which returns std::future.
270-
* std::packaged_task will call the task and signal to std::future that the task is finished
271-
* or the exception is thrown from task
272-
* Than std::future is used to wait for task execution completion and
273-
* task exception extraction
274-
* @note run_sub_stream_and_wait() does not copy or capture tasks!
275-
* @param tasks A vector of tasks to execute
276282
*/
277-
void run_sub_stream_and_wait(const std::vector<Task>& tasks);
283+
virtual void execute(Task task) = 0;
278284
};
279285

280286
} // namespace threading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// Copyright (C) 2018-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#include "openvino/runtime/threading/cpu_message.hpp"
5+
6+
#include <memory>
7+
#include <mutex>
8+
#include <queue>
9+
#include <set>
10+
#include <thread>
11+
#include <vector>
12+
13+
namespace ov {
14+
namespace threading {
15+
16+
MessageManager::MessageManager() {
17+
_num_sub_streams = 0;
18+
}
19+
20+
MessageManager::~MessageManager() {}
21+
22+
void MessageManager::send_message(const MessageInfo& msg_info) {
23+
{
24+
std::lock_guard<std::mutex> lock(_msgMutex);
25+
_messageQueue.push_back(msg_info);
26+
}
27+
_msgCondVar.notify_all();
28+
}
29+
30+
void MessageManager::server_wait() {
31+
assert(_num_sub_streams);
32+
MsgType msg_type;
33+
int count = 0;
34+
bool isStopped = false;
35+
while (!isStopped) {
36+
std::vector<MessageInfo> msgQueue;
37+
{
38+
std::unique_lock<std::mutex> lock(_msgMutex);
39+
_msgCondVar.wait(lock, [&] {
40+
return !_messageQueue.empty();
41+
});
42+
std::swap(_messageQueue, msgQueue);
43+
}
44+
45+
for (auto rec_info : msgQueue) {
46+
msg_type = rec_info.msg_type;
47+
if (msg_type == CALL_BACK) { // CALL_BACK
48+
count++;
49+
if (count == _num_sub_streams) {
50+
count = 0;
51+
isStopped = true;
52+
}
53+
}
54+
}
55+
};
56+
}
57+
58+
void MessageManager::set_num_sub_streams(int num_sub_streams) {
59+
_num_sub_streams = num_sub_streams;
60+
}
61+
62+
int MessageManager::get_num_sub_streams() {
63+
return _num_sub_streams;
64+
}
65+
66+
namespace {
67+
68+
class MessageManageHolder {
69+
std::mutex _mutex;
70+
std::weak_ptr<MessageManager> _manager;
71+
72+
public:
73+
MessageManageHolder(const MessageManageHolder&) = delete;
74+
MessageManageHolder& operator=(const MessageManageHolder&) = delete;
75+
76+
MessageManageHolder() = default;
77+
78+
std::shared_ptr<ov::threading::MessageManager> get() {
79+
std::lock_guard<std::mutex> lock(_mutex);
80+
auto manager = _manager.lock();
81+
if (!manager) {
82+
_manager = manager = std::make_shared<MessageManager>();
83+
}
84+
return manager;
85+
}
86+
};
87+
88+
} // namespace
89+
90+
std::shared_ptr<MessageManager> message_manager() {
91+
static MessageManageHolder message_manage;
92+
return message_manage.get();
93+
}
94+
95+
} // namespace threading
96+
} // namespace ov

0 commit comments

Comments
 (0)