@@ -45,61 +45,126 @@ class TensorMmapAllocator {
45
45
#endif
46
46
47
47
namespace ov ::genai {
48
+
48
49
class CacheManager {
49
- DeviceConfig m_device_config;
50
- std::vector<ov::Tensor> m_key_cache;
51
- std::vector<ov::Tensor> m_value_cache;
52
- size_t m_num_allocated_kv_blocks = 0 ;
50
+ size_t m_num_decoder_layers = 0 ;
51
+ std::string m_device;
52
+ std::vector<ov::element::Type> m_key_precisions, m_value_precisions;
53
+ std::vector<ov::PartialShape> m_key_shapes, m_value_shapes;
54
+ std::vector<ov::Tensor> m_key_cache, m_value_cache;
55
+ size_t m_num_allocated_kv_blocks = 0 , m_block_size_in_bytes = 0 ;
53
56
ov::InferRequest m_request;
54
- ov::Core m_core;
55
57
56
- ov::Shape set_first_dim_and_make_static (const ov::PartialShape& shape, size_t dim) {
57
- ov::PartialShape res_shape = shape;
58
- res_shape[0 ] = dim;
59
- OPENVINO_ASSERT (res_shape.is_static ());
60
- return res_shape.to_shape ();
58
+ static ov::Shape set_kv_blocks (ov::PartialShape pshape, size_t num_kv_blocks) {
59
+ pshape[0 ] = num_kv_blocks;
60
+ return pshape.get_shape ();
61
61
}
62
62
63
63
void update_request_tensor (size_t decoder_layer_id) {
64
64
m_request.set_tensor (std::string (" key_cache." ) + std::to_string (decoder_layer_id), m_key_cache[decoder_layer_id]);
65
65
m_request.set_tensor (std::string (" value_cache." ) + std::to_string (decoder_layer_id), m_value_cache[decoder_layer_id]);
66
66
}
67
67
68
+ ov::PartialShape patch_shape (ov::PartialShape pshape, ov::element::Type cache_type) {
69
+ OPENVINO_ASSERT (!m_device.empty (), " Internal error: device is not set" );
70
+
71
+ if (m_device.find (" CPU" ) != std::string::npos && cache_type == ov::element::u8) {
72
+ // Scale, zero point and quantized data will be stored together.
73
+ // The layout for per token per head:
74
+ // |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)|
75
+ // so, we have to extend head_size by 8, which is sizeof(float)
76
+ // for scale and sizeof(float) for zeropoint
77
+ pshape[3 ] += 2 * sizeof (float );
78
+ }
79
+
80
+ return pshape;
81
+ }
82
+
68
83
public:
69
- explicit CacheManager (const DeviceConfig &device_config, ov::InferRequest request, ov::Core core) :
70
- m_device_config(device_config),
71
- m_request(request),
72
- m_core(core) {
73
- m_key_cache.reserve (m_device_config.get_num_layers ());
74
- m_value_cache.reserve (m_device_config.get_num_layers ());
84
+ CacheManager (ov::InferRequest request, const DeviceConfig& device_config) :
85
+ m_request (request) {
86
+ // extract information about inference device
87
+ ov::CompiledModel compiled_model = request.get_compiled_model ();
88
+ std::vector<std::string> execution_devices = compiled_model.get_property (ov::execution_devices);
89
+ OPENVINO_ASSERT (execution_devices.size () == 1 , " Contituous batching: execution device is expected to be CPU or GPU, but got " , execution_devices.size (), " devices" );
90
+ m_device = execution_devices[0 ];
91
+
92
+ // extract information about KV cache precisions and shapes
93
+ size_t kv_input_index = 0 ;
94
+ for (const auto & input : compiled_model.inputs ()) {
95
+ for (auto & name : input.get_names ()) {
96
+ auto cache_precision = input.get_element_type ();
97
+
98
+ if (name.find (" key_cache." ) == 0 ) {
99
+ auto pshape = patch_shape (device_config.get_key_cache_shape (kv_input_index), cache_precision);
100
+ m_key_shapes.push_back (pshape);
101
+ m_key_precisions.push_back (cache_precision);
102
+ m_block_size_in_bytes += pshape[1 ].get_length () * pshape[2 ].get_length () * pshape[3 ].get_length () * cache_precision.size ();
103
+ break ;
104
+ } else if (name.find (" value_cache." ) == 0 ) {
105
+ auto pshape = patch_shape (device_config.get_value_cache_shape (kv_input_index), cache_precision);
106
+ m_value_shapes.push_back (pshape);
107
+ m_value_precisions.push_back (cache_precision);
108
+ m_block_size_in_bytes += pshape[1 ].get_length () * pshape[2 ].get_length () * pshape[3 ].get_length () * cache_precision.size ();
109
+ ++kv_input_index;
110
+ break ;
111
+ }
112
+ }
113
+ }
114
+
115
+ m_num_decoder_layers = m_value_precisions.size ();
116
+ OPENVINO_ASSERT (m_num_decoder_layers == m_key_precisions.size (), " Invalid case: a different number of K and V caches in a LLM model" );
117
+ }
118
+
119
+ size_t get_num_decoder_layers () const {
120
+ return m_num_decoder_layers;
121
+ }
122
+
123
+ std::string get_device () const {
124
+ return m_device;
125
+ }
126
+
127
+ ov::element::Type get_key_cache_precision (size_t decoder_layer_id) const {
128
+ OPENVINO_ASSERT (decoder_layer_id < m_key_precisions.size ());
129
+ return m_key_precisions[decoder_layer_id];
130
+ }
131
+
132
+ ov::element::Type get_value_cache_precision (size_t decoder_layer_id) const {
133
+ OPENVINO_ASSERT (decoder_layer_id < m_value_precisions.size ());
134
+ return m_value_precisions[decoder_layer_id];
135
+ }
136
+
137
+ size_t get_block_size_in_bytes () const {
138
+ return m_block_size_in_bytes;
75
139
}
76
140
77
141
void allocate_cache_if_needed (size_t num_kv_blocks) {
78
142
if (m_num_allocated_kv_blocks >= num_kv_blocks) {
79
143
return ;
80
144
}
81
- OPENVINO_ASSERT (m_key_cache.size () == m_value_cache.size ());
82
- m_num_allocated_kv_blocks = num_kv_blocks;
83
145
84
- const std::string device_name = m_device_config. get_device () ;
146
+ m_num_allocated_kv_blocks = num_kv_blocks ;
85
147
86
148
ov::Coordinate start_key{0 ,0 ,0 ,0 };
87
149
ov::Coordinate start_value{0 ,0 ,0 ,0 };
88
150
89
- if (device_name.find (" GPU" ) == std::string::npos) {// Allocate KV caches
90
- for (size_t decoder_layer_id = 0 ; decoder_layer_id < m_device_config.get_num_layers (); ++decoder_layer_id) {
91
- ov::Shape value_cache_shape = set_first_dim_and_make_static (m_device_config.get_value_cache_shape (decoder_layer_id), num_kv_blocks);
92
- ov::Shape key_cache_shape = set_first_dim_and_make_static (m_device_config.get_key_cache_shape (decoder_layer_id), num_kv_blocks);
151
+ if (m_device.find (" GPU" ) == std::string::npos) {// Allocate KV caches
152
+ for (size_t decoder_layer_id = 0 ; decoder_layer_id < m_num_decoder_layers; ++decoder_layer_id) {
153
+ ov::Shape value_cache_shape = set_kv_blocks (m_value_shapes[decoder_layer_id], num_kv_blocks);
154
+ ov::Shape key_cache_shape = set_kv_blocks (m_key_shapes[decoder_layer_id], num_kv_blocks);
155
+
156
+ ov::element::Type key_precision = get_key_cache_precision (decoder_layer_id);
157
+ ov::element::Type value_precision = get_value_cache_precision (decoder_layer_id);
158
+
93
159
#ifdef _WIN32
94
- ov::Tensor key_cache (m_device_config. get_cache_precision () , key_cache_shape);
95
- ov::Tensor value_cache (m_device_config. get_cache_precision () , value_cache_shape);
160
+ ov::Tensor key_cache (key_precision , key_cache_shape);
161
+ ov::Tensor value_cache (value_precision , value_cache_shape);
96
162
#else
97
- auto key_size = ov::shape_size (key_cache_shape) * m_device_config.get_cache_precision ().size ();
98
- auto value_size = ov::shape_size (value_cache_shape) * m_device_config.get_cache_precision ().size ();
99
-
100
- ov::Tensor key_cache = ov::Tensor (m_device_config.get_cache_precision (), key_cache_shape, TensorMmapAllocator (key_size));
101
- ov::Tensor value_cache = ov::Tensor (m_device_config.get_cache_precision (), value_cache_shape, TensorMmapAllocator (value_size));
163
+ auto key_size = ov::shape_size (key_cache_shape) * key_precision.size ();
164
+ auto value_size = ov::shape_size (value_cache_shape) * value_precision.size ();
102
165
166
+ ov::Tensor key_cache (key_precision, key_cache_shape, TensorMmapAllocator (key_size));
167
+ ov::Tensor value_cache (value_precision, value_cache_shape, TensorMmapAllocator (value_size));
103
168
#endif
104
169
105
170
auto key_cache_roi_end = static_cast <unsigned char *>(key_cache.data ());
@@ -137,24 +202,23 @@ class CacheManager {
137
202
if (m_key_cache.size () > decoder_layer_id) {
138
203
m_key_cache[decoder_layer_id] = key_cache;
139
204
m_value_cache[decoder_layer_id] = value_cache;
140
- }
141
- else {
205
+ } else {
142
206
m_key_cache.emplace_back (key_cache);
143
207
m_value_cache.emplace_back (value_cache);
144
208
}
145
209
146
210
update_request_tensor (decoder_layer_id);
147
211
}
148
212
} else {
149
- auto remote_context = m_core. get_default_context (device_name );
150
- for ( size_t decoder_layer_id = 0 ; decoder_layer_id < m_device_config. get_num_layers (); ++decoder_layer_id) {
151
- ov::Shape value_cache_shape = set_first_dim_and_make_static (m_device_config. get_value_cache_shape ( decoder_layer_id), num_kv_blocks);
152
- ov::Shape key_cache_shape = set_first_dim_and_make_static (m_device_config. get_key_cache_shape ( decoder_layer_id) , num_kv_blocks);
153
- ov::Tensor key_cache = remote_context. create_tensor (m_device_config. get_cache_precision (),
154
- key_cache_shape);
155
- ov::Tensor value_cache = remote_context.create_tensor (m_device_config. get_cache_precision (),
156
- value_cache_shape);
157
-
213
+ auto remote_context = m_request. get_compiled_model (). get_context ( );
214
+
215
+ for ( size_t decoder_layer_id = 0 ; decoder_layer_id < m_num_decoder_layers; ++ decoder_layer_id) {
216
+ ov::Shape value_cache_shape = set_kv_blocks (m_value_shapes[ decoder_layer_id] , num_kv_blocks);
217
+ ov::Shape key_cache_shape = set_kv_blocks (m_key_shapes[decoder_layer_id], num_kv_blocks);
218
+
219
+ ov::Tensor key_cache = remote_context.create_tensor (get_key_cache_precision (decoder_layer_id), key_cache_shape);
220
+ ov::Tensor value_cache = remote_context. create_tensor ( get_value_cache_precision (decoder_layer_id), value_cache_shape);
221
+
158
222
if (m_key_cache.size () > decoder_layer_id) {
159
223
ov::Coordinate end_key = m_key_cache[decoder_layer_id].get_shape ();
160
224
ov::Coordinate end_value = m_value_cache[decoder_layer_id].get_shape ();
@@ -167,23 +231,23 @@ class CacheManager {
167
231
168
232
m_key_cache[decoder_layer_id] = key_cache;
169
233
m_value_cache[decoder_layer_id] = value_cache;
170
- }
171
- else {
234
+ } else {
172
235
m_key_cache.emplace_back (key_cache);
173
236
m_value_cache.emplace_back (value_cache);
174
237
}
238
+
175
239
update_request_tensor (decoder_layer_id);
176
240
}
177
241
}
178
242
}
179
243
180
244
ov::Tensor get_key_cache (size_t decoder_layer_id) const {
181
- OPENVINO_ASSERT (decoder_layer_id < m_key_cache.size ());
245
+ OPENVINO_ASSERT (decoder_layer_id < m_key_cache.size (), " decoder_layer_id = " , decoder_layer_id, " , num_layers = " , m_key_cache. size () );
182
246
return m_key_cache[decoder_layer_id];
183
247
}
184
248
185
249
ov::Tensor get_value_cache (size_t decoder_layer_id) const {
186
- OPENVINO_ASSERT (decoder_layer_id < m_value_cache.size ());
250
+ OPENVINO_ASSERT (decoder_layer_id < m_value_cache.size (), " decoder_layer_id = " , decoder_layer_id, " , num_layers = " , m_value_cache. size () );
187
251
return m_value_cache[decoder_layer_id];
188
252
}
189
253
@@ -192,9 +256,9 @@ class CacheManager {
192
256
size_t src_block_id = blocks_pair.first ;
193
257
const std::list<size_t >& dst_block_ids = blocks_pair.second ;
194
258
for (size_t dst_block_id : dst_block_ids) {
195
- for (size_t decoder_layer_id = 0 ; decoder_layer_id < m_device_config. get_num_layers () ; ++decoder_layer_id) {
196
- ov::Shape key_shape = set_first_dim_and_make_static (m_device_config. get_key_cache_shape ( decoder_layer_id) , m_num_allocated_kv_blocks);
197
- ov::Shape value_shape = set_first_dim_and_make_static (m_device_config. get_value_cache_shape ( decoder_layer_id) , m_num_allocated_kv_blocks);
259
+ for (size_t decoder_layer_id = 0 ; decoder_layer_id < m_num_decoder_layers ; ++decoder_layer_id) {
260
+ ov::Shape key_shape = set_kv_blocks (m_key_shapes[ decoder_layer_id] , m_num_allocated_kv_blocks);
261
+ ov::Shape value_shape = set_kv_blocks (m_value_shapes[ decoder_layer_id] , m_num_allocated_kv_blocks);
198
262
ov::Coordinate key_src_start_roi (key_shape.size (), 0 );
199
263
ov::Coordinate key_src_end_roi = key_shape;
200
264
ov::Coordinate key_dst_start_roi (key_shape.size (), 0 );
@@ -221,13 +285,6 @@ class CacheManager {
221
285
}
222
286
}
223
287
}
224
-
225
- std::shared_ptr<Core> get_core () {
226
- return std::make_shared<Core>(m_core);
227
- }
228
-
229
- std::shared_ptr<DeviceConfig> get_device_config () {
230
- return std::make_shared<DeviceConfig>(m_device_config);
231
- }
232
288
};
289
+
233
290
}
0 commit comments