|
3 | 3 | //
|
4 | 4 |
|
5 | 5 | #include "test_utils.h"
|
| 6 | +#include "random_generator.hpp" |
6 | 7 |
|
7 | 8 | #include <intel_gpu/primitives/input_layout.hpp>
|
8 | 9 | #include <intel_gpu/primitives/data.hpp>
|
9 | 10 | #include <intel_gpu/primitives/shape_of.hpp>
|
10 | 11 | #include <intel_gpu/primitives/broadcast.hpp>
|
11 | 12 | #include <intel_gpu/primitives/gather.hpp>
|
12 | 13 | #include <intel_gpu/primitives/non_zero.hpp>
|
| 14 | +#include <intel_gpu/primitives/paged_attention.hpp> |
| 15 | +#include <intel_gpu/primitives/gather.hpp> |
13 | 16 |
|
14 | 17 | #include "program_wrapper.h"
|
15 | 18 |
|
@@ -65,4 +68,181 @@ TEST(update_shape_test, ocl_impl_in_shapeof_subgraph) {
|
65 | 68 | std::map<primitive_id, network_output> outputs;
|
66 | 69 | OV_ASSERT_NO_THROW(outputs = network.execute());
|
67 | 70 | }
|
| 71 | + |
| 72 | +TEST(update_shape_test, max_context_len_shapeof_subgraph) { |
| 73 | + tests::random_generator rg(GET_SUITE_NAME); |
| 74 | + auto& engine = get_test_engine(); |
| 75 | + |
| 76 | + auto input_data_layout = layout{ov::PartialShape{1, -1}, data_types::f16, format::bfyx}; |
| 77 | + |
| 78 | + auto qkv_mem_layout = layout{ov::PartialShape{1, 128}, data_types::f16, format::bfyx}; |
| 79 | + auto qkv_mem = engine.allocate_memory(qkv_mem_layout); |
| 80 | + auto qkv_rnd = rg.generate_random_1d<ov::float16>(qkv_mem_layout.count(), 0, 10); |
| 81 | + set_values(qkv_mem, qkv_rnd); |
| 82 | + |
| 83 | + auto key_cache_mem_layout = layout{ov::PartialShape{1, 2, 64, 16}, data_types::f16, format::bfyx}; |
| 84 | + auto value_cache_mem_layout = layout{ov::PartialShape{1, 2, 16, 64}, data_types::f16, format::bfyx}; |
| 85 | + auto key_cache_mem = engine.allocate_memory(key_cache_mem_layout); |
| 86 | + auto value_cache_mem = engine.allocate_memory(value_cache_mem_layout); |
| 87 | + auto cache_rnd = rg.generate_random_1d<ov::float16>(key_cache_mem_layout.count(), 0, 10); |
| 88 | + set_values(key_cache_mem, cache_rnd); |
| 89 | + set_values(value_cache_mem, cache_rnd); |
| 90 | + |
| 91 | + auto past_lens_mem_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx}; |
| 92 | + auto past_lens_mem = engine.allocate_memory(past_lens_mem_layout); |
| 93 | + set_values(value_cache_mem, {8}); |
| 94 | + |
| 95 | + auto subsequence_begins_mem_layout = layout{ov::PartialShape{2}, data_types::i32, format::bfyx}; |
| 96 | + auto subsequence_begins_mem = engine.allocate_memory(subsequence_begins_mem_layout); |
| 97 | + set_values(subsequence_begins_mem, {0, 1}); |
| 98 | + |
| 99 | + auto block_indices_mem_layout = layout{ov::PartialShape{2}, data_types::i32, format::bfyx}; |
| 100 | + auto block_indices_mem = engine.allocate_memory(block_indices_mem_layout); |
| 101 | + set_values(block_indices_mem, {0}); |
| 102 | + |
| 103 | + auto block_indices_begins_mem_layout = layout{ov::PartialShape{2}, data_types::i32, format::bfyx}; |
| 104 | + auto block_indices_begins_mem = engine.allocate_memory(block_indices_begins_mem_layout); |
| 105 | + set_values(block_indices_begins_mem, {0, 1}); |
| 106 | + |
| 107 | + auto scale_mem_layout = layout{ov::PartialShape{1}, data_types::f16, format::bfyx}; |
| 108 | + auto scale_mem = engine.allocate_memory(scale_mem_layout); |
| 109 | + set_values<ov::float16>(scale_mem, {1}); |
| 110 | + |
| 111 | + auto sliding_window_mem_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx}; |
| 112 | + auto sliding_window_mem = engine.allocate_memory(sliding_window_mem_layout); |
| 113 | + set_values(sliding_window_mem, {0}); |
| 114 | + |
| 115 | + auto alibi_mem_layout = layout{ov::PartialShape{0}, data_types::f16, format::bfyx}; |
| 116 | + auto alibi_mem = engine.allocate_memory(alibi_mem_layout); |
| 117 | + |
| 118 | + auto const_one_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx}; |
| 119 | + auto const_one_mem = engine.allocate_memory(const_one_layout); |
| 120 | + set_values(const_one_mem, {1}); |
| 121 | + |
| 122 | + auto input_data_mem_layout = layout{ov::PartialShape{1, 9}, data_types::f16, format::bfyx}; |
| 123 | + auto input_data_mem = engine.allocate_memory(input_data_mem_layout); |
| 124 | + auto input_data_rnd = rg.generate_random_1d<ov::float16>(input_data_mem_layout.count(), 0, 10); |
| 125 | + set_values(input_data_mem, input_data_rnd); |
| 126 | + |
| 127 | + auto query_layout = layout{ov::PartialShape{-1, 128}, data_types::f16, format::bfyx}; |
| 128 | + auto key_layout = query_layout; |
| 129 | + auto value_layout = query_layout; |
| 130 | + auto key_cache_layout = layout{ov::PartialShape{-1, 2, 64, 16}, data_types::f16, format::bfyx}; |
| 131 | + auto dynamic_i32_layout = layout{ov::PartialShape::dynamic(1), data_types::i32, format::bfyx}; |
| 132 | + auto value_cache_layout = key_cache_layout; |
| 133 | + auto past_lens_layout = dynamic_i32_layout; |
| 134 | + auto subsequence_begins_layout = dynamic_i32_layout; |
| 135 | + auto block_indices_layout = dynamic_i32_layout; |
| 136 | + auto block_indices_begins_layout = dynamic_i32_layout; |
| 137 | + auto scale_layout = layout{ov::PartialShape{1}, data_types::f16, format::bfyx}; |
| 138 | + auto sliding_window_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx}; |
| 139 | + auto alibi_layout = layout{ov::PartialShape{0}, data_types::f16, format::bfyx}; |
| 140 | + auto max_context_len_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx}; |
| 141 | + |
| 142 | + std::vector<input_info> pa_inputs = {input_info("query"), |
| 143 | + input_info("key"), |
| 144 | + input_info("value"), |
| 145 | + input_info("key_cache"), |
| 146 | + input_info("value_cache"), |
| 147 | + input_info("past_lens"), |
| 148 | + input_info("subsequence_begins"), |
| 149 | + input_info("block_indices"), |
| 150 | + input_info("block_indices_begins"), |
| 151 | + input_info("scale"), |
| 152 | + input_info("sliding_window"), |
| 153 | + input_info("alibi"), |
| 154 | + input_info("max_context_len")}; |
| 155 | + |
| 156 | + auto pa_prim = paged_attention("paged_attention", pa_inputs); |
| 157 | + pa_prim.head_size = 64; |
| 158 | + pa_prim.kv_heads_num = 2; |
| 159 | + pa_prim.heads_num = 2; |
| 160 | + pa_prim.scale_val = 1.f; |
| 161 | + pa_prim.has_alibi = false; |
| 162 | + pa_prim.num_outputs = 1; |
| 163 | + pa_prim.has_rotated_blocks = false; |
| 164 | + |
| 165 | + topology topology; |
| 166 | + topology.add(input_layout("input_data", input_data_layout)); |
| 167 | + topology.add(input_layout("query", query_layout)); |
| 168 | + topology.add(input_layout("key", key_layout)); |
| 169 | + topology.add(input_layout("value", value_layout)); |
| 170 | + topology.add(input_layout("key_cache", key_cache_layout)); |
| 171 | + topology.add(input_layout("value_cache", value_cache_layout)); |
| 172 | + topology.add(input_layout("past_lens", past_lens_layout)); |
| 173 | + topology.add(input_layout("subsequence_begins", subsequence_begins_layout)); |
| 174 | + topology.add(input_layout("block_indices", block_indices_layout)); |
| 175 | + topology.add(input_layout("block_indices_begins", block_indices_begins_layout)); |
| 176 | + topology.add(input_layout("scale", scale_layout)); |
| 177 | + topology.add(input_layout("sliding_window", sliding_window_layout)); |
| 178 | + topology.add(input_layout("alibi", alibi_layout)); |
| 179 | + topology.add(input_layout("max_context_len", max_context_len_layout)); |
| 180 | + topology.add(data("const_one", const_one_mem)); |
| 181 | + topology.add(shape_of("shape_of", input_info("input_data"), data_types::i32)); |
| 182 | + topology.add(gather("gather", input_info("shape_of"), input_info("const_one"), 0, 1, ov::Shape{})); |
| 183 | + topology.add(broadcast("broadcast", input_info("gather"), input_info("max_context_len"), {}, ov::op::BroadcastType::BIDIRECTIONAL)); |
| 184 | + topology.add(pa_prim); |
| 185 | + |
| 186 | + ExecutionConfig config = get_test_default_config(engine); |
| 187 | + config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); |
| 188 | + |
| 189 | + network network(engine, topology, config); |
| 190 | + |
| 191 | + network.set_input_data("input_data", input_data_mem); |
| 192 | + network.set_input_data("query", qkv_mem); |
| 193 | + network.set_input_data("key", qkv_mem); |
| 194 | + network.set_input_data("value", qkv_mem); |
| 195 | + network.set_input_data("key_cache", key_cache_mem); |
| 196 | + network.set_input_data("value_cache", value_cache_mem); |
| 197 | + network.set_input_data("past_lens", past_lens_mem); |
| 198 | + network.set_input_data("subsequence_begins", subsequence_begins_mem); |
| 199 | + network.set_input_data("block_indices", block_indices_mem); |
| 200 | + network.set_input_data("block_indices_begins", block_indices_begins_mem); |
| 201 | + network.set_input_data("scale", scale_mem); |
| 202 | + network.set_input_data("sliding_window", sliding_window_mem); |
| 203 | + network.set_input_data("alibi", alibi_mem); |
| 204 | + |
| 205 | + // Set original max_context_len value |
| 206 | + auto max_context_len_mem_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx}; |
| 207 | + auto max_context_len_mem = engine.allocate_memory(max_context_len_mem_layout); |
| 208 | + set_values(max_context_len_mem, {9}); |
| 209 | + |
| 210 | + network.set_input_data("max_context_len", max_context_len_mem); |
| 211 | + |
| 212 | + // 1st network execution |
| 213 | + network.execute(); |
| 214 | + |
| 215 | + auto broadcast_inst = network.get_primitive("broadcast"); |
| 216 | + ASSERT_EQ(broadcast_inst->get_node().get_dependant_shape_of_nodes().size(), 2); |
| 217 | + |
| 218 | + // Verify broadcast shape after first execution |
| 219 | + auto broadcast_shape = broadcast_inst->get_impl_params()->get_output_layout().get_shape(); |
| 220 | + ASSERT_EQ(broadcast_shape, ov::Shape{9}); |
| 221 | + |
| 222 | + network.set_input_data("input_data", input_data_mem); |
| 223 | + network.set_input_data("query", qkv_mem); |
| 224 | + network.set_input_data("key", qkv_mem); |
| 225 | + network.set_input_data("value", qkv_mem); |
| 226 | + network.set_input_data("key_cache", key_cache_mem); |
| 227 | + network.set_input_data("value_cache", value_cache_mem); |
| 228 | + network.set_input_data("past_lens", past_lens_mem); |
| 229 | + network.set_input_data("subsequence_begins", subsequence_begins_mem); |
| 230 | + network.set_input_data("block_indices", block_indices_mem); |
| 231 | + network.set_input_data("block_indices_begins", block_indices_begins_mem); |
| 232 | + network.set_input_data("scale", scale_mem); |
| 233 | + network.set_input_data("sliding_window", sliding_window_mem); |
| 234 | + network.set_input_data("alibi", alibi_mem); |
| 235 | + |
| 236 | + // Update max_context_len value, which should be taken into account in shape recalculation for broadcast |
| 237 | + set_values(max_context_len_mem, {8}); |
| 238 | + |
| 239 | + network.set_input_data("max_context_len", max_context_len_mem); |
| 240 | + |
| 241 | + // 2nd network execution with updated max_context_len |
| 242 | + network.execute(); |
| 243 | + |
| 244 | + // Check if broadcast shape was recalculated |
| 245 | + broadcast_shape = broadcast_inst->get_impl_params()->get_output_layout().get_shape(); |
| 246 | + ASSERT_EQ(broadcast_shape, ov::Shape{8}); |
| 247 | +} |
68 | 248 | } // update_shape_test
|
0 commit comments