Skip to content

Commit 489c66f

Browse files
chunyuan-wpytorchmergebot
authored andcommitted
[AOTI] fix pointer_to_list (pytorch#138806)
Fixes the `pointer_to_list` function to take `*(ptr + i)` instead of `*ptr`. This fixes the runtime error when running INT8 yolo-v7. Pull Request resolved: pytorch#138806 Approved by: https://github.com/jgong5, https://github.com/desertfire ghstack dependencies: pytorch#138691
1 parent 9af1816 commit 489c66f

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

test/inductor/test_cpu_cpp_wrapper.py

+6
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,12 @@ class BaseTest(NamedTuple):
303303
for func in dir(test_mkldnn_pattern_matcher.TestPatternMatcher())
304304
if func.startswith("test_qlinear")
305305
],
306+
BaseTest(
307+
"test_qconv2d_with_concat",
308+
"cpu",
309+
test_mkldnn_pattern_matcher.TestPatternMatcher(),
310+
condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS,
311+
),
306312
BaseTest(
307313
"test_dynamic_qlinear",
308314
"cpu",

test/inductor/test_mkldnn_pattern_matcher.py

+53
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,59 @@ def matcher_check_fn():
10891089
matcher_check_fn=matcher_check_fn,
10901090
)
10911091

1092+
@skipIfNoDynamoSupport
1093+
@skipIfNoONEDNN
1094+
def test_qconv2d_with_concat_cpu(self):
1095+
channel_1 = 32
1096+
channel_2 = 16
1097+
channel_3 = 8
1098+
channel_4 = int(channel_2 * 2 + channel_3)
1099+
1100+
class Model(torch.nn.Module):
1101+
def __init__(
1102+
self,
1103+
):
1104+
super().__init__()
1105+
self.conv1 = torch.nn.Conv2d(
1106+
channel_1, channel_2, 1, stride=1, dilation=1, padding=0
1107+
)
1108+
self.conv2 = torch.nn.Conv2d(
1109+
channel_1, channel_2, 1, stride=1, dilation=1, padding=0
1110+
)
1111+
self.conv3 = torch.nn.Conv2d(
1112+
channel_2, channel_3, 3, stride=1, dilation=1, padding=1
1113+
)
1114+
1115+
self.conv = torch.nn.Conv2d(
1116+
channel_4, channel_2, 1, stride=1, dilation=1, padding=0
1117+
)
1118+
1119+
def forward(self, x: torch.Tensor):
1120+
x1 = self.conv1(x)
1121+
x2 = self.conv2(x)
1122+
x3 = self.conv3(x2)
1123+
res = torch.cat([x1, x2, x3], dim=1)
1124+
res = self.conv(res)
1125+
return res
1126+
1127+
mod = Model().eval()
1128+
v = torch.randn(
1129+
(8, channel_1, 40, 40), dtype=torch.float32, requires_grad=False
1130+
)
1131+
1132+
def matcher_check_fn():
1133+
self.assertEqual(
1134+
counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 4
1135+
)
1136+
self.assertEqual(counters["inductor"]["qconv2d_unary_matcher_count"], 3)
1137+
1138+
self._test_common(
1139+
mod,
1140+
(v,),
1141+
check_quantization=True,
1142+
matcher_check_fn=matcher_check_fn,
1143+
)
1144+
10921145
@skipIfNoDynamoSupport
10931146
@skipIfNoONEDNN
10941147
def test_qconv2d_add_2(self):

torch/csrc/inductor/aoti_torch/utils.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ inline std::vector<at::Tensor> pointer_to_list(
148148
std::vector<at::Tensor> result;
149149
result.reserve(len);
150150
for (int64_t i = 0; i < len; i++) {
151-
result.emplace_back(*tensor_handle_to_tensor_pointer(*ptr));
151+
result.emplace_back(*tensor_handle_to_tensor_pointer(ptr[i]));
152152
}
153153
return result;
154154
}

0 commit comments

Comments
 (0)