Skip to content

Commit

Permalink
Reshape feature implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
jatinwadhwa921 committed Feb 17, 2025
1 parent ade3b59 commit 53278fe
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 8 deletions.
47 changes: 44 additions & 3 deletions onnxruntime/core/providers/openvino/backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <algorithm>
#include <cassert>
#include <fstream>
#include <map>
#include <regex>
#include <sstream>
#include <unordered_map>
Expand Down Expand Up @@ -61,12 +62,19 @@ BackendManager::BackendManager(SessionContext& session_context,
return "";
}(subgraph);

// Save the indexes of graph inputs among fused_node's inputDefs
// (which also contains initializers).
if (!session_context_.shape.empty()) {
ValidateInputShapes(session_context_.shape, subgraph.GetInputs());
}

for (uint32_t index = 0; const auto& node : subgraph.GetInputs()) {
if(subgraph.GetGraph().GetConsumerNodes(node->Name()).size()==0)
{
continue;
}
subgraph_context_.input_names.insert({node->Name(), index++});
}


for (uint32_t index = 0; const auto& node : subgraph.GetOutputs()) {
subgraph_context_.output_names.insert({node->Name(), index++});
}
Expand Down Expand Up @@ -100,7 +108,7 @@ BackendManager::BackendManager(SessionContext& session_context,
}
}

