Skip to content

Commit 5cbadd1

Browse files
CB: preparation for relying on KV cache precisions from plugins (openvinotoolkit#1634)
- Currently we have logic to detect KV cache precision and this logic become more and more complex - The idea is to rely on plugin's logic and compiled PA model with `ov::element::dynamic` precisions for KV cache inputs. - Later, take `ov::CompiledModel` and extract precisions from its `inputs()` - Then create tensors based on computed `num_kv_blocks` which depends on KV cache precisions. Currently, logic to mimic plugin's logic for KV cache precisions is still here, but will be dropped once plugin will support `ov::element::dynamic`
1 parent 4fb48de commit 5cbadd1

18 files changed

+352
-340
lines changed

.github/labeler.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@
103103
- 'src/cpp/src/generation_handle.cpp'
104104
- 'src/cpp/src/generation_stream.hpp'
105105
- 'src/cpp/src/model_runner.hpp'
106-
- 'src/cpp/src/utils/paged_attention_transformations.cpp'
107-
- 'src/cpp/src/utils/paged_attention_transformations.hpp'
106+
- 'src/cpp/src/paged_attention_transformations.cpp'
107+
- 'src/cpp/src/paged_attention_transformations.hpp'
108108
- 'src/cpp/src/scheduler.hpp'
109109
- 'src/cpp/src/sequence_group.cpp'
110110
- 'src/cpp/src/sequence_group.hpp'

src/cpp/src/cache_manager.hpp

+113-56
Original file line numberDiff line numberDiff line change
@@ -45,61 +45,126 @@ class TensorMmapAllocator {
4545
#endif
4646

4747
namespace ov::genai {
48+
4849
class CacheManager {
49-
DeviceConfig m_device_config;
50-
std::vector<ov::Tensor> m_key_cache;
51-
std::vector<ov::Tensor> m_value_cache;
52-
size_t m_num_allocated_kv_blocks = 0;
50+
size_t m_num_decoder_layers = 0;
51+
std::string m_device;
52+
std::vector<ov::element::Type> m_key_precisions, m_value_precisions;
53+
std::vector<ov::PartialShape> m_key_shapes, m_value_shapes;
54+
std::vector<ov::Tensor> m_key_cache, m_value_cache;
55+
size_t m_num_allocated_kv_blocks = 0, m_block_size_in_bytes = 0;
5356
ov::InferRequest m_request;
54-
ov::Core m_core;
5557

56-
ov::Shape set_first_dim_and_make_static(const ov::PartialShape& shape, size_t dim) {
57-
ov::PartialShape res_shape = shape;
58-
res_shape[0] = dim;
59-
OPENVINO_ASSERT(res_shape.is_static());
60-
return res_shape.to_shape();
58+
static ov::Shape set_kv_blocks(ov::PartialShape pshape, size_t num_kv_blocks) {
59+
pshape[0] = num_kv_blocks;
60+
return pshape.get_shape();
6161
}
6262

6363
void update_request_tensor(size_t decoder_layer_id) {
6464
m_request.set_tensor(std::string("key_cache.") + std::to_string(decoder_layer_id), m_key_cache[decoder_layer_id]);
6565
m_request.set_tensor(std::string("value_cache.") + std::to_string(decoder_layer_id), m_value_cache[decoder_layer_id]);
6666
}
6767

68+
ov::PartialShape patch_shape(ov::PartialShape pshape, ov::element::Type cache_type) {
69+
OPENVINO_ASSERT(!m_device.empty(), "Internal error: device is not set");
70+
71+
if (m_device.find("CPU") != std::string::npos && cache_type == ov::element::u8) {
72+
// Scale, zero point and quantized data will be stored together.
73+
// The layout for per token per head:
74+
// |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)|
75+
// so, we have to extend head_size by 8, which is sizeof(float)
76+
// for scale and sizeof(float) for zeropoint
77+
pshape[3] += 2 * sizeof(float);
78+
}
79+
80+
return pshape;
81+
}
82+
6883
public:
69-
explicit CacheManager(const DeviceConfig &device_config, ov::InferRequest request, ov::Core core) :
70-
m_device_config(device_config),
71-
m_request(request),
72-
m_core(core) {
73-
m_key_cache.reserve(m_device_config.get_num_layers());
74-
m_value_cache.reserve(m_device_config.get_num_layers());
84+
CacheManager(ov::InferRequest request, const DeviceConfig& device_config) :
85+
m_request(request) {
86+
// extract information about inference device
87+
ov::CompiledModel compiled_model = request.get_compiled_model();
88+
std::vector<std::string> execution_devices = compiled_model.get_property(ov::execution_devices);
89+
OPENVINO_ASSERT(execution_devices.size() == 1, "Contituous batching: execution device is expected to be CPU or GPU, but got ", execution_devices.size(), " devices");
90+
m_device = execution_devices[0];
91+
92+
// extract information about KV cache precisions and shapes
93+
size_t kv_input_index = 0;
94+
for (const auto& input : compiled_model.inputs()) {
95+
for (auto & name : input.get_names()) {
96+
auto cache_precision = input.get_element_type();
97+
98+
if (name.find("key_cache.") == 0) {
99+
auto pshape = patch_shape(device_config.get_key_cache_shape(kv_input_index), cache_precision);
100+
m_key_shapes.push_back(pshape);
101+
m_key_precisions.push_back(cache_precision);
102+
m_block_size_in_bytes += pshape[1].get_length() * pshape[2].get_length() * pshape[3].get_length() * cache_precision.size();
103+
break;
104+
} else if (name.find("value_cache.") == 0) {
105+
auto pshape = patch_shape(device_config.get_value_cache_shape(kv_input_index), cache_precision);
106+
m_value_shapes.push_back(pshape);
107+
m_value_precisions.push_back(cache_precision);
108+
m_block_size_in_bytes += pshape[1].get_length() * pshape[2].get_length() * pshape[3].get_length() * cache_precision.size();
109+
++kv_input_index;
110+
break;
111+
}
112+
}
113+
}
114+
115+
m_num_decoder_layers = m_value_precisions.size();
116+
OPENVINO_ASSERT(m_num_decoder_layers == m_key_precisions.size(), "Invalid case: a different number of K and V caches in a LLM model");
117+
}
118+
119+
size_t get_num_decoder_layers() const {
120+
return m_num_decoder_layers;
121+
}
122+
123+
std::string get_device() const {
124+
return m_device;
125+
}
126+
127+
ov::element::Type get_key_cache_precision(size_t decoder_layer_id) const {
128+
OPENVINO_ASSERT(decoder_layer_id < m_key_precisions.size());
129+
return m_key_precisions[decoder_layer_id];
130+
}
131+
132+
ov::element::Type get_value_cache_precision(size_t decoder_layer_id) const {
133+
OPENVINO_ASSERT(decoder_layer_id < m_value_precisions.size());
134+
return m_value_precisions[decoder_layer_id];
135+
}
136+
137+
size_t get_block_size_in_bytes() const {
138+
return m_block_size_in_bytes;
75139
}
76140

77141
void allocate_cache_if_needed(size_t num_kv_blocks) {
78142
if (m_num_allocated_kv_blocks >= num_kv_blocks) {
79143
return;
80144
}
81-
OPENVINO_ASSERT(m_key_cache.size() == m_value_cache.size());
82-
m_num_allocated_kv_blocks = num_kv_blocks;
83145

84-
const std::string device_name = m_device_config.get_device();
146+
m_num_allocated_kv_blocks = num_kv_blocks;
85147

86148
ov::Coordinate start_key{0,0,0,0};
87149
ov::Coordinate start_value{0,0,0,0};
88150

89-
if (device_name.find("GPU") == std::string::npos) {// Allocate KV caches
90-
for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
91-
ov::Shape value_cache_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(decoder_layer_id), num_kv_blocks);
92-
ov::Shape key_cache_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(decoder_layer_id), num_kv_blocks);
151+
if (m_device.find("GPU") == std::string::npos) {// Allocate KV caches
152+
for (size_t decoder_layer_id = 0; decoder_layer_id < m_num_decoder_layers; ++decoder_layer_id) {
153+
ov::Shape value_cache_shape = set_kv_blocks(m_value_shapes[decoder_layer_id], num_kv_blocks);
154+
ov::Shape key_cache_shape = set_kv_blocks(m_key_shapes[decoder_layer_id], num_kv_blocks);
155+
156+
ov::element::Type key_precision = get_key_cache_precision(decoder_layer_id);
157+
ov::element::Type value_precision = get_value_cache_precision(decoder_layer_id);
158+
93159
#ifdef _WIN32
94-
ov::Tensor key_cache(m_device_config.get_cache_precision(), key_cache_shape);
95-
ov::Tensor value_cache(m_device_config.get_cache_precision(), value_cache_shape);
160+
ov::Tensor key_cache(key_precision, key_cache_shape);
161+
ov::Tensor value_cache(value_precision, value_cache_shape);
96162
#else
97-
auto key_size = ov::shape_size(key_cache_shape) * m_device_config.get_cache_precision().size();
98-
auto value_size = ov::shape_size(value_cache_shape) * m_device_config.get_cache_precision().size();
99-
100-
ov::Tensor key_cache = ov::Tensor(m_device_config.get_cache_precision(), key_cache_shape, TensorMmapAllocator(key_size));
101-
ov::Tensor value_cache = ov::Tensor(m_device_config.get_cache_precision(), value_cache_shape, TensorMmapAllocator(value_size));
163+
auto key_size = ov::shape_size(key_cache_shape) * key_precision.size();
164+
auto value_size = ov::shape_size(value_cache_shape) * value_precision.size();
102165

166+
ov::Tensor key_cache(key_precision, key_cache_shape, TensorMmapAllocator(key_size));
167+
ov::Tensor value_cache(value_precision, value_cache_shape, TensorMmapAllocator(value_size));
103168
#endif
104169

105170
auto key_cache_roi_end = static_cast<unsigned char*>(key_cache.data());
@@ -137,24 +202,23 @@ class CacheManager {
137202
if (m_key_cache.size() > decoder_layer_id) {
138203
m_key_cache[decoder_layer_id] = key_cache;
139204
m_value_cache[decoder_layer_id] = value_cache;
140-
}
141-
else {
205+
} else {
142206
m_key_cache.emplace_back(key_cache);
143207
m_value_cache.emplace_back(value_cache);
144208
}
145209

146210
update_request_tensor(decoder_layer_id);
147211
}
148212
} else {
149-
auto remote_context = m_core.get_default_context(device_name);
150-
for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
151-
ov::Shape value_cache_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(decoder_layer_id), num_kv_blocks);
152-
ov::Shape key_cache_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(decoder_layer_id), num_kv_blocks);
153-
ov::Tensor key_cache = remote_context.create_tensor(m_device_config.get_cache_precision(),
154-
key_cache_shape);
155-
ov::Tensor value_cache = remote_context.create_tensor(m_device_config.get_cache_precision(),
156-
value_cache_shape);
157-
213+
auto remote_context = m_request.get_compiled_model().get_context();
214+
215+
for (size_t decoder_layer_id = 0; decoder_layer_id < m_num_decoder_layers; ++decoder_layer_id) {
216+
ov::Shape value_cache_shape = set_kv_blocks(m_value_shapes[decoder_layer_id], num_kv_blocks);
217+
ov::Shape key_cache_shape = set_kv_blocks(m_key_shapes[decoder_layer_id], num_kv_blocks);
218+
219+
ov::Tensor key_cache = remote_context.create_tensor(get_key_cache_precision(decoder_layer_id), key_cache_shape);
220+
ov::Tensor value_cache = remote_context.create_tensor(get_value_cache_precision(decoder_layer_id), value_cache_shape);
221+
158222
if (m_key_cache.size() > decoder_layer_id) {
159223
ov::Coordinate end_key = m_key_cache[decoder_layer_id].get_shape();
160224
ov::Coordinate end_value = m_value_cache[decoder_layer_id].get_shape();
@@ -167,23 +231,23 @@ class CacheManager {
167231

168232
m_key_cache[decoder_layer_id] = key_cache;
169233
m_value_cache[decoder_layer_id] = value_cache;
170-
}
171-
else {
234+
} else {
172235
m_key_cache.emplace_back(key_cache);
173236
m_value_cache.emplace_back(value_cache);
174237
}
238+
175239
update_request_tensor(decoder_layer_id);
176240
}
177241
}
178242
}
179243

180244
ov::Tensor get_key_cache(size_t decoder_layer_id) const {
181-
OPENVINO_ASSERT(decoder_layer_id < m_key_cache.size());
245+
OPENVINO_ASSERT(decoder_layer_id < m_key_cache.size(), "decoder_layer_id = ", decoder_layer_id, ", num_layers = ", m_key_cache.size());
182246
return m_key_cache[decoder_layer_id];
183247
}
184248

185249
ov::Tensor get_value_cache(size_t decoder_layer_id) const {
186-
OPENVINO_ASSERT(decoder_layer_id < m_value_cache.size());
250+
OPENVINO_ASSERT(decoder_layer_id < m_value_cache.size(), "decoder_layer_id = ", decoder_layer_id, ", num_layers = ", m_value_cache.size());
187251
return m_value_cache[decoder_layer_id];
188252
}
189253

@@ -192,9 +256,9 @@ class CacheManager {
192256
size_t src_block_id = blocks_pair.first;
193257
const std::list<size_t>& dst_block_ids = blocks_pair.second;
194258
for (size_t dst_block_id : dst_block_ids) {
195-
for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
196-
ov::Shape key_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(decoder_layer_id), m_num_allocated_kv_blocks);
197-
ov::Shape value_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(decoder_layer_id), m_num_allocated_kv_blocks);
259+
for (size_t decoder_layer_id = 0; decoder_layer_id < m_num_decoder_layers; ++decoder_layer_id) {
260+
ov::Shape key_shape = set_kv_blocks(m_key_shapes[decoder_layer_id], m_num_allocated_kv_blocks);
261+
ov::Shape value_shape = set_kv_blocks(m_value_shapes[decoder_layer_id], m_num_allocated_kv_blocks);
198262
ov::Coordinate key_src_start_roi(key_shape.size(), 0);
199263
ov::Coordinate key_src_end_roi = key_shape;
200264
ov::Coordinate key_dst_start_roi(key_shape.size(), 0);
@@ -221,13 +285,6 @@ class CacheManager {
221285
}
222286
}
223287
}
224-
225-
std::shared_ptr<Core> get_core() {
226-
return std::make_shared<Core>(m_core);
227-
}
228-
229-
std::shared_ptr<DeviceConfig> get_device_config() {
230-
return std::make_shared<DeviceConfig>(m_device_config);
231-
}
232288
};
289+
233290
}

0 commit comments

Comments
 (0)