@@ -167,7 +167,7 @@ void reshape_and_cache(ov::Tensor key, ov::Tensor value,
167
167
ov::Tensor generate_attention_mask (const std::int32_t num_seqs, const std::int32_t max_context_len, ov::Tensor context_lens) {
168
168
OPENVINO_ASSERT (num_seqs == context_lens.get_size ());
169
169
170
- ov::Shape attention_mask_shape ({num_seqs, max_context_len, max_context_len});
170
+ ov::Shape attention_mask_shape ({num_seqs, 1 , max_context_len, max_context_len});
171
171
ov::Tensor attention_mask (ov::element::boolean, attention_mask_shape);
172
172
int attention_mask_stride = attention_mask.get_strides ()[0 ];
173
173
@@ -186,21 +186,6 @@ ov::Tensor generate_attention_mask(const std::int32_t num_seqs, const std::int32
186
186
}
187
187
}
188
188
189
- // similar to torch.Tensor.view
190
- ov::Tensor view_as_3d (ov::Tensor tensor) {
191
- ov::Shape shape = tensor.get_shape ();
192
- OPENVINO_ASSERT (shape.size () == 4 );
193
- const std::uint32_t batch_size = shape[0 ], seq_len = shape[1 ], num_heads = shape[2 ], head_size = shape[3 ];
194
- return ov::Tensor (tensor.get_element_type (), ov::Shape ({batch_size, seq_len, num_heads * head_size}), tensor.data ());
195
- }
196
-
197
- ov::Tensor view_as_4d (ov::Tensor tensor, std::uint32_t num_heads, std::uint32_t head_size) {
198
- ov::Shape shape = tensor.get_shape ();
199
- const std::uint32_t batch_size = shape[0 ], seq_len = shape[1 ];
200
- OPENVINO_ASSERT (shape.size () == 3 && num_heads * head_size == shape[3 ]);
201
- return ov::Tensor (tensor.get_element_type (), ov::Shape ({batch_size, seq_len, num_heads, head_size}), tensor.data ());
202
- }
203
-
204
189
bool TemplateExtension::PagedAttention::evaluate (ov::TensorVector& outputs, const ov::TensorVector& inputs) const {
205
190
ov::Tensor query = inputs[0 ], key = inputs[1 ], value = inputs[2 ];
206
191
ov::Shape query_shape = query.get_shape ();
@@ -212,10 +197,10 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons
212
197
ov::Tensor context_lens = inputs[8 ];
213
198
ov::Tensor block_tables = inputs[9 ];
214
199
215
- // reshape to [batch_size, seq_len, num_heads/ m_num_kv_heads, head_size] from [batch_size, seq_len, num_heads/m_num_kv_heads * head_size]
216
- query = view_as_4d (query , m_num_heads, m_head_size);
217
- key = view_as_4d (key , m_num_kv_heads, m_head_size);
218
- value = view_as_4d (value, m_num_kv_heads, m_head_size );
200
+ // reshape to [batch_size * seq_len, m_num_kv_heads, head_size] from [batch_size, seq_len, num_heads/m_num_kv_heads * head_size]
201
+ query. set_shape ({batch_size * seq_len , m_num_heads, m_head_size} );
202
+ key. set_shape ({batch_size * seq_len , m_num_kv_heads, m_head_size} );
203
+ value. set_shape (key. get_shape () );
219
204
220
205
// put current K, V values into key_cache and value_cache
221
206
reshape_and_cache (key, value, key_cache, value_cache, slot_mapping);
@@ -225,6 +210,12 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons
225
210
outputs[0 ].set_shape (query.get_shape ());
226
211
227
212
if (is_prompt) {
213
+ // reshape to [batch_size, seq_len, m_num_kv_heads, head_size]
214
+ query.set_shape ({batch_size, seq_len, m_num_heads, m_head_size});
215
+ outputs[0 ].set_shape (query.get_shape ());
216
+ key.set_shape ({batch_size, seq_len, m_num_kv_heads, m_head_size});
217
+ value.set_shape (key.get_shape ());
218
+
228
219
auto attention_mask = generate_attention_mask (batch_size, max_context_len, context_lens);
229
220
ov::Tensor scale (ov::element::f32, ov::Shape{1 }, (void *)&m_scale);
230
221
@@ -237,15 +228,16 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons
237
228
238
229
m_prefill_request.infer ();
239
230
} else {
231
+ // 'query' and 'output' are expected to be [batch_size * seq_len, m_num_kv_heads, head_size]
240
232
paged_attention_v1_cpu (outputs[0 ],
241
233
query, key_cache, value_cache,
242
234
m_num_kv_heads, m_scale,
243
235
block_tables, context_lens,
244
236
m_block_size, max_context_len);
245
237
}
246
238
247
- // reshape
248
- outputs[0 ] = view_as_3d (outputs[ 0 ]);
239
+ // reshape back to [batch_size, seq_len, num_heads * head_size]
240
+ outputs[0 ]. set_shape (query_shape); // works like reshape
249
241
250
242
return true ;
251
243
}
0 commit comments