Skip to content

Commit 1dbc8ad

Browse files
kurtamohlerpytorchmergebot
authored andcommitted
Add Warning class and refactor C++ warnings to use it (pytorch#84101)
Also adds `TORCH_WARN_WITH` and `TORCH_WARN_DEPRECATION` macros Part of pytorch#72948 Pull Request resolved: pytorch#84101 Approved by: https://github.com/albanD
1 parent db65909 commit 1dbc8ad

13 files changed

+222
-149
lines changed

c10/util/Exception.cpp

+52-27
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ void torchInternalAssertFail(
116116

117117
} // namespace detail
118118

119-
namespace Warning {
119+
namespace WarningUtils {
120120

121121
namespace {
122122
WarningHandler* getBaseHandler() {
@@ -147,27 +147,6 @@ thread_local WarningHandler* ThreadWarningHandler::warning_handler_ = nullptr;
147147

148148
} // namespace
149149

150-
void warn(
151-
const SourceLocation& source_location,
152-
const std::string& msg,
153-
const bool verbatim) {
154-
ThreadWarningHandler::get_handler()->process(source_location, msg, verbatim);
155-
}
156-
157-
void warn(
158-
SourceLocation source_location,
159-
detail::CompileTimeEmptyString msg,
160-
const bool verbatim) {
161-
warn(source_location, "", verbatim);
162-
}
163-
164-
void warn(
165-
SourceLocation source_location,
166-
const char* msg,
167-
const bool verbatim) {
168-
ThreadWarningHandler::get_handler()->process(source_location, msg, verbatim);
169-
}
170-
171150
void set_warning_handler(WarningHandler* handler) noexcept(true) {
172151
ThreadWarningHandler::set_handler(handler);
173152
}
@@ -195,14 +174,60 @@ WarnAlways::~WarnAlways() {
195174
set_warnAlways(prev_setting);
196175
}
197176

198-
} // namespace Warning
177+
} // namespace WarningUtils
178+
179+
void warn(const Warning& warning) {
180+
WarningUtils::ThreadWarningHandler::get_handler()->process(warning);
181+
}
199182

200-
void WarningHandler::process(
183+
Warning::Warning(
184+
warning_variant_t type,
201185
const SourceLocation& source_location,
202186
const std::string& msg,
203-
const bool /*verbatim*/) {
204-
LOG_AT_FILE_LINE(WARNING, source_location.file, source_location.line)
205-
<< "Warning: " << msg << " (function " << source_location.function << ")";
187+
const bool verbatim)
188+
: type_(type),
189+
source_location_(source_location),
190+
msg_(msg),
191+
verbatim_(verbatim) {}
192+
193+
Warning::Warning(
194+
warning_variant_t type,
195+
SourceLocation source_location,
196+
detail::CompileTimeEmptyString msg,
197+
const bool verbatim)
198+
: Warning(type, std::move(source_location), "", verbatim) {}
199+
200+
Warning::Warning(
201+
warning_variant_t type,
202+
SourceLocation source_location,
203+
const char* msg,
204+
const bool verbatim)
205+
: type_(type),
206+
source_location_(std::move(source_location)),
207+
msg_(std::string(msg)),
208+
verbatim_(verbatim) {}
209+
210+
Warning::warning_variant_t Warning::type() const {
211+
return type_;
212+
}
213+
214+
const SourceLocation& Warning::source_location() const {
215+
return source_location_;
216+
}
217+
218+
const std::string& Warning::msg() const {
219+
return msg_;
220+
}
221+
222+
bool Warning::verbatim() const {
223+
return verbatim_;
224+
}
225+
226+
void WarningHandler::process(const Warning& warning) {
227+
LOG_AT_FILE_LINE(
228+
WARNING, warning.source_location().file, warning.source_location().line)
229+
<< "Warning: " << warning.msg() << " (function "
230+
<< warning.source_location().function << ")";
206231
}
207232

208233
std::string GetExceptionString(const std::exception& e) {

c10/util/Exception.h

+82-56
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <c10/macros/Macros.h>
55
#include <c10/util/Deprecated.h>
66
#include <c10/util/StringUtil.h>
7+
#include <c10/util/variant.h>
78

89
#include <cstddef>
910
#include <exception>
@@ -112,17 +113,66 @@ class C10_API Error : public std::exception {
112113
std::string compute_what(bool include_backtrace) const;
113114
};
114115

116+
class C10_API Warning {
117+
public:
118+
class C10_API UserWarning {};
119+
class C10_API DeprecationWarning {};
120+
121+
using warning_variant_t = c10::variant<UserWarning, DeprecationWarning>;
122+
123+
Warning(
124+
warning_variant_t type,
125+
const SourceLocation& source_location,
126+
const std::string& msg,
127+
bool verbatim);
128+
129+
Warning(
130+
warning_variant_t type,
131+
SourceLocation source_location,
132+
const char* msg,
133+
bool verbatim);
134+
135+
Warning(
136+
warning_variant_t type,
137+
SourceLocation source_location,
138+
::c10::detail::CompileTimeEmptyString msg,
139+
bool verbatim);
140+
141+
// Getters for members
142+
warning_variant_t type() const;
143+
const SourceLocation& source_location() const;
144+
const std::string& msg() const;
145+
bool verbatim() const;
146+
147+
private:
148+
// The type of warning
149+
warning_variant_t type_;
150+
151+
// Where the warning happened.
152+
SourceLocation source_location_;
153+
154+
// The actual warning message.
155+
std::string msg_;
156+
157+
// See note: [Verbatim Warnings]
158+
bool verbatim_;
159+
};
160+
161+
using UserWarning = Warning::UserWarning;
162+
using DeprecationWarning = Warning::DeprecationWarning;
163+
164+
// Issue a warning with a given message. Dispatched to the current
165+
// warning handler.
166+
void C10_API warn(const Warning& warning);
167+
115168
class C10_API WarningHandler {
116169
public:
117170
virtual ~WarningHandler() = default;
118171
/// The default warning handler. Prints the message to stderr.
119-
virtual void process(
120-
const SourceLocation& source_location,
121-
const std::string& msg,
122-
const bool verbatim);
172+
virtual void process(const Warning& warning);
123173
};
124174

125-
namespace Warning {
175+
namespace WarningUtils {
126176

127177
// Note: [Verbatim Warnings]
128178
// Warnings originating in C++ code can appear out-of-place to Python users:
@@ -137,20 +187,6 @@ namespace Warning {
137187
// context in their warnings should set verbatim to true so their warnings
138188
// appear without modification.
139189

140-
/// Issue a warning with a given message. Dispatched to the current
141-
/// warning handler.
142-
C10_API void warn(
143-
const SourceLocation& source_location,
144-
const std::string& msg,
145-
bool verbatim);
146-
C10_API void warn(
147-
SourceLocation source_location,
148-
const char* msg,
149-
bool verbatim);
150-
C10_API void warn(
151-
SourceLocation source_location,
152-
::c10::detail::CompileTimeEmptyString msg,
153-
bool verbatim);
154190
/// Sets the global warning handler. This is not thread-safe, so it should
155191
/// generally be called once during initialization or while holding the GIL
156192
/// for programs that use python.
@@ -165,11 +201,11 @@ class C10_API WarningHandlerGuard {
165201

166202
public:
167203
WarningHandlerGuard(WarningHandler* new_handler)
168-
: prev_handler_(c10::Warning::get_warning_handler()) {
169-
c10::Warning::set_warning_handler(new_handler);
204+
: prev_handler_(c10::WarningUtils::get_warning_handler()) {
205+
c10::WarningUtils::set_warning_handler(new_handler);
170206
}
171207
~WarningHandlerGuard() {
172-
c10::Warning::set_warning_handler(prev_handler_);
208+
c10::WarningUtils::set_warning_handler(prev_handler_);
173209
}
174210
};
175211

@@ -190,7 +226,7 @@ struct C10_API WarnAlways {
190226
bool prev_setting;
191227
};
192228

193-
} // namespace Warning
229+
} // namespace WarningUtils
194230

195231
// Used in ATen for out-of-bound indices that can reasonably only be detected
196232
// lazily inside a kernel (See: advanced indexing). These turn into
@@ -516,53 +552,43 @@ namespace detail {
516552
#define TORCH_CHECK_NOT_IMPLEMENTED(cond, ...) \
517553
TORCH_CHECK_WITH_MSG(NotImplementedError, cond, "TYPE", __VA_ARGS__)
518554

555+
#ifdef STRIP_ERROR_MESSAGES
556+
#define WARNING_MESSAGE_STRING(...) \
557+
::c10::detail::CompileTimeEmptyString {}
558+
#else
559+
#define WARNING_MESSAGE_STRING(...) ::c10::str(__VA_ARGS__)
560+
#endif
561+
519562
// Report a warning to the user. Accepts an arbitrary number of extra
520563
// arguments which are concatenated into the warning message using operator<<
521564
//
522-
#ifdef STRIP_ERROR_MESSAGES
523-
#define TORCH_WARN(...) \
524-
::c10::Warning::warn( \
565+
#define TORCH_WARN_WITH(warning_t, ...) \
566+
::c10::warn(::c10::Warning( \
567+
warning_t(), \
525568
{__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, \
526-
::c10::detail::CompileTimeEmptyString{}, \
527-
false)
528-
#else
529-
#define TORCH_WARN(...) \
530-
::c10::Warning::warn( \
531-
{__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, \
532-
::c10::str(__VA_ARGS__), \
533-
false)
534-
#endif
569+
WARNING_MESSAGE_STRING(__VA_ARGS__), \
570+
false));
571+
572+
#define TORCH_WARN(...) TORCH_WARN_WITH(::c10::UserWarning, __VA_ARGS__);
573+
574+
#define TORCH_WARN_DEPRECATION(...) \
575+
TORCH_WARN_WITH(::c10::DeprecationWarning, __VA_ARGS__);
535576

536577
// Report a warning to the user only once. Accepts an arbitrary number of extra
537578
// arguments which are concatenated into the warning message using operator<<
538579
//
539-
#ifdef STRIP_ERROR_MESSAGES
540-
#define _TORCH_WARN_ONCE(...) \
541-
C10_UNUSED static const auto C10_ANONYMOUS_VARIABLE(torch_warn_once_) = \
542-
[&] { \
543-
::c10::Warning::warn( \
544-
{__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, \
545-
::c10::detail::CompileTimeEmptyString{}, \
546-
false); \
547-
return true; \
548-
}()
549-
#else
550580
#define _TORCH_WARN_ONCE(...) \
551581
C10_UNUSED static const auto C10_ANONYMOUS_VARIABLE(torch_warn_once_) = \
552582
[&] { \
553-
::c10::Warning::warn( \
554-
{__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, \
555-
::c10::str(__VA_ARGS__), \
556-
false); \
583+
TORCH_WARN(__VA_ARGS__); \
557584
return true; \
558585
}()
559-
#endif
560586

561-
#define TORCH_WARN_ONCE(...) \
562-
if (::c10::Warning::get_warnAlways()) { \
563-
TORCH_WARN(__VA_ARGS__); \
564-
} else { \
565-
_TORCH_WARN_ONCE(__VA_ARGS__); \
587+
#define TORCH_WARN_ONCE(...) \
588+
if (::c10::WarningUtils::get_warnAlways()) { \
589+
TORCH_WARN(__VA_ARGS__); \
590+
} else { \
591+
_TORCH_WARN_ONCE(__VA_ARGS__); \
566592
}
567593

568594
// Report an error with a specific argument

test/cpp/api/autograd.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1154,7 +1154,7 @@ TEST(CustomAutogradTest, BackwardWithNonLeafInputs) {
11541154
}
11551155

11561156
TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) {
1157-
c10::Warning::WarnAlways guard(true);
1157+
c10::WarningUtils::WarnAlways guard(true);
11581158

11591159
torch::Tensor x = torch::randn({5, 5}).set_requires_grad(true);
11601160
auto z = x * x;

test/cpp/api/inference_mode.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ TEST(InferenceModeTest, TestCustomFunction) {
648648
}
649649

650650
TEST(InferenceModeTest, TestLegacyAutoNonVariableTypeModeWarning) {
651-
c10::Warning::WarnAlways warn_always(true);
651+
c10::WarningUtils::WarnAlways warn_always(true);
652652
WarningCapture warnings;
653653
at::AutoNonVariableTypeMode guard;
654654
ASSERT_TRUE(

test/cpp/api/support.h

+5-8
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@ struct SeedingFixture : public ::testing::Test {
3737
};
3838

3939
struct WarningCapture : public WarningHandler {
40-
WarningCapture() : prev_(Warning::get_warning_handler()) {
41-
Warning::set_warning_handler(this);
40+
WarningCapture() : prev_(WarningUtils::get_warning_handler()) {
41+
WarningUtils::set_warning_handler(this);
4242
}
4343

4444
~WarningCapture() {
45-
Warning::set_warning_handler(prev_);
45+
WarningUtils::set_warning_handler(prev_);
4646
}
4747

4848
const std::vector<std::string>& messages() {
@@ -53,11 +53,8 @@ struct WarningCapture : public WarningHandler {
5353
return c10::Join("\n", messages_);
5454
}
5555

56-
void process(
57-
const SourceLocation& source_location,
58-
const std::string& msg,
59-
const bool /*verbatim*/) override {
60-
messages_.push_back(msg);
56+
void process(const c10::Warning& warning) override {
57+
messages_.push_back(warning.msg());
6158
}
6259

6360
private:

test/test_torch.py

+21
Original file line numberDiff line numberDiff line change
@@ -5683,6 +5683,27 @@ def test_pytorch_library_disabled_env(self):
56835683
except subprocess.CalledProcessError as e:
56845684
raise RuntimeError("Could not 'import torch' with PYTORCH_DISABLE_LIBRARY=0") from e
56855685

5686+
# Test that warnings generated from C++ are translated to the correct type
5687+
def test_warn_types(self):
5688+
test_cases = [
5689+
# function, warning type, message
5690+
(torch._C._warn, UserWarning, r"Test message for TORCH_WARN"),
5691+
(torch._C._warn_deprecation, DeprecationWarning, r"Test message for TORCH_WARN_DEPRECATION"),
5692+
]
5693+
5694+
for fn, warning_type, message in test_cases:
5695+
with warnings.catch_warnings(record=True) as w:
5696+
warnings.resetwarnings()
5697+
warnings.filterwarnings('always', category=warning_type)
5698+
fn()
5699+
5700+
self.assertEqual(len(w), 1, msg=f'{warning_type} not raised')
5701+
warning = w[0].message
5702+
self.assertTrue(isinstance(warning, warning_type), msg=f'{warning_type} not raised')
5703+
self.assertTrue(re.search(
5704+
message,
5705+
str(warning)))
5706+
56865707
def test_structseq_repr(self):
56875708
a = torch.arange(250).reshape(5, 5, 10)
56885709
expected = """

0 commit comments

Comments
 (0)