Skip to content

Commit f7378bb

Browse files
committed
[GPU] Fix concat_input_order pass
1 parent e4b82a5 commit f7378bb

File tree

2 files changed

+5
-11
lines changed

2 files changed

+5
-11
lines changed

src/plugins/intel_gpu/src/graph/graph_optimizer/concat_input_order.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ bool can_shuffle_features(program_node& node, program_node& concat_node, stream&
4848
if (pass_through) {
4949
// Primitives that are feature order invariant, pass-through shuffled features to users
5050
for (auto& user : node.get_users()) {
51-
if (!can_shuffle_features(*user, concat_node, stream))
51+
if (!can_shuffle_features(*user, node, stream))
5252
return false;
5353
}
5454
return true;

src/plugins/intel_gpu/tests/unit/test_cases/concatenation_gpu_test.cpp

+4-10
Original file line numberDiff line numberDiff line change
@@ -723,28 +723,24 @@ TEST(concat_gpu, no_exception_in_input_order_opt_b_fs_yx_fsv16_with_conv_port2)
723723
auto concat_input1 = engine.allocate_memory({ data_types::f32, format::b_fs_yx_fsv16, { 1, 48, 6, 6 }});
724724
auto concat_input2 = engine.allocate_memory({ data_types::f32, format::b_fs_yx_fsv16, { 1, 96, 6, 6 }});
725725
auto concat_input3 = engine.allocate_memory({ data_types::f32, format::b_fs_yx_fsv16, { 1, 128, 6, 6 }});
726-
auto conv_input = engine.allocate_memory({ data_types::f32, format::b_fs_yx_fsv16, { 1, 192, 6, 6 } });
727-
auto weights0 = engine.allocate_memory({ data_types::f32, format::bfyx, { 296, 192, 1, 1 } });
726+
auto weights0 = engine.allocate_memory({ data_types::f32, format::bfyx, { 296, 296, 1, 1 } });
728727

729728
std::vector<float> concat_input0_data(concat_input0->get_layout().count());
730729
std::vector<float> concat_input1_data(concat_input1->get_layout().count());
731730
std::vector<float> concat_input2_data(concat_input2->get_layout().count());
732731
std::vector<float> concat_input3_data(concat_input3->get_layout().count());
733-
std::vector<float> conv_input_data(conv_input->get_layout().count());
734732
std::vector<float> weights0_data(weights0->get_layout().count());
735733

736734
std::iota(concat_input0_data.begin(), concat_input0_data.end(), 0.f);
737735
std::iota(concat_input1_data.begin(), concat_input1_data.end(), 0.f);
738736
std::iota(concat_input2_data.begin(), concat_input2_data.end(), 0.f);
739737
std::iota(concat_input3_data.begin(), concat_input3_data.end(), 0.f);
740-
std::iota(conv_input_data.begin(), conv_input_data.end(), 0.f);
741738
std::iota(weights0_data.begin(), weights0_data.end(), 0.f);
742739

743740
set_values(concat_input0, concat_input0_data);
744741
set_values(concat_input1, concat_input1_data);
745742
set_values(concat_input2, concat_input2_data);
746743
set_values(concat_input3, concat_input3_data);
747-
set_values(conv_input, conv_input_data);
748744
set_values(weights0, weights0_data);
749745

750746
layout reorder_layout(data_types::f32, format::b_fs_yx_fsv16, {1, 296, 6, 6});
@@ -753,16 +749,15 @@ TEST(concat_gpu, no_exception_in_input_order_opt_b_fs_yx_fsv16_with_conv_port2)
753749
input_layout("concat_input1", concat_input1->get_layout()),
754750
input_layout("concat_input2", concat_input2->get_layout()),
755751
input_layout("concat_input3", concat_input3->get_layout()),
756-
input_layout("conv_input", conv_input->get_layout()),
757752
concatenation("concat",
758753
{ input_info("concat_input0"), input_info("concat_input1"), input_info("concat_input2"), input_info("concat_input3") },
759754
1,
760755
data_types::f32,
761756
padding{{0, 0, 0, 0}, 0}),
757+
pooling("pooling", input_info("concat"), pooling_mode::max, { 2, 2 }, { 1, 1 }),
762758
data("weights0", weights0),
763-
convolution("conv0", input_info("conv_input"), "weights0", "", 1, { 1, 1 }, {1, 1}, {0, 0}, {0, 0}, false),
764-
eltwise("eltwise", input_info("conv0"), input_info("concat"), eltwise_mode::sum),
765-
permute("permute", input_info("eltwise"), {0, 1, 2, 3}));
759+
convolution("conv0", input_info("pooling"), "weights0", "", 1, { 1, 1 }, {1, 1}, {0, 0}, {0, 0}, false),
760+
permute("permute", input_info("conv0"), {0, 1, 2, 3}));
766761

767762
ov::intel_gpu::ExecutionConfig config = get_test_default_config(engine);
768763
config.set_property(ov::intel_gpu::optimize_data(true));
@@ -772,7 +767,6 @@ TEST(concat_gpu, no_exception_in_input_order_opt_b_fs_yx_fsv16_with_conv_port2)
772767
network.set_input_data("concat_input1", concat_input1);
773768
network.set_input_data("concat_input2", concat_input2);
774769
network.set_input_data("concat_input3", concat_input3);
775-
network.set_input_data("conv_input", conv_input);
776770

777771
ASSERT_NO_FATAL_FAILURE(network.execute());
778772
}

0 commit comments

Comments
 (0)