forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbackend_debug_handler.h
138 lines (130 loc) · 6.18 KB
/
backend_debug_handler.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#pragma once
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/backends/backend_detail.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/scope.h>
#include <atomic>
namespace torch::jit {
/*
* BackendDebugHandleManager is responsible for issuing debug handles to
* backends. Debug handles are associated with nodes of a graph.
* BackendDebugHandleManager also maintains a map
* [debug-handle, DebugInfoTuple = {source range, inlined callstack ptr]} that
* will help generate a callstack for exception raised using debug handles.
* Effectively debug handles are something that is given to backend and later
* when an exception occurs in the backend, backend can tell, using debug
* handle, that an exception occurred here. Then the runtime can generate
* callstack correspoding to the exception.
* There are two parts to BackendDebugHandleManager:
* 1. static std::atomic debug_handle
* 2. Map of [debug-handle, DebugInfoTuple]
*
* About 1:
* Why do they have to be unique. The reason is that by ensuring
* uniqueness of debug handles, we remove the burden of another layer of
* mapping where we need to say this set of debug handles were generated for
* this lowered module or this bytecode function. This simplifies the API for
* serialization since debug handles can uniquely identify DebugInfoTuple.
* Thus simplifies the runtime API for throwing exception. Exception throwing
* only needs to know debug_handle and not which module or method threw it.
* There are 2 issues to keep in mind, though,for static std::atomic
* debug_handle: A. Performance implications of using atomic variable. However
* this is only used for compilation so we assume to absorb some of that
* penalty. Plus if there is no contention then we should have less to worry
* about. B. If repeated compilation is part of a long running process then we
* may overflow int64_t. We may detect and fail on this. For now this is not
* done.
*
* Now about 2:
* There are two usecases for [debug-handle, DebugInfoTuple]
* A. During bytecode generation the DebugInfoTuple corresponding to the nodes
* of the inlined graph being serialized, are stored in this object and a
* unique debug handle is returned. This unique debug handle is stored in
* mobile_debug info for pytorch lite models. It will be used for raising
* exceptions as well as profiling. B. During backend lowering, each backend's
* preprocess/compile method can compile method's graph and serialize those
* methods. Once the method is lowered to backend, graph is essentially lost.
* Without access to graph it is hard to generate model level debug info. Thus
* the debug handles provide a way to map nodes of the graph to the model level
* debug info.
*
* During byte-code model serialization, [debug-handle, DebugInfoTuple] is
* serialized. Now we know a. debug handles and b. how to map debug handles to
* model source code. Thus we can either do eager symbolication by converting
* debug handles to corresponding source code at runtime, or do lazy
* symbolicattion offline.
*
* Note that it is not necessary to serialize [debug-handle, DebugInfoTuple]
* corresponding to lowered backend if the lowering process, that is
* preprocess/compile, and execution happens in the same session, then eager
* symbolication can be employed.
*
* Now how does BackendDebugHandleManager capture all of the above?
* By providing two API.
* 1. getNextDebugHandle which given a Node* returns a unique debug handle,
* that will uniquely identify DebugInfoTuple.
* and
* 2. getCallStackPtrMap which returns the map
* [debug-handle, DebugInfoTuple]
*
* 1 provides debug handles to backends and 2 provides runtime a way to map
* debug handles to source level debug info.
*
* So why does debug handle map to DebugInfoTuple = {source range and inlined
* cs}? {debug_handle, source_range_tag, serialized_callstack} Take this
* example: class L(nn.Module): def __init__(self) -> None:
* ...
* def forward(self, x):
* return x * 5
* class M(nn.Module):
* def __init__(self) -> None:
* ...
* def forward(self, x):
* return x - 2
* class N(nn.Module):
* def __init__(self) -> None:
* self.m = M()
* def forward(self, x):
* return self.m(x) + 3
* m = torch.jit.script(N())
* Once you inline m's forward method, m.forward.graph will look something
* like this
* graph(%self...):
* %x = aten::mul(..)
* %x = aten::sub(x, ..)
* %y = aten::add(x, ..)
* ..
* Inlined callstack ptr for these two nodes will look like:
* aten::mul's inlined CS (callstack): [N.forward, source range] -> [M.forward,
* source range] aten::sub's inlined CS (callstack): [N.forward, source range]
* aten::add's inlined CS: null
* mul node's inlined CS contains only information about the callsites' source
* range The information about mul node's source range ('return x * 5') is not
* available in its inlined CS. It is rather part of node's source range
* instead of inlined CS. Thus to get full stack: [N.forward, source range] ->
* [M.forward, source range] -> [aten::mul's source range] We need to track
* mul's source range and inlined CS both.
*/
using BackendDebugInfoMapType =
std::unordered_map<torch::jit::DebugHandleType, DebugInfoTuple>;
/*
* This class is used to generate debug info map.
* backend's preprocess will call generate_debug_handles (see
* backend_detail.cpp), which uses debug_handle_manager to generate debug
* handles. When lowering process finishes, calling stopRecording will
* return debug info map from debug_handle_manager
*/
class TORCH_API BackendDebugInfoRecorder {
public:
BackendDebugInfoRecorder() = default;
int64_t getNextDebugHandle(const Node* node);
// Reason this is not done as RAII is that work done in stopRecording
// can throw, and throwing with dtor will call terminate and thus voids any
// exception catching at a higher level.
BackendDebugInfoMapType stopRecording();
NodeToDebugHandle generate_debug_handles(const std::shared_ptr<Graph>& graph);
private:
static std::atomic<DebugHandleType> unique_debug_handle_;
BackendDebugInfoMapType handles_to_inlined_callstack_ptrs_;
};
} // namespace torch::jit