Skip to content

Commit 80c7af0

Browse files
committed
Move input_layout specific check to input_layout_inst
1 parent 8991494 commit 80c7af0

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/plugins/intel_gpu/src/graph/input_layout.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ event::ptr input_layout_inst::set_data(memory::ptr mem, bool need_to_check_memor
7070
void input_layout_inst::update_shape() {
7171
OPENVINO_ASSERT(!_outputs.empty() && _outputs[0] != nullptr, "[GPU] input memory is not set");
7272
auto mem_layout = _outputs[0]->get_layout();
73-
if (_impl_params->get_output_layout() != mem_layout) {
73+
// Set SHAPE_CHANGED flag if the actual data layout has changed, or if the node is included
74+
// into shape_of subgraph to trigger proper shape_of subgraph shape recalculation
75+
if (_impl_params->get_output_layout() != mem_layout || _node->is_in_shape_of_subgraph()) {
7476
set_flag(ExecutionFlags::SHAPE_CHANGED);
7577
}
7678
_impl_params->output_layouts[0] = mem_layout;

src/plugins/intel_gpu/src/graph/primitive_inst.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ void primitive_inst::update_shape() {
368368
if (_node->is_in_shape_of_subgraph()) {
369369
bool subgraph_input_changed = false;
370370
for (size_t i = 0; i < dependant_shape_of_insts.size(); i++) {
371-
if (dependant_shape_of_insts[i]->get_flag(ExecutionFlags::SHAPE_CHANGED) || dependant_shape_of_insts[i]->_node->is_type<input_layout>()) {
371+
if (dependant_shape_of_insts[i]->get_flag(ExecutionFlags::SHAPE_CHANGED)) {
372372
subgraph_input_changed = true;
373373
break;
374374
}
@@ -396,7 +396,6 @@ void primitive_inst::update_shape() {
396396
const auto& insts = _deps[i].first->dependant_shape_of_insts;
397397
for (auto& inst : insts) {
398398
can_skip &= !inst->get_flag(ExecutionFlags::SHAPE_CHANGED);
399-
can_skip &= !inst->_node->is_type<input_layout>();
400399
}
401400
if (can_skip)
402401
continue;
@@ -1850,7 +1849,7 @@ void primitive_inst::prepare_primitive() {
18501849
if (_node->is_in_shape_of_subgraph() && dependant_shape_of_insts.front()->is_dynamic()) {
18511850
bool subgraph_input_changed = false;
18521851
for (size_t i = 0; i < dependant_shape_of_insts.size(); i++) {
1853-
if (dependant_shape_of_insts[i]->get_flag(ExecutionFlags::SHAPE_CHANGED) || dependant_shape_of_insts[i]->_node->is_type<input_layout>()) {
1852+
if (dependant_shape_of_insts[i]->get_flag(ExecutionFlags::SHAPE_CHANGED)) {
18541853
subgraph_input_changed = true;
18551854
break;
18561855
}

0 commit comments

Comments
 (0)