1
+ // Copyright (C) 2022-2023 Intel Corporation
2
+ // SPDX-License-Identifier: Apache-2.0
3
+ //
4
+
5
+ #include " ov_finite_comparer.hpp"
6
+ #include " ov_models/utils/ov_helpers.hpp"
7
+
8
+ using namespace ov ::test;
9
+
10
+ void ov::test::FiniteLayerComparer::compare (const std::vector<ov::Tensor>& expected_outputs,
11
+ const std::vector<ov::Tensor>& actual_outputs,
12
+ float threshold,
13
+ bool to_check_nans,
14
+ std::optional<double > infinity_value) {
15
+ for (std::size_t output_iIndex = 0 ; output_iIndex < expected_outputs.size (); ++output_iIndex) {
16
+ const auto & expected = expected_outputs[output_iIndex];
17
+ const auto & actual = actual_outputs[output_iIndex];
18
+ FiniteLayerComparer::compare (expected, actual, threshold, to_check_nans, infinity_value);
19
+ }
20
+ }
21
+
22
+ template <typename T_IE>
23
+ inline void call_compare (const ov::Tensor& expected,
24
+ const T_IE* actual_buffer,
25
+ size_t size,
26
+ float threshold,
27
+ bool to_check_nans,
28
+ std::optional<double > infinity_value) {
29
+ const auto & precision = expected.get_element_type ();
30
+ switch (precision) {
31
+ case ov::element::Type_t::i64:
32
+ FiniteLayerComparer::compare<T_IE>(
33
+ expected.data <int64_t >(), actual_buffer, size, threshold, to_check_nans, infinity_value);
34
+ break ;
35
+ case ov::element::Type_t::i32:
36
+ FiniteLayerComparer::compare<T_IE>(
37
+ expected.data <int32_t >(), actual_buffer, size, threshold, to_check_nans, infinity_value);
38
+ break ;
39
+ case ov::element::Type_t::i16:
40
+ FiniteLayerComparer::compare<T_IE>(
41
+ expected.data <int16_t >(), actual_buffer, size, threshold, to_check_nans, infinity_value);
42
+ break ;
43
+ case ov::element::Type_t::i8:
44
+ FiniteLayerComparer::compare<T_IE>(
45
+ expected.data <int8_t >(), actual_buffer, size, threshold, to_check_nans, infinity_value);
46
+ break ;
47
+ case ov::element::Type_t::u64:
48
+ FiniteLayerComparer::compare<T_IE>(
49
+ expected.data <uint64_t >(), actual_buffer, size, threshold, to_check_nans, infinity_value);
50
+ break ;
51
+ case ov::element::Type_t::u32:
52
+ FiniteLayerComparer::compare<T_IE>(
53
+ expected.data <uint32_t >(), actual_buffer, size, threshold, to_check_nans, infinity_value);
54
+ break ;
55
+ case ov::element::Type_t::u16:
56
+ FiniteLayerComparer::compare<T_IE>(
57
+ expected.data <uint16_t >(), actual_buffer, size, threshold, to_check_nans, infinity_value);
58
+ break ;
59
+ case ov::element::Type_t::boolean:
60
+ case ov::element::Type_t::u8:
61
+ FiniteLayerComparer::compare<T_IE>(
62
+ expected.data <uint8_t >(), actual_buffer, size, threshold, to_check_nans, infinity_value);
63
+ break ;
64
+ case ov::element::Type_t::f64:
65
+ FiniteLayerComparer::compare<T_IE>(
66
+ expected.data <double >(), actual_buffer, size, threshold, to_check_nans, infinity_value);
67
+ break ;
68
+ case ov::element::Type_t::f32:
69
+ FiniteLayerComparer::compare<T_IE>(
70
+ expected.data <float >(), actual_buffer, size, threshold, to_check_nans, infinity_value);
71
+ break ;
72
+ case ov::element::Type_t::f16:
73
+ FiniteLayerComparer::compare<T_IE>(
74
+ expected.data <ov::float16>(), actual_buffer, size, threshold, to_check_nans, infinity_value);
75
+ break ;
76
+ case ov::element::Type_t::bf16:
77
+ FiniteLayerComparer::compare<T_IE>(
78
+ expected.data <ov::bfloat16>(), actual_buffer, size, threshold, to_check_nans, infinity_value);
79
+ break ;
80
+ case ov::element::Type_t::dynamic:
81
+ case ov::element::Type_t::undefined:
82
+ FiniteLayerComparer::compare<T_IE, T_IE>(
83
+ expected.data <T_IE>(), actual_buffer, size, threshold, to_check_nans, infinity_value);
84
+ break ;
85
+ default :
86
+ FAIL () << " Comparator for " << precision << " precision isn't supported" ;
87
+ }
88
+ return ;
89
+ }
90
+
91
+ void FiniteLayerComparer::compare (const ov::Tensor& expected,
92
+ const ov::Tensor& actual,
93
+ float threshold,
94
+ bool to_check_nans,
95
+ std::optional<double > infinity_value) {
96
+ const auto & precision = actual.get_element_type ();
97
+ auto k = static_cast <float >(expected.get_element_type ().size ()) / precision.size ();
98
+ // W/A for int4, uint4
99
+ if (expected.get_element_type () == ov::element::Type_t::u4 ||
100
+ expected.get_element_type () == ov::element::Type_t::i4) {
101
+ k /= 2 ;
102
+ } else if (expected.get_element_type () == ov::element::Type_t::undefined ||
103
+ expected.get_element_type () == ov::element::Type_t::dynamic) {
104
+ k = 1 ;
105
+ }
106
+ ASSERT_EQ (expected.get_byte_size (), actual.get_byte_size () * k);
107
+
108
+ const auto & size = actual.get_size ();
109
+ switch (precision) {
110
+ case ov::element::f32:
111
+ call_compare (expected, actual.data <float >(), size, threshold, to_check_nans, infinity_value);
112
+ break ;
113
+ case ov::element::i32:
114
+ call_compare (expected, actual.data <int32_t >(), size, threshold, to_check_nans, infinity_value);
115
+ break ;
116
+ case ov::element::u32:
117
+ call_compare (expected, actual.data <uint32_t >(), size, threshold, to_check_nans, infinity_value);
118
+ break ;
119
+ case ov::element::i64:
120
+ call_compare (expected, actual.data <int64_t >(), size, threshold, to_check_nans, infinity_value);
121
+ break ;
122
+ case ov::element::i8:
123
+ call_compare (expected, actual.data <int8_t >(), size, threshold, to_check_nans, infinity_value);
124
+ break ;
125
+ case ov::element::u16:
126
+ call_compare (expected, actual.data <uint16_t >(), size, threshold, to_check_nans, infinity_value);
127
+ break ;
128
+ case ov::element::i16:
129
+ call_compare (expected, actual.data <int16_t >(), size, threshold, to_check_nans, infinity_value);
130
+ break ;
131
+ case ov::element::boolean:
132
+ case ov::element::u8:
133
+ call_compare (expected, actual.data <uint8_t >(), size, threshold, to_check_nans, infinity_value);
134
+ break ;
135
+ case ov::element::u64:
136
+ call_compare (expected, actual.data <uint64_t >(), size, threshold, to_check_nans, infinity_value);
137
+ break ;
138
+ case ov::element::bf16:
139
+ call_compare (expected, actual.data <ov::bfloat16>(), size, threshold, to_check_nans, infinity_value);
140
+ break ;
141
+ case ov::element::f16:
142
+ call_compare (expected, actual.data <ov::float16>(), size, threshold, to_check_nans, infinity_value);
143
+ break ;
144
+ default :
145
+ FAIL () << " Comparator for " << precision << " precision isn't supported" ;
146
+ }
147
+ }
148
+
149
+ void ov::test::FiniteLayerComparer::compare (const std::vector<ov::Tensor>& expected_outputs,
150
+ const std::vector<ov::Tensor>& actual_outputs) {
151
+ FiniteLayerComparer::compare (
152
+ expected_outputs, actual_outputs, abs_threshold, this ->to_check_nans , this ->infinity_value );
153
+ }
0 commit comments