Skip to content

Commit 68a9853

Browse files
kshitij12345pytorchmergebot
authored andcommitted
[fix] nn c++ : segfault in modulelist and moduledict (pytorch#93074)
Fixes pytorch#73565 Pull Request resolved: pytorch#93074 Approved by: https://github.com/albanD
1 parent 219e953 commit 68a9853

File tree

4 files changed

+46
-4
lines changed

4 files changed

+46
-4
lines changed

test/cpp/api/moduledict.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,11 @@ TEST_F(ModuleDictTest, PrettyPrintModuleDict) {
299299
" (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
300300
")");
301301
}
302+
303+
TEST_F(ModuleDictTest, InvalidAt) {
304+
torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
305+
{"linear", Linear(10, 3).ptr()}};
306+
ModuleDict dict(ordereddict);
307+
ASSERT_THROWS_WITH(
308+
dict->at<torch::nn::Dropout2dImpl>("linear"), "Unable to cast module");
309+
}

test/cpp/api/modulelist.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,9 @@ TEST_F(ModuleListTest, RangeBasedForLoop) {
300300
module->pretty_print(buffer);
301301
}
302302
}
303+
304+
TEST_F(ModuleListTest, InvalidAt) {
305+
torch::nn::ModuleList m(torch::nn::Linear(1, 2));
306+
ASSERT_THROWS_WITH(
307+
m->at<torch::nn::Dropout2dImpl>(0), "Unable to cast module");
308+
}

torch/csrc/api/include/torch/nn/modules/container/moduledict.h

+16-2
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,14 @@ class ModuleDictImpl : public Cloneable<ModuleDictImpl> {
178178
static_assert(
179179
torch::detail::is_module<T>::value,
180180
"Can only call ModuleList::at with an nn::Module type");
181-
return *modules_[key]->as<T>();
181+
auto module = modules_[key]->as<T>();
182+
TORCH_CHECK(
183+
module,
184+
"Unable to cast module[",
185+
key,
186+
"] to ",
187+
c10::demangle(typeid(T).name()));
188+
return *module;
182189
}
183190

184191
/// Attempts to return the module at the given key as the requested type.
@@ -189,7 +196,14 @@ class ModuleDictImpl : public Cloneable<ModuleDictImpl> {
189196
static_assert(
190197
torch::detail::is_module<T>::value,
191198
"Can only call ModuleList::at with an nn::Module type");
192-
return *modules_[key]->as<T>();
199+
const auto module = modules_[key]->as<T>();
200+
TORCH_CHECK(
201+
module,
202+
"Unable to cast module[",
203+
key,
204+
"] to ",
205+
c10::demangle(typeid(T).name()));
206+
return *module;
193207
}
194208

195209
/// Removes and returns the `Module` associated with the given `key`.

torch/csrc/api/include/torch/nn/modules/container/modulelist.h

+16-2
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,14 @@ class ModuleListImpl : public Cloneable<ModuleListImpl> {
147147
torch::detail::is_module<T>::value,
148148
"Can only call ModuleList::at with an nn::Module type");
149149
TORCH_CHECK(index < size(), "Index out of range");
150-
return *modules_[index]->as<T>();
150+
auto module = modules_[index]->as<T>();
151+
TORCH_CHECK(
152+
module,
153+
"Unable to cast module[",
154+
index,
155+
"] to ",
156+
c10::demangle(typeid(T).name()));
157+
return *module;
151158
}
152159

153160
/// Attempts to return the module at the given index as the requested type.
@@ -159,7 +166,14 @@ class ModuleListImpl : public Cloneable<ModuleListImpl> {
159166
torch::detail::is_module<T>::value,
160167
"Can only call ModuleList::at with an nn::Module type");
161168
TORCH_CHECK(index < size(), "Index out of range");
162-
return *modules_[index]->as<T>();
169+
const auto module = modules_[index]->as<T>();
170+
TORCH_CHECK(
171+
module,
172+
"Unable to cast module[",
173+
index,
174+
"] to ",
175+
c10::demangle(typeid(T).name()));
176+
return *module;
163177
}
164178

165179
/// Attempts to return a `std::shared_ptr` whose dynamic type is that of the

0 commit comments

Comments
 (0)