@@ -76,6 +76,43 @@ static void attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input,
76
76
});
77
77
}
78
78
79
+ template <typename T, typename T2>
80
+ static void paged_attn_memcpy_kernel (const ov::intel_cpu::PlainTensor& k_input,
81
+ const ov::intel_cpu::PlainTensor& v_input,
82
+ const ov::intel_cpu::PlainTensor& past_k_output,
83
+ const ov::intel_cpu::PlainTensor& past_v_output,
84
+ const ov::intel_cpu::PlainTensor& slot_mapping) {
85
+ size_t B = k_input.m_dims [0 ], H = k_input.m_dims [1 ], L1 = k_input.m_dims [2 ], S = k_input.m_dims [3 ];
86
+ parallel_for3d (B, H, L1, [&](size_t b, size_t h, size_t m) {
87
+ auto block_idx = slot_mapping.ptr <int32_t >(b)[m];
88
+ if (block_idx < 0 ) return ;
89
+ attn_copy (past_k_output.ptr <T2>(block_idx, h, 0 ),
90
+ k_input.ptr <T>(b, h, m, 0 ),
91
+ S);
92
+ attn_copy (past_v_output.ptr <T2>(block_idx, h, 0 ),
93
+ v_input.ptr <T>(b, h, m, 0 ),
94
+ S);
95
+ });
96
+ }
97
+
98
+ static void paged_attn_memcpy_kernel (const ov::intel_cpu::PlainTensor& k_input,
99
+ const ov::intel_cpu::PlainTensor& v_input,
100
+ const ov::intel_cpu::PlainTensor& past_k_output,
101
+ const ov::intel_cpu::PlainTensor& past_v_output,
102
+ const ov::intel_cpu::PlainTensor& slot_mapping) {
103
+ size_t B = k_input.m_dims [0 ], H = k_input.m_dims [1 ], L1 = k_input.m_dims [2 ], S = k_input.m_dims [3 ];
104
+ parallel_for3d (B, H, L1, [&](size_t b, size_t h, size_t m) {
105
+ auto block_idx = slot_mapping.ptr <int32_t >(b)[m];
106
+ if (block_idx < 0 ) return ;
107
+ std::memcpy (past_k_output.ptr_v (block_idx, h, 0 ),
108
+ k_input.ptr_v (b, h, m, 0 ),
109
+ S * k_input.m_element_size );
110
+ std::memcpy (past_v_output.ptr_v (block_idx, h, 0 ),
111
+ v_input.ptr_v (b, h, m, 0 ),
112
+ S * v_input.m_element_size );
113
+ });
114
+ }
115
+
79
116
void attn_memcpy (const ov::intel_cpu::PlainTensor& k_input,
80
117
const ov::intel_cpu::PlainTensor& v_input,
81
118
const ov::intel_cpu::PlainTensor& past_k_output,
@@ -90,6 +127,23 @@ void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input,
90
127
OPENVINO_THROW (" unsupport src type: " , k_input.get_precision (), " , dst type: " , past_k_output.get_precision (), " in attn_memcpy" );
91
128
}
92
129
}
130
+
131
+ void paged_attn_memcpy (const ov::intel_cpu::PlainTensor& k_input,
132
+ const ov::intel_cpu::PlainTensor& v_input,
133
+ const ov::intel_cpu::PlainTensor& past_k_output,
134
+ const ov::intel_cpu::PlainTensor& past_v_output,
135
+ const ov::intel_cpu::PlainTensor& slot_mapping) {
136
+ if (past_k_output.get_precision () == k_input.get_precision ()) {
137
+ paged_attn_memcpy_kernel (k_input, v_input, past_k_output, past_v_output, slot_mapping);
138
+ } else if (k_input.get_precision () == ov::element::f32 && past_k_output.get_precision () == ov::element::f16) {
139
+ paged_attn_memcpy_kernel<float , ov::float16>(k_input, v_input, past_k_output, past_v_output, slot_mapping);
140
+ } else if (k_input.get_precision () == ov::element::f32 && past_k_output.get_precision () == ov::element::bf16) {
141
+ paged_attn_memcpy_kernel<float , ov::bfloat16>(k_input, v_input, past_k_output, past_v_output, slot_mapping);
142
+ } else {
143
+ OPENVINO_THROW (" unsupport src type: " , k_input.get_precision (), " , dst type: " , past_k_output.get_precision (), " in paged_attn_memcpy" );
144
+ }
145
+ }
146
+
93
147
} // namespace XARCH
94
148
} // namespace Cpu
95
149
} // namespace Extensions
0 commit comments