Skip to content

Commit 04e7014

Browse files
committed
Use Tensor::set_shape instead of view
1 parent e9d5e44 commit 04e7014

File tree

2 files changed

+14
-22
lines changed

2 files changed

+14
-22
lines changed

modules/custom_operations/tests/CMakeLists.txt

Whitespace-only changes.

modules/custom_operations/user_ie_extensions/paged_attention/paged_attention.cpp

+14-22
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ void reshape_and_cache(ov::Tensor key, ov::Tensor value,
167167
ov::Tensor generate_attention_mask(const std::int32_t num_seqs, const std::int32_t max_context_len, ov::Tensor context_lens) {
168168
OPENVINO_ASSERT(num_seqs == context_lens.get_size());
169169

170-
ov::Shape attention_mask_shape({num_seqs, max_context_len, max_context_len});
170+
ov::Shape attention_mask_shape({num_seqs, 1, max_context_len, max_context_len});
171171
ov::Tensor attention_mask(ov::element::boolean, attention_mask_shape);
172172
int attention_mask_stride = attention_mask.get_strides()[0];
173173

@@ -186,21 +186,6 @@ ov::Tensor generate_attention_mask(const std::int32_t num_seqs, const std::int32
186186
}
187187
}
188188

189-
// similar to torch.Tensor.view
190-
ov::Tensor view_as_3d(ov::Tensor tensor) {
191-
ov::Shape shape = tensor.get_shape();
192-
OPENVINO_ASSERT(shape.size() == 4);
193-
const std::uint32_t batch_size = shape[0], seq_len = shape[1], num_heads = shape[2], head_size = shape[3];
194-
return ov::Tensor(tensor.get_element_type(), ov::Shape({batch_size, seq_len, num_heads * head_size}), tensor.data());
195-
}
196-
197-
ov::Tensor view_as_4d(ov::Tensor tensor, std::uint32_t num_heads, std::uint32_t head_size) {
198-
ov::Shape shape = tensor.get_shape();
199-
const std::uint32_t batch_size = shape[0], seq_len = shape[1];
200-
OPENVINO_ASSERT(shape.size() == 3 && num_heads * head_size == shape[3]);
201-
return ov::Tensor(tensor.get_element_type(), ov::Shape({batch_size, seq_len, num_heads, head_size}), tensor.data());
202-
}
203-
204189
bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const {
205190
ov::Tensor query = inputs[0], key = inputs[1], value = inputs[2];
206191
ov::Shape query_shape = query.get_shape();
@@ -212,10 +197,10 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons
212197
ov::Tensor context_lens = inputs[8];
213198
ov::Tensor block_tables = inputs[9];
214199

215-
// reshape to [batch_size, seq_len, num_heads/m_num_kv_heads, head_size] from [batch_size, seq_len, num_heads/m_num_kv_heads * head_size]
216-
query = view_as_4d(query, m_num_heads, m_head_size);
217-
key = view_as_4d(key, m_num_kv_heads, m_head_size);
218-
value = view_as_4d(value, m_num_kv_heads, m_head_size);
200+
// reshape to [batch_size * seq_len, m_num_kv_heads, head_size] from [batch_size, seq_len, num_heads/m_num_kv_heads * head_size]
201+
query.set_shape({batch_size * seq_len, m_num_heads, m_head_size});
202+
key.set_shape({batch_size * seq_len, m_num_kv_heads, m_head_size});
203+
value.set_shape(key.get_shape());
219204

220205
// put current K, V values into key_cache and value_cache
221206
reshape_and_cache(key, value, key_cache, value_cache, slot_mapping);
@@ -225,6 +210,12 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons
225210
outputs[0].set_shape(query.get_shape());
226211

227212
if (is_prompt) {
213+
// reshape to [batch_size, seq_len, m_num_kv_heads, head_size]
214+
query.set_shape({batch_size, seq_len, m_num_heads, m_head_size});
215+
outputs[0].set_shape(query.get_shape());
216+
key.set_shape({batch_size, seq_len, m_num_kv_heads, m_head_size});
217+
value.set_shape(key.get_shape());
218+
228219
auto attention_mask = generate_attention_mask(batch_size, max_context_len, context_lens);
229220
ov::Tensor scale(ov::element::f32, ov::Shape{1}, (void *)&m_scale);
230221

@@ -237,15 +228,16 @@ bool TemplateExtension::PagedAttention::evaluate(ov::TensorVector& outputs, cons
237228

238229
m_prefill_request.infer();
239230
} else {
231+
// 'query' and 'output' are expected to be [batch_size * seq_len, m_num_kv_heads, head_size]
240232
paged_attention_v1_cpu(outputs[0],
241233
query, key_cache, value_cache,
242234
m_num_kv_heads, m_scale,
243235
block_tables, context_lens,
244236
m_block_size, max_context_len);
245237
}
246238

247-
// reshape
248-
outputs[0] = view_as_3d(outputs[0]);
239+
// reshape back to [batch_size, seq_len, num_heads * head_size]
240+
outputs[0].set_shape(query_shape); // works like reshape
249241

250242
return true;
251243
}

0 commit comments

Comments
 (0)