Skip to content

Commit 46e504d

Browse files
committed
[GPU] Fix gemm_tiled_opt kernel accuracy for the dynamic case with TILE_N=32 and transposed output shape
1 parent d384662 commit 46e504d

File tree

4 files changed

+67
-16
lines changed

4 files changed

+67
-16
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_tiled_opt.cl

+19-8
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,15 @@ KERNEL(gemm_tiled_opt)(
162162
#ifdef BIAS_TERM
163163
const uint batch_offset_input2 = FUNC_CALL(get_input2_batch_offset)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z);
164164
#endif // BIAS_TERM
165-
uint write_id = 0;
165+
uint y_write_id = 0;
166+
uint x_write_id = 0;
166167
const uint batch_offset_output = FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR TR_B, TR_F, TR_W, TR_Z, TR_Y, TR_X);
167-
write_id = 1;
168-
const uint batch_offset_output_diff = FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR TR_B, TR_F, TR_W, TR_Z, TR_Y, TR_X) - batch_offset_output;
168+
y_write_id = 1;
169+
x_write_id = 0;
170+
const uint output_y_pitch = FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR TR_B, TR_F, TR_W, TR_Z, TR_Y, TR_X) - batch_offset_output;
171+
y_write_id = 0;
172+
x_write_id = 1;
173+
const uint output_x_pitch = FUNC_CALL(get_output_index)(OPTIONAL_SHAPE_INFO_TENSOR TR_B, TR_F, TR_W, TR_Z, TR_Y, TR_X) - batch_offset_output;
169174

