@@ -746,7 +746,8 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {
746
746
const float alpha,
747
747
const float beta,
748
748
const float gamma,
749
- const std::vector<ov::element::Type>& input_precisions = {}) {
749
+ const std::vector<ov::element::Type>& input_precisions = {},
750
+ const std::vector<ov::element::Type>& output_precisions = {}) {
750
751
#if defined(OPENVINO_ARCH_X86_64)
751
752
const auto isISASupportedByJIT = mayiuse (dnnl::impl::cpu::x64::sse41);
752
753
#elif defined(OPENVINO_ARCH_ARM64)
@@ -788,15 +789,27 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {
788
789
return false ;
789
790
}
790
791
791
- const std::set<ov::element::Type> supported_precisions =
792
+ std::set<ov::element::Type> supported_input_precisions = std::set<ov::element::Type>{ov::element::f16,
793
+ ov::element::f32,
794
+ ov::element::i32,
795
+ ov::element::i8,
796
+ ov::element::u8};
797
+
798
+ std::set<ov::element::Type> supported_output_precisions = supported_input_precisions;
799
+ if (one_of (algorithm, Algorithm::EltwiseDivide, Algorithm::EltwiseFloor)) {
800
+ supported_input_precisions = std::set<ov::element::Type>{ov::element::f16, ov::element::f32};
801
+ }
802
+
803
+ auto fusedOps = node->getFusedWith ();
804
+ if (!fusedOps.empty ()) {
792
805
// Divide and Floor (issue #138629) operations are supported for fp32 and fp16 only.
793
- ((algorithm == Algorithm::EltwiseDivide) || (algorithm == Algorithm::EltwiseFloor))
794
- ? std::set<ov::element::Type>{ov::element::f16, ov::element::f32}
795
- : std::set<ov::element::Type>{ov::element::f16,
796
- ov::element::f32,
797
- ov::element::i32,
798
- ov::element::i8,
799
- ov::element::u8};
806
+ if ( one_of (fusedOps. back ()-> getAlgorithm (), Algorithm::EltwiseDivide, Algorithm::EltwiseFloor)) {
807
+ supported_output_precisions = std::set<ov::element::Type>{ov::element::f16, ov::element::f32};
808
+ }
809
+ } else {
810
+ supported_output_precisions = supported_input_precisions;
811
+ }
812
+
800
813
#elif defined(OPENVINO_ARCH_RISCV64)
801
814
if (!one_of (algorithm,
802
815
Algorithm::EltwiseAdd,
@@ -813,36 +826,37 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {
813
826
return false ;
814
827
}
815
828
816
- const std::set<ov::element::Type> supported_precisions = {ov::element::f32,
817
- ov::element::i32,
818
- ov::element::i8,
819
- ov::element::u8};
829
+ const std::set<ov::element::Type> supported_input_precisions = {ov::element::f32,
830
+ ov::element::i32,
831
+ ov::element::i8,
832
+ ov::element::u8};
833
+ auto supported_input_precisions = supported_output_precisions;
820
834
#endif
821
835
822
836
#if defined(OPENVINO_ARCH_ARM64) || defined(OPENVINO_ARCH_RISCV64)
823
- const auto check_precisions = [](const std::vector<ov::element::Type>& input_precisions,
824
- const std::vector<ov::element::Type>& output_precisions,
825
- const std::set<ov::element::Type>& supported_precisions) {
837
+ const auto check_precisions = [&](const std::vector<ov::element::Type>& input_precisions,
838
+ const std::vector<ov::element::Type>& output_precisions) {
826
839
if (std::any_of (input_precisions.begin (),
827
840
input_precisions.end (),
828
- [&supported_precisions ](const ov::element::Type& precision) {
829
- return supported_precisions .find (precision) == supported_precisions .end ();
841
+ [&supported_input_precisions ](const ov::element::Type& precision) {
842
+ return supported_input_precisions .find (precision) == supported_input_precisions .end ();
830
843
})) {
831
844
return false ;
832
845
}
833
846
834
847
if (std::any_of (output_precisions.begin (),
835
848
output_precisions.end (),
836
- [&supported_precisions ](const ov::element::Type& precision) {
837
- return supported_precisions .find (precision) == supported_precisions .end ();
849
+ [&supported_output_precisions ](const ov::element::Type& precision) {
850
+ return supported_output_precisions .find (precision) == supported_output_precisions .end ();
838
851
})) {
839
852
return false ;
840
853
}
841
854
842
855
return true ;
843
856
};
844
857
845
- return check_precisions (input_precisions, node->getOriginalOutputPrecisions (), supported_precisions);
858
+ auto out_precisions = output_precisions.empty () ? node->getOriginalOutputPrecisions () : output_precisions;
859
+ return check_precisions (input_precisions, out_precisions);
846
860
#endif
847
861
848
862
// Unsupported architectures should return false:
0 commit comments