Skip to content

Commit 0b260ff

Browse files
[CPU] Add PagedAttention support (#23524)
### Details: - *Support PagedAttention support, depends on:* - openvino_contrib: openvinotoolkit/openvino_contrib#867 - vLLM: ilya-lavrenov/vllm#4 - *TODO* - Models with alibi feature ### Tickets: - *[134329](https://jira.devtools.intel.com/browse/CVS-134329)* - *[134327](https://jira.devtools.intel.com/browse/CVS-134327)*
1 parent 1c0ca0e commit 0b260ff

File tree

11 files changed

+544
-198
lines changed

11 files changed

+544
-198
lines changed

src/common/transformations/src/transformations/convert_precision.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,11 @@ bool fuse_type_to_parameter(const std::shared_ptr<ov::Node>& node,
607607
auto convert = std::make_shared<opset4::Convert>(param, to);
608608
for (auto& input : param_consumers) {
609609
const auto consumer = input.get_node();
610-
if (ov::is_type<ov::op::v0::Result>(consumer) || ov::is_type<ov::op::v0::Convert>(consumer)) {
610+
if (ov::is_type<ov::op::v0::Result>(consumer) || ov::is_type<ov::op::v0::Convert>(consumer) ||
611+
// TODO: refactor after ngraph op defined
612+
// The fourth and fifth inputs are kvcache and should be directly connected to parameters
613+
(consumer->get_type_name() == std::string("PagedAttentionExtension") &&
614+
(input.get_index() == 3 || input.get_index() == 4))) {
611615
continue;
612616
}
613617
input.replace_source_output(convert);

src/plugins/intel_cpu/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ cross_compiled_file(${TARGET_NAME}
176176
ARCH AVX512F AVX2 ANY
177177
src/nodes/kernels/scaled_attn/attn_memcpy.cpp
178178
API src/nodes/kernels/scaled_attn/attn_memcpy.hpp
179-
NAME attn_memcpy
179+
NAME attn_memcpy paged_attn_memcpy
180180
NAMESPACE ov::Extensions::Cpu::XARCH
181181
)
182182
cross_compiled_file(${TARGET_NAME}

src/plugins/intel_cpu/src/cpu_types.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
217217
{"Ngram", Type::Ngram},
218218
{"ScaledDotProductAttention", Type::ScaledDotProductAttention},
219219
{"ScaledDotProductAttentionWithKVCache", Type::ScaledDotProductAttention},
220+
{"PagedAttentionExtension", Type::ScaledDotProductAttention},
220221
{"RoPE", Type::RoPE},
221222
};
222223
return type_to_name_tbl;

src/plugins/intel_cpu/src/graph.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -1680,6 +1680,10 @@ void Graph::EnforceInferencePrecision() {
16801680
if (node->getOriginalInputPrecisionAtPort(inPort) != ov::element::f32)
16811681
return true;
16821682

1683+
// kvcache of PagedAttention should be written directly
1684+
if (node->getType() == Type::ScaledDotProductAttention && node->getOriginalInputsNumber() == 13 &&
1685+
(inPort == 3 || inPort == 4))
1686+
return true;
16831687
const auto &parent = node->getParentEdgeAt(inPort)->getParent();
16841688
/* Skip BF16 enforcement for nodes after Constant Inputs for maintaining precision for fusing.
16851689
* Element type conversion to bf16 is done automatically, if convolution follows up after Constant Inputs

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp

+54
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,43 @@ static void attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
7676
});
7777
}
7878

79+
template <typename T, typename T2>
80+
static void paged_attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
81+
const ov::intel_cpu::PlainTensor& v_input,
82+
const ov::intel_cpu::PlainTensor& past_k_output,
83+
const ov::intel_cpu::PlainTensor& past_v_output,
84+
const ov::intel_cpu::PlainTensor& slot_mapping) {
85+
size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3];
86+
parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) {
87+
auto block_idx = slot_mapping.ptr<int32_t>(b)[m];
88+
if (block_idx < 0) return;
89+
attn_copy(past_k_output.ptr<T2>(block_idx, h, 0),
90+
k_input.ptr<T>(b, h, m, 0),
91+
S);
92+
attn_copy(past_v_output.ptr<T2>(block_idx, h, 0),
93+
v_input.ptr<T>(b, h, m, 0),
94+
S);
95+
});
96+
}
97+
98+
static void paged_attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
99+
const ov::intel_cpu::PlainTensor& v_input,
100+
const ov::intel_cpu::PlainTensor& past_k_output,
101+
const ov::intel_cpu::PlainTensor& past_v_output,
102+
const ov::intel_cpu::PlainTensor& slot_mapping) {
103+
size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3];
104+
parallel_for3d(B, H, L1, [&](size_t b, size_t h, size_t m) {
105+
auto block_idx = slot_mapping.ptr<int32_t>(b)[m];
106+
if (block_idx < 0) return;
107+
std::memcpy(past_k_output.ptr_v(block_idx, h, 0),
108+
k_input.ptr_v(b, h, m, 0),
109+
S * k_input.m_element_size);
110+
std::memcpy(past_v_output.ptr_v(block_idx, h, 0),
111+
v_input.ptr_v(b, h, m, 0),
112+
S * v_input.m_element_size);
113+
});
114+
}
115+
79116
void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
80117
const ov::intel_cpu::PlainTensor& v_input,
81118
const ov::intel_cpu::PlainTensor& past_k_output,
@@ -90,6 +127,23 @@ void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
90127
OPENVINO_THROW("unsupport src type: ", k_input.get_precision(), ", dst type: ", past_k_output.get_precision(), " in attn_memcpy");
91128
}
92129
}
130+
131+
void paged_attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
132+
const ov::intel_cpu::PlainTensor& v_input,
133+
const ov::intel_cpu::PlainTensor& past_k_output,
134+
const ov::intel_cpu::PlainTensor& past_v_output,
135+
const ov::intel_cpu::PlainTensor& slot_mapping) {
136+
if (past_k_output.get_precision() == k_input.get_precision()) {
137+
paged_attn_memcpy_kernel(k_input, v_input, past_k_output, past_v_output, slot_mapping);
138+
} else if (k_input.get_precision() == ov::element::f32 && past_k_output.get_precision() == ov::element::f16) {
139+
paged_attn_memcpy_kernel<float, ov::float16>(k_input, v_input, past_k_output, past_v_output, slot_mapping);
140+
} else if (k_input.get_precision() == ov::element::f32 && past_k_output.get_precision() == ov::element::bf16) {
141+
paged_attn_memcpy_kernel<float, ov::bfloat16>(k_input, v_input, past_k_output, past_v_output, slot_mapping);
142+
} else {
143+
OPENVINO_THROW("unsupport src type: ", k_input.get_precision(), ", dst type: ", past_k_output.get_precision(), " in paged_attn_memcpy");
144+
}
145+
}
146+
93147
} // namespace XARCH
94148
} // namespace Cpu
95149
} // namespace Extensions

src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.hpp

+6
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
2020
const ov::intel_cpu::PlainTensor& past_k_output,
2121
const ov::intel_cpu::PlainTensor& past_v_output);
2222

23+
void paged_attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
24+
const ov::intel_cpu::PlainTensor& v_input,
25+
const ov::intel_cpu::PlainTensor& past_k_output,
26+
const ov::intel_cpu::PlainTensor& past_v_output,
27+
const ov::intel_cpu::PlainTensor& slot_mapping);
28+
2329
} // namespace XARCH
2430
} // namespace Cpu
2531
} // namespace Extensions

0 commit comments

Comments
 (0)