Skip to content

Commit cfaecfe

Browse files
committed
[GPU] Apply transpose to beam table batch dim
1 parent 13391b0 commit cfaecfe

File tree

3 files changed

+20
-18
lines changed

3 files changed

+20
-18
lines changed

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_ref.cl

+2-2
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,10 @@ KERNEL(gemm_ref)(
125125
uint b0 = b;
126126
uint b1 = b;
127127
#if INDIRECT_INPUT0
128-
b0 = BEAM_TABLE_BATCH_NUM > 1 : beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, ki)] : b;
128+
b0 = TR_BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, ki)] : b;
129129
#endif
130130
#if INDIRECT_INPUT1
131-
b1 = BEAM_TABLE_BATCH_NUM > 1 : beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, ki, x)] : b;
131+
b1 = TR_BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, ki, x)] : b;
132132
#endif
133133

134134
uint in0_idx = FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b0, f, w, z, y, ki);

src/plugins/intel_gpu/src/kernel_selector/cl_kernels/gemm_tiled_opt.cl

+3-3
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,14 @@ inline uint FUNC(get_bt_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, u
7373

7474
#if INDIRECT_INPUT0
7575
inline uint FUNC(get_input0_indirect_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x, __global BEAM_TABLE_TYPE* beam_table) {
76-
int b_index = BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x)] : b;
76+
int b_index = TR_BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x)] : b;
7777
return FUNC_CALL(get_input0_index)(OPTIONAL_SHAPE_INFO_TENSOR b_index, f, w, z, y, x);
7878
}
7979
#endif
8080

8181
#if INDIRECT_INPUT1
8282
inline uint FUNC(get_input1_indirect_index)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint w, uint z, uint y, uint x, __global BEAM_TABLE_TYPE* beam_table) {
83-
int b_index = BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x)] : b;
83+
int b_index = TR_BEAM_TABLE_BATCH_NUM > 1 ? beam_table[FUNC_CALL(get_bt_index)(OPTIONAL_SHAPE_INFO_TENSOR b, f, w, z, y, x)] : b;
8484
return FUNC_CALL(get_input1_index)(OPTIONAL_SHAPE_INFO_TENSOR b_index, f, w, z, y, x);
8585
}
8686
#endif
@@ -203,7 +203,7 @@ KERNEL(gemm_tiled_opt)(
203203
const uint b_raw_global_id = tile_n_offset + sglid;
204204

205205
#if INDIRECT_INPUT0 || INDIRECT_INPUT1
206-
const char do_indirect_load = BEAM_TABLE_BATCH_NUM > 1;
206+
const char do_indirect_load = TR_BEAM_TABLE_BATCH_NUM > 1;
207207
#endif
208208

209209
#if TRANSPOSE_INPUT0 != TRANSPOSE_X_LAST

src/plugins/intel_gpu/src/kernel_selector/kernels/gemm/gemm_kernel_base.cpp

+15-13
Original file line numberDiff line numberDiff line change
@@ -183,28 +183,30 @@ JitConstants GemmKernelBase::GetJitConstants(const gemm_params& params) const {
183183
MakeJitConstant("BEAM_TABLE_TERM", params.indirect_input0 || params.indirect_input1),
184184
});
185185

186-
auto get_output_size = [this](const std::vector<int64_t>& output_order_idx, const int target_idx) {
187-
auto output_dims_order = GetDimsOrder(output_order_idx);
186+
auto get_tensor_size = [this](const std::vector<int64_t>& order_idx, const int target_idx, std::string name) {
187+
auto output_dims_order = GetDimsOrder(order_idx);
188188

189189
switch (output_dims_order.at(target_idx)) {
190190
case 'b':
191-
return "OUTPUT_BATCH_NUM";
191+
return name + "_BATCH_NUM";
192192
case 'f':
193-
return "OUTPUT_FEATURE_NUM";
193+
return name + "_FEATURE_NUM";
194194
case 'w':
195-
return "OUTPUT_SIZE_W";
195+
return name + "_SIZE_W";
196196
case 'z':
197-
return "OUTPUT_SIZE_Z";
197+
return name + "_SIZE_Z";
198198
case 'y':
199-
return "OUTPUT_SIZE_Y";
199+
return name + "_SIZE_Y";
200200
case 'x':
201-
return "OUTPUT_SIZE_X";
201+
return name + "_SIZE_X";
202202
default:
203-
return "";
203+
return std::string("");
204204
}
205205
};
206206
if (params.indirect_input0 || params.indirect_input1) {
207+
auto beam_table_batch = get_tensor_size(params.indirect_input0 ? params.input0_order : params.input1_order, 0, "BEAM_TABLE");
207208
jit.AddConstant(MakeJitConstant("BEAM_TABLE", params.inputs[params.inputs.size() - 1]));
209+
jit.AddConstant(MakeJitConstant("TR_BEAM_TABLE_BATCH_NUM", beam_table_batch));
208210
}
209211

210212
if (params.inputs.size() == 4 || (!params.indirect_input0 && !params.indirect_input1 && params.inputs.size() == 3)) {
@@ -217,10 +219,10 @@ JitConstants GemmKernelBase::GetJitConstants(const gemm_params& params) const {
217219
MakeJitConstant("TRANSPOSE_OTHER", 2),
218220
MakeJitConstant("INPUT0_DIMS_ORDER", GetDimsOrder(params.input0_order)),
219221
MakeJitConstant("INPUT1_DIMS_ORDER", GetDimsOrder(params.input1_order)),
220-
MakeJitConstant("TR_OUTPUT_SIZE_Z", get_output_size(params.output_order, 6)),
221-
MakeJitConstant("TR_OUTPUT_SIZE_W", get_output_size(params.output_order, 4)),
222-
MakeJitConstant("TR_OUTPUT_FEATURE_NUM", get_output_size(params.output_order, 2)),
223-
MakeJitConstant("TR_OUTPUT_BATCH_NUM", get_output_size(params.output_order, 0)),
222+
MakeJitConstant("TR_OUTPUT_SIZE_Z", get_tensor_size(params.output_order, 6, "OUTPUT")),
223+
MakeJitConstant("TR_OUTPUT_SIZE_W", get_tensor_size(params.output_order, 4, "OUTPUT")),
224+
MakeJitConstant("TR_OUTPUT_FEATURE_NUM", get_tensor_size(params.output_order, 2, "OUTPUT")),
225+
MakeJitConstant("TR_OUTPUT_BATCH_NUM", get_tensor_size(params.output_order, 0, "OUTPUT")),
224226
});
225227

226228
return jit;

0 commit comments

Comments
 (0)