forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbackend_detail.h
39 lines (26 loc) · 1.05 KB
/
backend_detail.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#pragma once
#include <torch/csrc/jit/api/module.h>
#include <ATen/core/jit_type.h>
#include <functional>
namespace torch::jit {
using DebugHandleType = int64_t;
using NodeToDebugHandle = std::unordered_map<Node*, DebugHandleType>;
using BackendDebugHandleGenerator =
std::function<NodeToDebugHandle(const std::shared_ptr<Graph>&)>;
namespace detail {
using BackendPreprocessFunction = std::function<c10::IValue(
const Module&,
const c10::Dict<IValue, IValue>&,
const BackendDebugHandleGenerator& generate_debug_handles)>;
TORCH_API void registerBackendPreprocessFunction(
const std::string& name,
const BackendPreprocessFunction& preprocess);
bool hasBackendPreprocessFunction(const std::string& name);
BackendPreprocessFunction getBackendPreprocessFunction(const std::string& name);
TORCH_API Module codegen_backend_module(
const std::string& backend_name,
const Module& orig_module,
const c10::Dict<IValue, IValue>& method_compile_spec,
const c10::DictTypePtr& any_dict_ty);
} // namespace detail
} // namespace torch::jit