if (ModelHasSymbolicInputDims(subgraph)) {
if (ModelHasSymbolicInputDims(subgraph) && session_context_.shape.empty()) {
subgraph_context_.has_dynamic_input_shape = true;
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims";
if ((session_context_.device_type.find("CPU") != std::string::npos ||
Expand Down Expand Up @@ -308,6 +316,39 @@ bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& s
return has_sym_dims;
}

void BackendManager::ValidateInputShapes(const std::map<std::string, ov::PartialShape>& shape,
const std::vector<const NodeArg*>& graph_inputs) const {
for (const auto& [tensor_name, requested_shape] : shape) {
// Find matching input in graph
const NodeArg* graph_input = nullptr;
for (const auto* input : graph_inputs) {
if (input->Name() == tensor_name) {
graph_input = input;
break;
}
}

if (!graph_input) {
ORT_THROW("Input " + tensor_name + "specified in reshape_input does not exist");
}

const ONNX_NAMESPACE::TensorShapeProto* graph_shape = graph_input->Shape();
if (!graph_shape) {
ORT_THROW("Graph input" + tensor_name + "has no shape information");
}

// Check dimensions count matches
size_t graph_dim_count = graph_shape->dim_size();
size_t requested_dim_count = requested_shape.get_max_shape().size();
if (graph_dim_count != requested_dim_count) {
ORT_THROW("Dimensions mismatched for input" + tensor_name +
": graph expects " + std::to_string(graph_dim_count) +
" dimensions but reshape_input specifies " +
std::to_string(requested_dim_count) + " dimensions");
}
}
}

// Check to see if the graph is QDQ
static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) {
std::unordered_set<std::string> qdq_ops = {"QuantizeLinear", "DequantizeLinear"};
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/openvino/backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class BackendManager {

bool ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const;
bool ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const;
void ValidateInputShapes(const shape_t& shape,
const std::vector<const NodeArg*>& graph_inputs) const;

std::shared_ptr<ONNX_NAMESPACE::ModelProto>
ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_proto);
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/openvino/backend_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ CreateOVModel(const std::string model,
try {
auto ov_model = OVCore::ReadModel(model, session_context.onnx_model_path_name.string());

if (!session_context.shape.empty()) {
LOGS_DEFAULT(INFO) << log_tag << "Reshaping the ov tensor to specified shape";
ov_model->reshape(session_context.shape);
}

// Check for Constant Folding
if ((session_context.device_type != "NPU") && !session_context.is_wholly_supported_graph) {
ov::pass::ConstantFolding pass_const_obj;
Expand Down
54 changes: 49 additions & 5 deletions onnxruntime/core/providers/openvino/backends/basic_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <sstream>
#include <fstream>
#include <utility>
#include <vector>

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/openvino/backend_utils.h"
Expand Down Expand Up @@ -96,6 +97,7 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
} else if (!session_context_.has_external_weights &&
!subgraph_context_.has_dynamic_input_shape &&
!session_context_.so_context_enable &&
session_context.shape.empty() &&
auto_unified_compile) {
// Unified OV compile_model is efficient when ov model caching is enabled
// Unified OV compile_model API is supported with AUTO from version 2024.3 and above
Expand Down Expand Up @@ -418,9 +420,20 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
(it != ort_ov_tensor_map.end() && (it->second.ort_ptr != tensor.GetTensorRawData()))) {
ov_tensor_data_t ov_tensor_data;
const auto& input = ov_input_info.at(input_idx);
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(input.get_element_type(), input.get_shape(),
const_cast<void*>(tensor.GetTensorRawData()));

if (!session_context_.shape.empty()) {
ov::PartialShape partial_shape = input.get_partial_shape();
const auto& ort_dims = tensor.GetTensorTypeAndShapeInfo().GetShape();
ValidateOrtDimsAgainstPartialShape(ort_dims, partial_shape);
ov::Shape concrete_shape;
for (size_t i = 0; i < ort_dims.size(); ++i) {
concrete_shape.push_back(ort_dims[i]);
}
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(input.get_element_type(), concrete_shape,
const_cast<void*>(tensor.GetTensorRawData()));
} else {
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(input.get_element_type(), input.get_shape(),
const_cast<void*>(tensor.GetTensorRawData()));
}
ov_tensor_data.ort_ptr = tensor.GetTensorRawData();
ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data;

Expand All @@ -434,6 +447,10 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
}
} // Loop subgraph original input names

if (!session_context_.shape.empty()) {
infer_request->Infer();
}

if (session_context_.device_type.find("NPU") != std::string::npos) {
// Set the output blob as remote blob
auto graph_output_info = exe_network_.Get().outputs();
Expand Down Expand Up @@ -465,8 +482,15 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
ov_tensor_data_t ov_tensor_data;
const auto& output = graph_output_info.at(output_idx);
ov_tensor_data.ort_ptr = tensor.GetTensorRawData();
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(output.get_element_type(), output.get_shape(),
const_cast<void*>(tensor.GetTensorRawData()));

if (!session_context_.shape.empty()) {
ov::Tensor output_tensor = infer_request->GetOutputTensor(output_idx);
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(output.get_element_type(), output_tensor.get_shape(),
const_cast<void*>(tensor.GetTensorRawData()));
} else {
ov_tensor_data.tensor_ptr = std::make_shared<ov::Tensor>(output.get_element_type(), output.get_shape(),
const_cast<void*>(tensor.GetTensorRawData()));
}
ort_ov_tensor_map[ort_tensor_key] = ov_tensor_data;

try {
Expand Down Expand Up @@ -669,6 +693,26 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe
}
}

void BasicBackend::ValidateOrtDimsAgainstPartialShape(const std::vector<int64_t>& ort_dims,
const ov::PartialShape& partial_shape) const {
// Check if the number of dimensions matches
if (static_cast<int64_t>(ort_dims.size()) != partial_shape.rank().get_length()) {
ORT_THROW("Mismatch in number of dimensions between ORT tensor and OpenVINO PartialShape.");
}
// Validate each dimension
for (size_t i = 0; i < ort_dims.size(); ++i) {
const auto& ov_dim = partial_shape[i]; // OpenVINO dimension at index i
int64_t ort_dim = ort_dims[i]; // ORT dimension at index i

// Check if the ORT dimension is within the specified range
int64_t min_dim = ov_dim.get_min_length();
int64_t max_dim = ov_dim.get_max_length();
if (ort_dim < min_dim || ort_dim > max_dim) {
ORT_THROW(" ORT Dimension is out of range");
}
}
}

void BasicBackend::Infer(OrtKernelContext* ctx) {
// Preliminary Thread safety mechanism
// currently allows a maximum of 8 Infer request's to parallel execute at the same time
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/openvino/backends/basic_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class BasicBackend : public IBackend {
void EnableStreams();
void SetNumThreads(ov::AnyMap& device_config);
void StartAsyncInference(Ort::KernelContext& context, std::shared_ptr<OVInferRequest> infer_request);
void ValidateOrtDimsAgainstPartialShape(const std::vector<int64_t>& ort_dims,
const ov::PartialShape& partial_shape) const;

#ifdef IO_BUFFER_ENABLED
void StartRemoteAsyncInference(Ort::KernelContext& context, std::shared_ptr<OVInferRequest> infer_request);
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/openvino/contexts.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ struct SharedContext {
};

using config_t = std::map<std::string, ov::AnyMap>;
using shape_t = std::map<std::string, ov::PartialShape>;

struct ProviderInfo {
std::string device_type{""}; // [device_type]: Overrides the accelerator hardware type and
Expand All @@ -74,6 +75,7 @@ struct ProviderInfo {
uint32_t num_of_threads{0}; // [num_of_threads]: Overrides the accelerator default value of
// number of threads with this value at runtime.
config_t load_config{}; // JSON config map to load custom OV parameters.
shape_t shape{}; // Used for reshaping ov tensors to a particular lower and upper bound
fs::path cache_dir{""}; // [cache_dir]: specify the path to
// dump and load the blobs for the model caching/kernel caching
// (GPU) feature. If blob files are already present,
Expand Down
91 changes: 91 additions & 0 deletions onnxruntime/core/providers/openvino/openvino_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,97 @@ struct OpenVINO_Provider : Provider {

pi.precision = ParsePrecision(provider_options, pi.device_type, "precision");

if (provider_options.contains("reshape_input") && pi.device_type == "NPU") {
auto parse_input_shapes = [&](const std::string& reshape_input_definition) {
std::map<std::string, ov::PartialShape> parsed_shape_map;
std::string unparsed_definition = reshape_input_definition;

while (!unparsed_definition.empty()) {
// Find the next shape definition brakcet
auto shape_start_bracket = unparsed_definition.find_first_of('[');
if (shape_start_bracket == std::string::npos) {
ORT_THROW("Malformed input: missing opening bracket '[' in: " + unparsed_definition);
}
// Extract the tensor name
std::string tensor_name = unparsed_definition.substr(0, shape_start_bracket);
// Remove the leading/trailing whitespaces
tensor_name.erase(0, tensor_name.find_first_not_of("\t"));
tensor_name.erase(tensor_name.find_last_not_of("\t") + 1);

if (tensor_name.empty()) {
ORT_THROW("Empty tensor name provided in rehsape_input parameter");
}

// Closing bracket for current shape definition
auto shape_end_bracket = unparsed_definition.find_first_of(']', shape_start_bracket);

if (shape_end_bracket == std::string::npos || shape_end_bracket < shape_start_bracket) {
ORT_THROW("Missing closing bracket ']' for tensor: " + tensor_name);
}

// Extract shape dimensions string
std::string shape_dimension_str = unparsed_definition.substr(shape_start_bracket + 1,
shape_end_bracket - shape_start_bracket - 1);
std::vector<ov::Dimension> dimension_values;
std::stringstream dimension_stream(shape_dimension_str);
std::string dimension_token;

while (std::getline(dimension_stream, dimension_token, ',')) {
// Remove leading/trailing whitespaces
dimension_token.erase(0, dimension_token.find_first_not_of("\t"));
dimension_token.erase(dimension_token.find_last_not_of("\t") + 1);

// Check if dimension is a range
size_t range_separator_pos = dimension_token.find("..");
if (range_separator_pos != std::string::npos) {
std::string range_start_str = dimension_token.substr(0, range_separator_pos);
std::string range_end_str = dimension_token.substr(range_separator_pos + 2);

// Remove leading/trailing spaced
range_start_str.erase(0, range_start_str.find_first_not_of("\t"));
range_start_str.erase(range_start_str.find_last_not_of("\t") + 1);
range_end_str.erase(0, range_end_str.find_first_not_of("\t"));
range_end_str.erase(range_end_str.find_last_not_of("\t") + 1);

if (range_start_str.empty() || range_end_str.empty() ||
!std::all_of(range_start_str.begin(), range_start_str.end(), ::isdigit) ||
!std::all_of(range_end_str.begin(), range_end_str.end(), ::isdigit)) {
ORT_THROW("Invalid dimension range format: " + dimension_token + " for tensor: " + tensor_name);
}

int range_start = std::stoi(range_start_str);
int range_end = std::stoi(range_end_str);

if (range_start > range_end) {
ORT_THROW("Invalid dimension range (start > end) for tensor: " + tensor_name);
}

dimension_values.emplace_back(ov::Dimension(range_start, range_end));
} else {
// Handle single dimension value
if (dimension_token.empty() ||
!std::all_of(dimension_token.begin(), dimension_token.end(), ::isdigit)) {
ORT_THROW("Invalid dimension value: " + dimension_token + " for tensor: " + tensor_name);
}
dimension_values.emplace_back(std::stoi(dimension_token));
}
}

// Store parsed shape in result map
parsed_shape_map[tensor_name] = ov::PartialShape(dimension_values);
// Update reminaing unparsed string
unparsed_definition = unparsed_definition.substr(shape_end_bracket + 1);
if (!unparsed_definition.empty() && unparsed_definition.front() == ',') {
unparsed_definition = unparsed_definition.substr(1);
}
// Remove leading whitespaces
unparsed_definition.erase(0, unparsed_definition.find_first_not_of("\t"));
}
return parsed_shape_map;
};
pi.shape = parse_input_shapes(provider_options.at("reshape_input"));
}

if (provider_options.contains("load_config")) {
auto parse_config = [&](const std::string& config_str) -> std::map<std::string, ov::AnyMap> {
// If the config string is empty, return an empty map and skip processing
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/providers/openvino/ov_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,16 @@ OVTensorPtr OVInferRequest::GetTensor(const std::string& input_name) {
}
}

OVTensor OVInferRequest::GetOutputTensor(const int& output_idx) {
try {
return ovInfReq.get_output_tensor(output_idx);
} catch (const Exception& e) {
ORT_THROW(log_tag + " Cannot access output tensor: " + e.what());
} catch (...) {
ORT_THROW(log_tag + " Cannot access output tensor");
}
}

std::string OVInferRequest::GetInputTensorName(uint32_t index) {
try {
const auto& model = ovInfReq.get_compiled_model();
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/openvino/ov_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class OVInferRequest {
OVTensorPtr GetTensor(const std::string& name);
std::string GetInputTensorName(uint32_t index);
void SetTensor(const std::string& name, OVTensorPtr& blob);
OVTensor GetOutputTensor(const int& output_idx);
void StartAsync();
void Infer();
void WaitRequest();
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/test/perftest/ort_test_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,8 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
}
} else if (key == "device_memory_name") {
device_memory_name_ = std::move(value);
} else if (key == "reshape_input") {
ov_options[key] = value;
} else {
ORT_THROW(
"[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO."
Expand Down

0 comments on commit 53278fe

Please sign in to comment.