|
31 | 31 | #include "ngraph_bridge/ngraph_mark_for_clustering.h"
|
32 | 32 | #include "ngraph_bridge/ngraph_utils.h"
|
33 | 33 |
|
| 34 | +#include "ocm/include/ocm_nodes_checker.h" |
| 35 | + |
34 | 36 | using namespace std;
|
35 | 37 |
|
36 | 38 | namespace tensorflow {
|
@@ -104,7 +106,30 @@ class NGraphEncapsulationPass : public NGraphRewritePass {
|
104 | 106 |
|
105 | 107 | // 1. Mark for clustering then, if requested, dump the graphs.
|
106 | 108 | std::set<string> skip_these_nodes = {};
|
107 |
| - TF_RETURN_IF_ERROR(MarkForClustering(graph, skip_these_nodes)); |
| 109 | + // TF_RETURN_IF_ERROR(MarkForClustering(graph, skip_these_nodes)); |
| 110 | + |
| 111 | + // OCM bypassing the MarkForClustering function call |
| 112 | + const char* device_id = std::getenv("NGRAPH_TF_BACKEND"); |
| 113 | + if (device_id==nullptr){ |
| 114 | + device_id = "CPU"; |
| 115 | + } |
| 116 | + std::string ov_version = "2021_1"; |
| 117 | + ocm::Framework_Names fName = ocm::Framework_Names::TF; |
| 118 | + ocm::FrameworkNodesChecker FC(fName, device_id, ov_version, options.graph->get()); |
| 119 | + std::vector<void *> nodes_list = FC.MarkSupportedNodes(); |
| 120 | + |
| 121 | + // cast back the nodes in the TF format and mark the nodes for clustering (moved out from MarkForClustering function) |
| 122 | + const std::map<std::string, SetAttributesFunction>& set_attributes_map = GetAttributeSetters(); |
| 123 | + for (auto void_node : nodes_list) { |
| 124 | + // TODO(amprocte): move attr name to a constant |
| 125 | + tensorflow::Node* node = (tensorflow::Node *)void_node; |
| 126 | + node->AddAttr("_ngraph_marked_for_clustering", true); |
| 127 | + auto it = set_attributes_map.find(node->type_string()); |
| 128 | + if (it != set_attributes_map.end()) { |
| 129 | + it->second(node); |
| 130 | + } |
| 131 | + } |
| 132 | + |
108 | 133 | util::DumpTFGraph(graph, idx, "marked");
|
109 | 134 |
|
110 | 135 | // 2. Assign clusters then, if requested, dump the graphs.
|
|
0 commit comments