170175
// Start pointers offsets
171176
#if TRANSPOSE_INPUT0 == TRANSPOSE_X_LAST
@@ -424,7 +429,7 @@ KERNEL(gemm_tiled_opt)(
424429
#endif // TILE_K > SIMD_WIDTH
425430
}
426431
}
427-
#if IS_DYNAMIC && !INDIRECT_INPUT0 && !HAS_DYNAMIC_K_PADDING
432+
#if IS_DYNAMIC && !INDIRECT_INPUT0 && !HAS_DYNAMIC_K_PADDING
428433
// Read A for next dot_id
429434
#if TILE_K_NOT_DIVISIBLE
430435
a_read = (dot_id + 1 < tile_m_iterations) ? TILE_K_NOT_DIVISIBLE_CALC ? a_ptr[sglid] : BLOCK_READ_A(a_ptr, 0) : 0;
@@ -732,7 +737,13 @@ KERNEL(gemm_tiled_opt)(
732737
#endif // HAS_FUSED_OPS
733738
}
734739
#else
735-
OUTPUT_TYPE* d_ptr_tmp = d_ptr + sglid;
740+
#if TRANSPOSE_OUTPUT == TRANSPOSE_X_LAST
741+
const uint x_pitch = 1;
742+
#else
743+
const uint x_pitch = output_x_pitch;
744+
#endif
745+
OUTPUT_TYPE* d_ptr_tmp = d_ptr + sglid * x_pitch;
746+
736747
#ifdef BIAS_TERM
737748
ACCUMULATOR_TYPE_VEC dequantized = (ACCUMULATOR_TYPE_VEC)(ALPHA) * c_tile[write_id] + TO_ACCUMULATOR_TYPE(BETA) * c_ptr[sglid];
738749
#else // BIAS_TERM
@@ -743,13 +754,13 @@ KERNEL(gemm_tiled_opt)(
743754
OUTPUT_TYPE_VEC result = FUSED_OPS_RESULT_VEC;
744755
unroll_for (uint n_elem = 0; n_elem < B_VEC_SIZE; ++n_elem) {
745756
if (b_raw_global_id + SIMD_WIDTH * n_elem < N) {
746-
*(d_ptr_tmp + SIMD_WIDTH * n_elem) = result[n_elem];
757+
*(d_ptr_tmp + SIMD_WIDTH * n_elem * x_pitch) = result[n_elem];
747758
}
748759
}
749760
#else
750761
unroll_for (uint n_elem = 0; n_elem < B_VEC_SIZE; ++n_elem) {
751762
if (b_raw_global_id + SIMD_WIDTH * n_elem < N) {
752-
*(d_ptr_tmp + SIMD_WIDTH * n_elem) = dequantized[n_elem];
763+
*(d_ptr_tmp + SIMD_WIDTH * n_elem * x_pitch) = dequantized[n_elem];
753764
}
754765
}
755766
#endif // HAS_FUSED_OPS
@@ -796,7 +807,7 @@ KERNEL(gemm_tiled_opt)(
796807
#endif // HAS_FUSED_OPS
797808
#endif // TILE_N_NOT_DIVISIBLE || B_VEC_SIZE == 1
798809
#endif // IS_DYNAMIC
799-
d_ptr += batch_offset_output_diff;
810+
d_ptr += output_y_pitch;
800811
#ifdef BIAS_TERM
801812
c_ptr += N;
802813
#endif // BIAS_TERM

src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,17 @@ std::vector<std::string> GemmKernelBase::GetTransposedDims(const std::vector<int
114114
break;
115115
case 6:
116116
if (is_tiled_opt) {
117-
dim_ids.push_back("(y+write_id)");
117+
dim_ids.push_back("(y+y_write_id)");
118118
} else {
119119
dim_ids.push_back("y");
120120
}
121121
break;
122122
case 7:
123-
dim_ids.push_back("x");
123+
if (is_tiled_opt) {
124+
dim_ids.push_back("(x+x_write_id)");
125+
} else {
126+
dim_ids.push_back("x");
127+
}
124128
break;
125129
default:
126130
break;

src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_tiled_opt.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
196196
MakeJitConstant("TR_X", GetTransposedDims(params.output_order, true).at(7)),
197197
});
198198

199+
bool transpose_output = (params.output_order.size() > 0 && (params.output_order.back() != (static_cast<int>(params.output_order.size()) - 1)));
200+
if (transpose_output)
201+
jit.AddConstant(MakeJitConstant("TRANSPOSE_OUTPUT", 2 /* set as TRANSPOSE_OTHER */));
202+
else
203+
jit.AddConstant(MakeJitConstant("TRANSPOSE_OUTPUT", 0 /* set as TRANSPOSE_X_LAST */));
204+
199205
bool has_dynamic_k_padding = params.transpose_input0 ? params.inputs[0].Y().pad.is_dynamic
200206
: params.inputs[0].X().pad.is_dynamic;
201207
bool has_dynamic_n_padding = params.transpose_input1 ? params.inputs[1].Y().pad.is_dynamic

src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp

+36-6
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,7 @@ class gemm_gpu_tests: public ::testing::Test {
927927
ov::Shape ref_input1_broadcasted_shape;
928928
ov::Shape ref_input1_shape;
929929
ov::Shape ref_output_shape;
930-
930+
931931
ref_input0_shape = { BATCH_SIZE, 16, M_SIZE, K_SIZE };
932932
ref_input1_broadcasted_shape = { N_SIZE, BATCH_SIZE, 16, K_SIZE };
933933
ref_input1_shape = { BATCH_SIZE, 16, K_SIZE, N_SIZE };
@@ -1063,7 +1063,7 @@ class gemm_gpu_tests: public ::testing::Test {
10631063
ov::Shape ref_input1_reshaped_shape;
10641064
ov::Shape ref_input1_shape;
10651065
ov::Shape ref_output_shape;
1066-
1066+
10671067
ref_input0_shape = { BATCH_SIZE, 32, M_SIZE, K_SIZE };
10681068
ref_input1_broadcasted_shape = { N_SIZE, BATCH_SIZE, 2, 16, K_SIZE };
10691069
ref_input1_reshaped_shape = { N_SIZE, BATCH_SIZE, 32, K_SIZE };
@@ -1313,16 +1313,22 @@ class gemm_gpu_tests: public ::testing::Test {
13131313
output_shape_default = { M_SIZE, N_SIZE };
13141314
} else if (num_dims == 3) {
13151315
input0_shape_default = { BATCH_SIZE, M_SIZE, K_SIZE };
1316-
input1_shape_default = { BATCH_SIZE, K_SIZE, N_SIZE };
1316+
input1_shape_default = { BATCH_SIZE, K_SIZE, N_SIZE };
13171317
output_shape_default = { BATCH_SIZE, M_SIZE, N_SIZE };
13181318
} else if (num_dims == 4) {
13191319
input0_shape_default = { BATCH_SIZE, 1, M_SIZE, K_SIZE};
1320-
input1_shape_default = { BATCH_SIZE, 1, K_SIZE, N_SIZE};
1320+
input1_shape_default = { BATCH_SIZE, 1, K_SIZE, N_SIZE};
13211321
output_shape_default = { BATCH_SIZE, 1, M_SIZE, N_SIZE };
13221322
}
13231323
}
13241324

1325-
void test_transpose_matmul_f32(size_t num_dims, bool is_input_dynamic, bool is_caching_test, std::vector<size_t> BMKN, std::vector<int64_t> input0_order, std::vector<int64_t> input1_order) {
1325+
void test_transpose_matmul_f32(size_t num_dims,
1326+
bool is_input_dynamic,
1327+
bool is_caching_test,
1328+
std::vector<size_t> BMKN,
1329+
std::vector<int64_t> input0_order,
1330+
std::vector<int64_t> input1_order,
1331+
std::vector<int64_t> output_order = {}) {
13261332
tests::random_generator rg;
13271333
rg.set_seed(GET_SUITE_NAME);
13281334

@@ -1337,6 +1343,7 @@ class gemm_gpu_tests: public ::testing::Test {
13371343
set_default_shapes(num_dims, BMKN, input0_shape_default, input1_shape_default, output_shape_default);
13381344
ov::Shape input0_shape(input0_shape_default.size());
13391345
ov::Shape input1_shape(input1_shape_default.size());
1346+
ov::Shape output_shape(output_shape_default.size());
13401347

13411348
for (size_t dim = 0; dim < input0_shape_default.size(); ++dim) {
13421349
input0_shape[input0_order[dim]] = input0_shape_default[dim];
@@ -1346,6 +1353,12 @@ class gemm_gpu_tests: public ::testing::Test {
13461353
input1_shape[input1_order[dim]] = input1_shape_default[dim];
13471354
}
13481355

1356+
if (!output_order.empty()) {
1357+
for (size_t dim = 0; dim < output_shape_default.size(); ++dim) {
1358+
output_shape[output_order[dim]] = output_shape_default[dim];
1359+
}
1360+
}
1361+
13491362
if (is_input_dynamic) {
13501363
input0_layout = layout{ov::PartialShape::dynamic(input0_shape.size()), data_types::f32, format::bfyx};
13511364
input1_layout = layout{ov::PartialShape::dynamic(input1_shape.size()), data_types::f32, format::bfyx};
@@ -1366,7 +1379,7 @@ class gemm_gpu_tests: public ::testing::Test {
13661379
topology topology;
13671380
topology.add(input_layout("input0", input0_layout),
13681381
input_layout("input1", input1_layout),
1369-
gemm("gemm", { input_info("input0"), input_info("input1") }, data_types::f32, {}, {}, {}, {}, input0_order, input1_order)
1382+
gemm("gemm", { input_info("input0"), input_info("input1") }, data_types::f32, {}, {}, {}, {}, input0_order, input1_order, output_order)
13701383
);
13711384

13721385
ExecutionConfig config = get_test_default_config(engine);
@@ -1415,6 +1428,19 @@ class gemm_gpu_tests: public ::testing::Test {
14151428
false,
14161429
false);
14171430

1431+
if (!output_order.empty()) {
1432+
std::vector<float> out_data_transposed(ov::shape_size(output_shape_default));
1433+
1434+
ov::reference::transpose((const char *)(ref_out_data.data()),
1435+
(char *)(out_data_transposed.data()),
1436+
output_shape_default,
1437+
sizeof(float),
1438+
output_order,
1439+
output_shape);
1440+
1441+
ref_out_data = out_data_transposed;
1442+
}
1443+
14181444
ASSERT_EQ(output_ptr.size(), ref_out_data.size());
14191445

14201446
const auto abs_error = 0.0001;
@@ -1614,6 +1640,10 @@ TEST_F(gemm_gpu_tests, transpose_matmul_dynamic_4d_f32) {
16141640
this->test_transpose_matmul_f32(4, true, false, /*BMKN*/{19, 37, 23, 29}, /*input0_order*/{0, 2, 3, 1}, /*input1_order*/{1, 2, 3, 0});
16151641
}
16161642

1643+
TEST_F(gemm_gpu_tests, transpose_matmul_dynamic_4d_f32_n_tile_32_output_ylast) {
1644+
this->test_transpose_matmul_f32(4, true, false, /*BMKN*/{1, 128, 1, 9}, /*input0_order*/{0, 1, 2, 3}, /*input1_order*/{0, 1, 2, 3}, /*output_order*/{0, 1, 3, 2});
1645+
}
1646+
16171647
TEST_F(gemm_gpu_tests, transpose_matmul_static_4d_f16) {
16181648
this->test_transpose_matmul_f16(4, false, false, /*BMKN*/{19, 37, 23, 29}, /*input0_order*/{0, 2, 3, 1}, /*input1_order*/{1, 2, 3, 0});
16191649
}

0 commit comments

Comments
 (0)