Skip to content

Commit b02b0f2

Browse files
authored
Supprot complex tensors for equal (openvinotoolkit#439)
1 parent e3e71cd commit b02b0f2

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

src/tensorflow_translators.cpp

+14-1
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,20 @@ ov::OutputVector translate_equal(const ov::frontend::NodeContext& node) {
503503
result = std::make_shared<Convert>(equal_str, element::boolean);
504504
}
505505
else {
506-
result = std::make_shared<Equal>(input1, input2)->output(0);
506+
auto lhs_complex = ov::as_type_ptr<ComplexTypeMark>(input1.get_node_shared_ptr());
507+
auto rhs_complex = ov::as_type_ptr<ComplexTypeMark>(input2.get_node_shared_ptr());
508+
if (lhs_complex && rhs_complex) {
509+
auto lhs_data = lhs_complex->get_data();
510+
auto rhs_data = rhs_complex->get_data();
511+
auto equal = std::make_shared<Equal>(lhs_data, rhs_data);
512+
513+
// reduce along the last dimension using ReduceAnd
514+
auto reduce_axes = std::make_shared<Constant>(element::i32, Shape{1}, std::vector<int32_t>{-1});
515+
result = std::make_shared<ReduceLogicalAnd>(equal, reduce_axes, false)->output(0);
516+
}
517+
else{
518+
result = std::make_shared<Equal>(input1, input2)->output(0);
519+
}
507520
}
508521

509522
result.get_node_shared_ptr()->set_friendly_name(node_name);

0 commit comments

Comments
 (0)