|
4 | 4 |
|
5 | 5 | #include "include/batch_headers/common.cl"
|
6 | 6 |
|
| 7 | +#if IS_KV_COMPRESSED |
| 8 | +#define SUBGROUPS_PER_WG 1 |
| 9 | +#else |
7 | 10 | #define SUBGROUPS_PER_WG KV_HEADS_NUM
|
| 11 | +#endif |
| 12 | +#define ACCUMULATOR_TYPE float |
8 | 13 |
|
9 | 14 | REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
|
10 |
| -__attribute__((reqd_work_group_size(SUBGROUP_SIZE, KV_HEADS_NUM, 1))) |
| 15 | +__attribute__((reqd_work_group_size(SUBGROUP_SIZE, SUBGROUPS_PER_WG, 1))) |
11 | 16 | KERNEL(pa_kv_cache_rotate)(
|
12 | 17 | OPTIONAL_SHAPE_INFO_ARG
|
13 | 18 | __global const INPUT0_TYPE* rotated_block_indices,
|
@@ -62,22 +67,76 @@ KERNEL(pa_kv_cache_rotate)(
|
62 | 67 | barrier(CLK_LOCAL_MEM_FENCE);
|
63 | 68 |
|
64 | 69 | const uint token_coefficient_idx = per_token_rotation ? sglid : 0;
|
65 |
| - const uint block_offset = rotated_block_indices[block_idx] * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + |
66 |
| - head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + sglid; |
| 70 | + const uint block_base_offset = rotated_block_indices[block_idx] * KV_HEADS_NUM * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + |
| 71 | + head_idx * ADJUSTED_HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE; |
| 72 | + const uint token_offset = block_base_offset + sglid; |
| 73 | + |
| 74 | +#if IS_KV_COMPRESSED |
| 75 | + const uint comp_offset = block_base_offset + HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE; |
| 76 | + UNCOMPRESSED_TYPE* comp_ptr = key_cache + comp_offset; |
| 77 | + UNCOMPRESSED_TYPE comp_scale = comp_ptr[0 + sglid]; |
| 78 | + UNCOMPRESSED_TYPE comp_zp = comp_ptr[PAGED_ATTENTION_BLOCK_SIZE + sglid]; |
| 79 | + |
| 80 | + UNCOMPRESSED_TYPE max_value = UNCOMPRESSED_VAL_MIN; |
| 81 | + UNCOMPRESSED_TYPE min_value = UNCOMPRESSED_VAL_MAX; |
| 82 | + |
| 83 | + // Reuse SLM to store dequantized rotated values |
| 84 | + __local UNCOMPRESSED_TYPE* rotated_data = (__local UNCOMPRESSED_TYPE*)(&rotation_coefficients[0][0]); |
| 85 | +#endif |
| 86 | + |
| 87 | + // Apply cache rotation |
67 | 88 | for (uint i = 0; i < HEAD_SIZE / 2; i++) {
|
68 |
| - const uint cache_offset = block_offset + i * PAGED_ATTENTION_BLOCK_SIZE; |
69 |
| - OUTPUT_TYPE cache_value_first = key_cache[cache_offset]; |
70 |
| - OUTPUT_TYPE cache_value_second = key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE]; |
| 89 | + const uint cache_offset = token_offset + i * PAGED_ATTENTION_BLOCK_SIZE; |
| 90 | + |
| 91 | +#if IS_KV_COMPRESSED |
| 92 | + UNCOMPRESSED_TYPE cache_value_first = TO_UNCOMPRESSED_TYPE(key_cache[cache_offset] - comp_zp) * comp_scale; |
| 93 | + UNCOMPRESSED_TYPE cache_value_second = TO_UNCOMPRESSED_TYPE(key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE] - comp_zp) * comp_scale; |
| 94 | +#else |
| 95 | + UNCOMPRESSED_TYPE cache_value_first = key_cache[cache_offset]; |
| 96 | + UNCOMPRESSED_TYPE cache_value_second = key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE]; |
| 97 | +#endif |
71 | 98 |
|
72 | 99 | INPUT2_TYPE rotation_value_cos = rotation_coefficients[i][token_coefficient_idx];
|
73 | 100 | INPUT2_TYPE rotation_value_sin = rotation_coefficients[i + (HEAD_SIZE / 2)][token_coefficient_idx];
|
74 | 101 |
|
75 |
| - OUTPUT_TYPE new_cache_value_first = cache_value_first * rotation_value_cos - cache_value_second * rotation_value_sin; |
76 |
| - OUTPUT_TYPE new_cache_value_second = cache_value_first * rotation_value_sin + cache_value_second * rotation_value_cos; |
| 102 | + UNCOMPRESSED_TYPE new_cache_value_first = cache_value_first * rotation_value_cos - cache_value_second * rotation_value_sin; |
| 103 | + UNCOMPRESSED_TYPE new_cache_value_second = cache_value_first * rotation_value_sin + cache_value_second * rotation_value_cos; |
77 | 104 |
|
| 105 | +#if IS_KV_COMPRESSED |
| 106 | + max_value = fmax(fmax(max_value, new_cache_value_first), new_cache_value_second); |
| 107 | + min_value = fmin(fmin(min_value, new_cache_value_first), new_cache_value_second); |
| 108 | + |
| 109 | + rotated_data[(i + 0) * PAGED_ATTENTION_BLOCK_SIZE + sglid] = new_cache_value_first; |
| 110 | + rotated_data[(i + (HEAD_SIZE / 2)) * PAGED_ATTENTION_BLOCK_SIZE + sglid] = new_cache_value_second; |
| 111 | +#else |
78 | 112 | key_cache[cache_offset] = new_cache_value_first;
|
79 | 113 | key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE] = new_cache_value_second;
|
| 114 | +#endif |
| 115 | + } |
| 116 | + |
| 117 | +#if IS_KV_COMPRESSED |
| 118 | + // Re-quantize cache data |
| 119 | + ACCUMULATOR_TYPE grp_max = 0.001; |
| 120 | + ACCUMULATOR_TYPE diff_value = max_value == min_value ? (grp_max) : (max_value - min_value); |
| 121 | + ACCUMULATOR_TYPE scale_tmp = (ACCUMULATOR_TYPE)((CHAR_MAX - CHAR_MIN) / diff_value); |
| 122 | + ACCUMULATOR_TYPE zp_tmp = (ACCUMULATOR_TYPE)(-min_value * scale_tmp) + CHAR_MIN; |
| 123 | + UNCOMPRESSED_TYPE scale = (UNCOMPRESSED_TYPE)(scale_tmp); |
| 124 | + UNCOMPRESSED_TYPE zp = (UNCOMPRESSED_TYPE)(zp_tmp); |
| 125 | + |
| 126 | + // Note: absence of this explicit unrolling directive leads to automatic |
| 127 | + // unrolling and causes registers spill. Set unrolling to a reasonable value manually |
| 128 | + __attribute__((opencl_unroll_hint(8))) |
| 129 | + for (uint i = 0; i < HEAD_SIZE; i++) { |
| 130 | + OUTPUT_TYPE quantized_res = convert_char_rte(rotated_data[i * PAGED_ATTENTION_BLOCK_SIZE + sglid] * scale + zp); |
| 131 | + |
| 132 | + const uint cache_offset = token_offset + i * PAGED_ATTENTION_BLOCK_SIZE; |
| 133 | + key_cache[cache_offset] = quantized_res; |
80 | 134 | }
|
| 135 | + |
| 136 | + comp_ptr[0 + sglid] = 1.0 / scale; |
| 137 | + comp_ptr[PAGED_ATTENTION_BLOCK_SIZE + sglid] = zp; |
| 138 | +#endif |
81 | 139 | }
|
82 | 140 |
|
| 141 | +#undef ACCUMULATOR_TYPE |
83 | 142 | #undef SUBGROUPS_PER_WG
|
0 commit comments