Skip to content

Commit 03f1951

Browse files
committed
Add support for parsing AUTO, HETERO and MULTI from json config
1 parent bd32f51 commit 03f1951

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

onnxruntime/core/providers/openvino/backends/basic_backend.cc

+15
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,15 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
224224
}
225225
}
226226
}
227+
auto find_device_type_mode = [&](const std::string& device_type) -> std::string {
228+
std::string device_mode="";
229+
auto delimiter_pos = device_type.find(':');
230+
if (delimiter_pos != std::string::npos) {
231+
std::stringstream str_stream(device_type.substr(0, delimiter_pos));
232+
std::getline(str_stream, device_mode, ',');
233+
}
234+
return device_mode;
235+
};
227236

228237
// Parse device types like "AUTO:CPU,GPU" and extract individual devices
229238
auto parse_individual_devices = [&](const std::string& device_type) -> std::vector<std::string> {
@@ -272,8 +281,12 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
272281
if (session_context_.device_type.find("AUTO") == 0 ||
273282
session_context_.device_type.find("HETERO") == 0 ||
274283
session_context_.device_type.find("MULTI") == 0) {
284+
//// Parse to get the device mode (e.g., "AUTO:CPU,GPU" -> "AUTO")
285+
auto device_mode = find_device_type_mode(session_context_.device_type);
275286
// Parse individual devices (e.g., "AUTO:CPU,GPU" -> ["CPU", "GPU"])
276287
auto individual_devices = parse_individual_devices(session_context_.device_type);
288+
if(!device_mode.empty()) individual_devices.emplace_back(device_mode);
289+
277290
// Set properties only for individual devices (e.g., "CPU", "GPU")
278291
for (const std::string& device : individual_devices) {
279292
if (target_config.count(device)) {
@@ -284,6 +297,8 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
284297
}
285298
}
286299
} else {
300+
std::unordered_set<std::string> valid_ov_devices = {"CPU", "GPU", "NPU", "AUTO", "HETERO", "MULTI"};
301+
287302
if (target_config.count(session_context_.device_type)) {
288303
auto supported_properties = OVCore::Get()->core.get_property(session_context_.device_type,
289304
ov::supported_properties);

onnxruntime/core/providers/openvino/openvino_provider_factory.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ std::string ParseDeviceType(std::shared_ptr<OVCore> ov_core, const ProviderOptio
100100
default_device = DEVICE_NAME;
101101

102102
// Validate that devices passed are valid
103-
int delimit = device_type.find(":");
104-
const auto& devices = device_type.substr(delimit + 1);
103+
int delimit = default_device.find(":");
104+
const auto& devices = default_device.substr(delimit + 1);
105105
auto device_list = split(devices, ',');
106106
for (const auto& device : devices) {
107107
if (!ov_supported_device_types.contains(device)) {
@@ -256,9 +256,9 @@ struct OpenVINO_Provider : Provider {
256256

257257
for (auto& [key, value] : json_config.items()) {
258258
ov::AnyMap inner_map;
259-
259+
std::unordered_set<std::string> valid_ov_devices = {"CPU", "GPU", "NPU", "AUTO", "HETERO", "MULTI"};
260260
// Ensure the key is one of "CPU", "GPU", or "NPU"
261-
if (key != "CPU" && key != "GPU" && key != "NPU") {
261+
if (valid_ov_devices.find(key) == valid_ov_devices.end()) {
262262
LOGS_DEFAULT(WARNING) << "Unsupported device key: " << key << ". Skipping entry.\n";
263263
continue;
264264
}

0 commit comments

Comments
 (0)