15
15
namespace onnxruntime {
16
16
namespace openvino_ep {
17
17
void ParseConfigOptions (ProviderInfo& pi ) {
18
- if (pi .config_options == NULL )
18
+ if (pi .config_options == NULL )
19
19
return ;
20
20
21
21
pi .so_disable_cpu_ep_fallback = pi .config_options ->GetConfigOrDefault (kOrtSessionOptionsDisableCPUEPFallback , " 0" ) == " 1" ;
@@ -29,7 +29,6 @@ void ParseConfigOptions(ProviderInfo& pi) {
29
29
map[" NPU_COMPILATION_MODE_PARAMS" ] = " enable-wd-blockarg-input=true compute-layers-with-higher-precision=Sqrt,Power,ReduceSum" ;
30
30
pi .load_config [" NPU" ] = std::move (map);
31
31
}
32
-
33
32
}
34
33
35
34
void * ParseUint64 (const ProviderOptions& provider_options, std::string option_name) {
@@ -123,48 +122,68 @@ std::string ParsePrecision(const ProviderOptions& provider_options, std::string&
123
122
using foo = std::pair<DefaultValue, ValidValues>;
124
123
using ParserHelper = std::map<DeviceName, foo>;
125
124
ParserHelper helper = {
126
- {" GPU" , {" FP16" , {" FP16" , " FP32" }}},
127
- {" NPU" , {" FP16" , {" FP16" }}},
128
- {" CPU" , {" FP32" , {" FP32" }}},
125
+ {" GPU" , {" FP16" , {" FP16" , " FP32" , " ACCURACY " }}},
126
+ {" NPU" , {" FP16" , {" FP16" , " ACCURACY " }}},
127
+ {" CPU" , {" FP32" , {" FP32" , " ACCURACY " }}},
129
128
};
130
129
131
130
std::set<std::string> deprecated_device_types = {" CPU_FP32" , " GPU_FP32" ,
132
131
" GPU.0_FP32" , " GPU.1_FP32" , " GPU_FP16" ,
133
132
" GPU.0_FP16" , " GPU.1_FP16" };
134
133
134
+ bool is_composite = device_type.find (' :' ) != std::string::npos; // FOR devices AUTO:,HETRO:,MULTI:
135
+
135
136
if (provider_options.contains (option_name)) {
136
- // Start by checking if the device_type is a normal valid one
137
- if (helper. contains (device_type)) {
138
- auto const & valid_values = helper[device_type]. second ;
139
- const auto & precision = provider_options. at (option_name) ;
140
- if (precision == " ACCURACY " ) {
141
- return valid_values. back (); // Return highest supported precision
137
+ const auto & precision = provider_options. at (option_name);
138
+
139
+ if (is_composite) {
140
+ std::set<std::string> allowed_precisions = { " FP16 " , " FP32 " , " ACCURACY " } ;
141
+ if (allowed_precisions. contains ( precision) ) {
142
+ return precision;
142
143
} else {
143
- if (std::find (valid_values.begin (), valid_values.end (), precision) != valid_values.end ()) {
144
- return precision; // Return precision selected if valid
144
+ ORT_THROW (" [ERROR] [OpenVINO] Unsupported inference precision is selected. " , precision, " .\n " );
145
+ }
146
+ } else {
147
+ if (helper.contains (device_type)) {
148
+ auto const & valid_values = helper[device_type].second ;
149
+
150
+ if (precision == " ACCURACY" ) {
151
+ return valid_values.back (); // Return highest supported precision
145
152
} else {
146
- auto value_iter = valid_values.begin ();
147
- std::string valid_values_joined = *value_iter;
148
- // Append 2nd and up, if only one then ++value_iter is same as end()
149
- for (++value_iter; value_iter != valid_values.end (); ++value_iter) {
150
- valid_values_joined += " , " + *value_iter;
151
- }
153
+ if (std::find (valid_values.begin (), valid_values.end (), precision) != valid_values.end ()) {
154
+ return precision; // Return precision selected if valid
155
+ } else {
156
+ auto value_iter = valid_values.begin ();
157
+ std::string valid_values_joined = *value_iter;
158
+ // Append 2nd and up, if only one then ++value_iter is same as end()
159
+ for (++value_iter; value_iter != valid_values.end (); ++value_iter) {
160
+ valid_values_joined += " , " + *value_iter;
161
+ }
152
162
153
- ORT_THROW (" [ERROR] [OpenVINO] Unsupported inference precision is selected. " , device_type, " only supports" , valid_values_joined, " .\n " );
163
+ ORT_THROW (" [ERROR] [OpenVINO] Unsupported inference precision is selected. " , device_type, " only supports" , valid_values_joined, " .\n " );
164
+ }
154
165
}
166
+ } else if (deprecated_device_types.contains (device_type)) {
167
+ LOGS_DEFAULT (WARNING) << " [OpenVINO] Selected 'device_type' " + device_type + " is deprecated. \n "
168
+ << " Update the 'device_type' to specified types 'CPU', 'GPU', 'GPU.0', "
169
+ << " 'GPU.1', 'NPU' or from"
170
+ << " HETERO/MULTI/AUTO options and set 'precision' separately. \n " ;
171
+ auto delimit = device_type.find (" _" );
172
+ device_type = device_type.substr (0 , delimit);
173
+ return device_type.substr (delimit + 1 );
174
+ } else {
175
+ ORT_THROW (" [ERROR] [OpenVINO] Unsupported device type provided: " , device_type, " \n " );
155
176
}
156
- } else if (deprecated_device_types. contains (device_type)) {
157
- LOGS_DEFAULT (WARNING) << " [OpenVINO] Selected 'device_type' " + device_type + " is deprecated. \n "
158
- << " Update the 'device_type' to specified types 'CPU', ' GPU', 'GPU.0', "
159
- << " 'GPU.1', 'NPU' or from "
160
- << " HETERO/MULTI/AUTO options and set 'precision' separately. \n " ;
161
- auto delimit = device_type. find ( " _ " ) ;
162
- device_type = device_type. substr ( 0 , delimit);
163
- return device_type. substr (delimit + 1 );
177
+ }
178
+ } else {
179
+ if (device_type. find ( " NPU " ) != std::string::npos || device_type. find ( " GPU" ) != std::string::npos) {
180
+ return " FP16 " ;
181
+ } else if (device_type. find ( " CPU " ) != std::string::npos) {
182
+ return " FP32 " ;
183
+ } else {
184
+ ORT_THROW ( " [ERROR] [OpenVINO] Unsupported device is selected " , device_type, " \n " );
164
185
}
165
186
}
166
- // Return default
167
- return helper[device_type].first ;
168
187
}
169
188
170
189
void ParseProviderOptions ([[maybe_unused]] ProviderInfo& result, [[maybe_unused]] const ProviderOptions& config_options) {}
@@ -204,7 +223,7 @@ struct OpenVINO_Provider : Provider {
204
223
const ProviderOptions* provider_options_ptr = reinterpret_cast <ProviderOptions*>(pointers_array[0 ]);
205
224
const ConfigOptions* config_options = reinterpret_cast <ConfigOptions*>(pointers_array[1 ]);
206
225
207
- if (provider_options_ptr == NULL ) {
226
+ if (provider_options_ptr == NULL ) {
208
227
LOGS_DEFAULT (ERROR) << " [OpenVINO EP] Passed NULL ProviderOptions to CreateExecutionProviderFactory()" ;
209
228
return nullptr ;
210
229
}
0 commit comments