Skip to content

Commit 8564ed2

Browse files
sywangyipytorchmergebot
authored andcommitted
do not need to check if element in dict input is Tensor. (pytorch#97866)
sometimes it's a tuple with tensor element such as past value key in text generation case Fixes pytorch#97229 Pull Request resolved: pytorch#97866 Approved by: https://github.com/jgong5, https://github.com/davidberard98
1 parent 794f6e5 commit 8564ed2

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

test/jit/test_tracer.py

+16
Original file line numberDiff line numberDiff line change
@@ -1973,6 +1973,22 @@ def forward(self, x, y, **deprecated_arguments):
19731973
m2 = torch.jit.trace(model, (torch.ones(1), torch.ones(1)))
19741974
m3 = torch.jit.trace(model, example_kwarg_inputs={'x': torch.ones(1), "y": torch.ones(1)}, strict=False)
19751975

1976+
def test_trace_with_tuple_tensor(self):
1977+
class MyClass(torch.nn.Module):
1978+
def __init__(self):
1979+
super(MyClass, self).__init__()
1980+
1981+
def forward(self, x, y):
1982+
return x + y[0] + y[1]
1983+
1984+
model = MyClass()
1985+
traced_model = torch.jit.trace(model, (torch.ones(1), (torch.ones(1), torch.ones(1))))
1986+
input_dict = {"x": torch.tensor([2, 3]), "y": (torch.tensor([5, 6]), torch.tensor([7, 8]))}
1987+
self.assertEqual(model(**input_dict), traced_model(**input_dict))
1988+
traced_model = torch.jit.trace(model, example_kwarg_inputs={
1989+
'x': torch.ones(1), "y": (torch.ones(1), torch.ones(1))})
1990+
self.assertEqual(model(**input_dict), traced_model(**input_dict))
1991+
19761992

19771993
class TestMixTracingScripting(JitTestCase):
19781994
def test_trace_script(self):

torch/csrc/jit/python/python_tracer.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,8 @@ std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracingWithDict(
102102
for (const auto& compact_argument_name : compact_argument_names) {
103103
for (auto it = inputs_dict.begin(); it != inputs_dict.end(); it++) {
104104
if (py::cast<std::string>(it->first) == compact_argument_name) {
105-
if (THPVariable_Check(it->second.ptr())) {
106-
compact_trace_inputs.push_back(
107-
toIValue(it->second, tryToInferType(it->second).type()));
108-
}
105+
compact_trace_inputs.push_back(
106+
toIValue(it->second, tryToInferType(it->second).type()));
109107
}
110108
}
111109
}

0 commit comments

Comments
 (0)