@@ -183,28 +183,30 @@ JitConstants GemmKernelBase::GetJitConstants(const gemm_params& params) const {
183
183
MakeJitConstant (" BEAM_TABLE_TERM" , params.indirect_input0 || params.indirect_input1 ),
184
184
});
185
185
186
- auto get_output_size = [this ](const std::vector<int64_t >& output_order_idx , const int target_idx) {
187
- auto output_dims_order = GetDimsOrder (output_order_idx );
186
+ auto get_tensor_size = [this ](const std::vector<int64_t >& order_idx , const int target_idx, std::string name ) {
187
+ auto output_dims_order = GetDimsOrder (order_idx );
188
188
189
189
switch (output_dims_order.at (target_idx)) {
190
190
case ' b' :
191
- return " OUTPUT_BATCH_NUM " ;
191
+ return name + " _BATCH_NUM " ;
192
192
case ' f' :
193
- return " OUTPUT_FEATURE_NUM " ;
193
+ return name + " _FEATURE_NUM " ;
194
194
case ' w' :
195
- return " OUTPUT_SIZE_W " ;
195
+ return name + " _SIZE_W " ;
196
196
case ' z' :
197
- return " OUTPUT_SIZE_Z " ;
197
+ return name + " _SIZE_Z " ;
198
198
case ' y' :
199
- return " OUTPUT_SIZE_Y " ;
199
+ return name + " _SIZE_Y " ;
200
200
case ' x' :
201
- return " OUTPUT_SIZE_X " ;
201
+ return name + " _SIZE_X " ;
202
202
default :
203
- return " " ;
203
+ return std::string ( " " ) ;
204
204
}
205
205
};
206
206
if (params.indirect_input0 || params.indirect_input1 ) {
207
+ auto beam_table_batch = get_tensor_size (params.indirect_input0 ? params.input0_order : params.input1_order , 0 , " BEAM_TABLE" );
207
208
jit.AddConstant (MakeJitConstant (" BEAM_TABLE" , params.inputs [params.inputs .size () - 1 ]));
209
+ jit.AddConstant (MakeJitConstant (" TR_BEAM_TABLE_BATCH_NUM" , beam_table_batch));
208
210
}
209
211
210
212
if (params.inputs .size () == 4 || (!params.indirect_input0 && !params.indirect_input1 && params.inputs .size () == 3 )) {
@@ -217,10 +219,10 @@ JitConstants GemmKernelBase::GetJitConstants(const gemm_params& params) const {
217
219
MakeJitConstant (" TRANSPOSE_OTHER" , 2 ),
218
220
MakeJitConstant (" INPUT0_DIMS_ORDER" , GetDimsOrder (params.input0_order )),
219
221
MakeJitConstant (" INPUT1_DIMS_ORDER" , GetDimsOrder (params.input1_order )),
220
- MakeJitConstant (" TR_OUTPUT_SIZE_Z" , get_output_size (params.output_order , 6 )),
221
- MakeJitConstant (" TR_OUTPUT_SIZE_W" , get_output_size (params.output_order , 4 )),
222
- MakeJitConstant (" TR_OUTPUT_FEATURE_NUM" , get_output_size (params.output_order , 2 )),
223
- MakeJitConstant (" TR_OUTPUT_BATCH_NUM" , get_output_size (params.output_order , 0 )),
222
+ MakeJitConstant (" TR_OUTPUT_SIZE_Z" , get_tensor_size (params.output_order , 6 , " OUTPUT " )),
223
+ MakeJitConstant (" TR_OUTPUT_SIZE_W" , get_tensor_size (params.output_order , 4 , " OUTPUT " )),
224
+ MakeJitConstant (" TR_OUTPUT_FEATURE_NUM" , get_tensor_size (params.output_order , 2 , " OUTPUT " )),
225
+ MakeJitConstant (" TR_OUTPUT_BATCH_NUM" , get_tensor_size (params.output_order , 0 , " OUTPUT " )),
224
226
});
225
227
226
228
return jit;
0 commit comments