Skip to content

Commit 1199dc0

Browse files
authored
Enabling L2+ Optimizations for EPs (microsoft#23517)
There are some requirements to modify the graph which are specific to the EP/hardware. ORT has the hardcoded EP list for optimizations but that can't scale and it's hard be extended to enable EP custom optimizations. Here is the prototype to enable L2+ optimizations for EPs (The original overview is provided by @skottmckay) as well as the TRT EP implementation for the ConstantFoldingDQ optimization. Signatures for selection and optimization functions: ```` - Selection: std::function<std::vector<std::unique_ptr<ComputeCapability>>(const GraphViewer&, const KeyValueConfig&)> - Optimization: std::function<Status(const Graph&, const ComputeCapability& this_optimization, ComputeCapability& cc_to_update)> ```` GetCapability - call (new) provider bridge API to lookup pre-defined optimizer by name and get selection function - ComputeCapability.optimize_func, i.e. optimization function, would be set by the optimizer to the function that does the optimization - EP has to update the returning ComputeCapability to include the optimization ComputeCapability in nodes_to_optimize. So that later ORT can perform optimization/transformation accordingly. GraphPartitioner - After assigning the ComputeCapability to the EP and prior to Compile, if the ComputeCapability has nodes_to_optimize, iterate that list - optimization function needs to be called with - a mutable Graph instance - the ComputeCapability for the individual optimization - the overall ComputeCapability so it can be updated
1 parent c28bf78 commit 1199dc0

File tree

67 files changed

+874
-107
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+874
-107
lines changed

cmake/onnxruntime_optimizer.cmake

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ if (onnxruntime_MINIMAL_BUILD)
99
list(APPEND onnxruntime_optimizer_src_patterns
1010
"${ONNXRUNTIME_INCLUDE_DIR}/core/optimizer/graph_transformer.h"
1111
"${ONNXRUNTIME_ROOT}/core/optimizer/graph_transformer.cc"
12+
"${ONNXRUNTIME_ROOT}/core/optimizer/graph_optimizer_registry.cc"
1213
)
1314

1415
if (onnxruntime_EXTENDED_MINIMAL_BUILD)

include/onnxruntime/core/framework/execution_provider.h

+16
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ struct ComputeCapability;
2020
class KernelRegistry;
2121
struct KernelCreateInfo;
2222
class Node;
23+
class GraphOptimizerRegistry;
2324
} // namespace onnxruntime
2425
#else
2526
#include <memory>
@@ -129,10 +130,25 @@ class IExecutionProvider {
129130
and decide whether a node will be assigned to <*this> execution provider.
130131
For kernels registered in a kernel registry, `kernel_lookup` must be used
131132
to find a matching kernel for this EP.
133+
134+
The graph_optimizer_registry is designed for enabling L2+ graph optimizations tailored for EPs.
135+
These optimizations are applied after the graph partitioner assigns ComputeCapability to the EP
136+
and before EP's "Compile" or fusion.
137+
138+
Steps to use graph_optimizer_registry and create the optimization ComputeCapability:
139+
1. Lookup Optimizer: The EP calls provider bridge API to lookup pre-defined optimizer by name and get selection function.
140+
- Example: g_host->GetOptimizerByName(optimizer_name, graph_optimizer_registry, selection_func)
141+
2. Run Selection Function: The EP executes the selection function to obtain the selection ComputeCapability.
142+
- ComputeCapability.optimize_func would be set by the optimizer to the function that does the optimization.
143+
3. Create Optimization ComputeCapability: The EP uses the selection ComputeCapability to create the optimization ComputeCapability.
144+
4. Return ComputeCapability: The EP returns the final ComputeCapability, with nodes_to_optimize set to the optimization ComputeCapability.
145+
146+
Note: For more detailed implementations of using graph_optimizer_registry, please refer to TensorRT EP.
132147
*/
133148
virtual std::vector<std::unique_ptr<ComputeCapability>>
134149
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
135150
const IKernelLookup& kernel_lookup,
151+
const GraphOptimizerRegistry& graph_optimizer_registry,
136152
IResourceAccountant* resource_accountant = nullptr) const;
137153

138154
/**

include/onnxruntime/core/graph/indexed_sub_graph.h

+6
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ struct IndexedSubGraph {
7272
return meta_def_.get();
7373
}
7474

75+
/** Gets the mutable meta definition needed to represent this subgraph as a FunctionProto.
76+
@returns MetaDef instance if it has been set. nullptr if not. */
77+
MetaDef* GetMutableMetaDef() {
78+
return meta_def_.get();
79+
}
80+
7581
// Check if the accounting is enabled for the current EP
7682
bool IsAccountingEnabled() const {
7783
return resource_accountant != nullptr &&

onnxruntime/core/framework/compute_capability.h

+20
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
// Licensed under the MIT License.
33

44
#pragma once
5+
#include <functional>
56
#include "core/common/common.h"
67
#include "core/graph/indexed_sub_graph.h"
8+
#include "core/graph/graph.h"
9+
#include "core/optimizer/graph_optimizer_registry.h"
710

811
namespace onnxruntime {
912
// A structure encodes a subgraph and the method to run it.
@@ -21,5 +24,22 @@ struct ComputeCapability {
2124

2225
ComputeCapability(std::unique_ptr<IndexedSubGraph> t_sub_graph)
2326
: sub_graph(std::move(t_sub_graph)) {}
27+
28+
// Optional function to optimize this ComputeCapability.
29+
// This will be called by ORT once the ComputeCapability is assigned to the EP.
30+
std::function<Status(Graph&,
31+
const ComputeCapability& /* this_optimization*/,
32+
ComputeCapability& /* cc_to_update */,
33+
const GraphOptimizerRegistry&)>
34+
optimization_func;
35+
36+
// Optional ComputeCapability instances for sets of nodes within this ComputeCapability that should be optimized.
37+
// when an optimization is applied, ORT will update this ComputeCapability to reflect the changes made.
38+
// IndexedSubGraph.nodes:
39+
// - update based on RemovedNode/AddNode calls
40+
// IndexedSubGraph.MetaDef (if present):
41+
// - inputs and outputs will be unchanged
42+
// - constant_initializers MAY change if we constant fold an initializer during optimization
43+
std::vector<std::unique_ptr<ComputeCapability>> nodes_to_optimize;
2444
};
2545
} // namespace onnxruntime

onnxruntime/core/framework/execution_provider.cc

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace onnxruntime {
1414
std::vector<std::unique_ptr<ComputeCapability>>
1515
IExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
1616
const IKernelLookup& kernel_lookup,
17+
const GraphOptimizerRegistry&,
1718
IResourceAccountant*) const {
1819
std::vector<std::unique_ptr<ComputeCapability>> result;
1920
for (const auto& node : graph.Nodes()) {

0 commit comments

Comments
 (0)