|
8 | 8 | #include "identity.hpp"
|
9 | 9 | #include "openvino/frontend/exception.hpp"
|
10 | 10 | #include "openvino/op/constant.hpp"
|
| 11 | +#include "openvino/op/convert.hpp" |
11 | 12 | #include "openvino/op/exp.hpp"
|
12 | 13 | #include "openvino/op/log.hpp"
|
13 | 14 | #include "openvino/op/multiply.hpp"
|
@@ -94,6 +95,27 @@ const std::set<element::Type> supported_types_v1 =
|
94 | 95 | {element::u32, element::u64, element::i32, element::i64, element::f16, element::f32, element::f64};
|
95 | 96 | const std::set<element::Type> supported_types_v2 =
|
96 | 97 | {element::u32, element::u64, element::i32, element::i64, element::f16, element::f32, element::f64, element::bf16};
|
| 98 | +const std::set<element::Type> supported_types_v3 = {element::u32, |
| 99 | + element::u64, |
| 100 | + element::i32, |
| 101 | + element::i64, |
| 102 | + element::f16, |
| 103 | + element::f32, |
| 104 | + element::f64, |
| 105 | + element::bf16, |
| 106 | + element::i8, |
| 107 | + element::u8}; |
| 108 | +const std::set<element::Type> supported_types_v4 = {element::u32, |
| 109 | + element::u64, |
| 110 | + element::i32, |
| 111 | + element::i64, |
| 112 | + element::f16, |
| 113 | + element::f32, |
| 114 | + element::f64, |
| 115 | + element::bf16, |
| 116 | + element::i8, |
| 117 | + element::u8, |
| 118 | + element::boolean}; |
97 | 119 |
|
98 | 120 | template <typename OpType>
|
99 | 121 | std::shared_ptr<ov::Node> make_ov_reduction_op(const Node& node,
|
@@ -177,11 +199,33 @@ namespace set_13 {
|
177 | 199 | ov::OutputVector reduce_sum(const ov::frontend::onnx::Node& node) {
|
178 | 200 | return {make_ov_reduction_op<v1::ReduceSum>(node, node.get_ov_inputs().at(0), supported_types_v2, false)};
|
179 | 201 | }
|
| 202 | +ov::OutputVector reduce_max(const ov::frontend::onnx::Node& node) { |
| 203 | + return {make_ov_reduction_op<v1::ReduceMax>(node, node.get_ov_inputs().at(0), supported_types_v3)}; |
| 204 | +} |
180 | 205 | } // namespace set_13
|
181 | 206 |
|
182 | 207 | namespace set_18 {
|
183 |
| -// Placeholder |
| 208 | +ov::OutputVector reduce_max(const ov::frontend::onnx::Node& node) { |
| 209 | + return {make_ov_reduction_op<v1::ReduceMax>(node, node.get_ov_inputs().at(0), supported_types_v3, false)}; |
| 210 | +} |
184 | 211 | } // namespace set_18
|
| 212 | + |
| 213 | +namespace set_20 { |
| 214 | +ov::OutputVector reduce_max(const ov::frontend::onnx::Node& node) { |
| 215 | + auto data = node.get_ov_inputs().at(0); |
| 216 | + if (data.get_element_type() != element::boolean) { |
| 217 | + return {make_ov_reduction_op<v1::ReduceMax>(node, data, supported_types_v3, false)}; |
| 218 | + } else { |
| 219 | + // Handling boolean as a uint8 |
| 220 | + return {std::make_shared<v0::Convert>( |
| 221 | + make_ov_reduction_op<v1::ReduceMax>(node, |
| 222 | + std::make_shared<ov::op::v0::Convert>(data, element::u8), |
| 223 | + supported_types_v4, |
| 224 | + false), |
| 225 | + element::boolean)}; |
| 226 | + } |
| 227 | +} |
| 228 | +} // namespace set_20 |
185 | 229 | } // namespace op
|
186 | 230 | } // namespace onnx
|
187 | 231 | } // namespace frontend
|
|
0 commit comments