Skip to content

Commit a635246

Browse files
Thiago Crepaldipytorchmergebot
Thiago Crepaldi
authored andcommitted
[ONNX] Add col2im for opset 18 (pytorch#84594)
Opset 18 will be used to introduce suport for ONNX's Col2Im-18 and resolve pytorch#84408 Depends: pytorch#83201 (CI will fail until ONNX submodule is updated) as per Faith recommendation, this PR should be merged post ORT 1.13 only Pull Request resolved: pytorch#84594 Approved by: https://github.com/justinchuby, https://github.com/titaiwangms, https://github.com/abock, https://github.com/BowenBao
1 parent ea98ba0 commit a635246

File tree

6 files changed

+111
-3
lines changed

6 files changed

+111
-3
lines changed

test/onnx/test_pytorch_onnx_no_runtime.py

+33
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,39 @@ def forward(self, x):
11561156
dim,
11571157
)
11581158

1159+
def test_col2im(self):
1160+
# This test can be moved to test/onnx/test_pytorch_onnx_onnxruntime.py when ORT implement ::Col2Im
1161+
1162+
# Random batched RGB 32x32 image-shaped input tensor of batch size 64
1163+
original_image_inputs = torch.randn((64, 3, 32, 32))
1164+
output_size = tuple(original_image_inputs.shape[2:])
1165+
kernel_size = (1, 2)
1166+
dilation = 3
1167+
padding = 2
1168+
stride = 1
1169+
model_im2col = torch.nn.Unfold(
1170+
kernel_size, dilation=dilation, padding=padding, stride=stride
1171+
)
1172+
blocks = model_im2col(original_image_inputs)
1173+
1174+
model = torch.nn.Fold(
1175+
output_size=output_size,
1176+
kernel_size=kernel_size,
1177+
dilation=dilation,
1178+
padding=padding,
1179+
stride=stride,
1180+
)
1181+
f = io.BytesIO()
1182+
torch.onnx.export(model, (blocks,), f, opset_version=18)
1183+
1184+
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
1185+
self.assertEqual(onnx_model.graph.node[-1].op_type, "Col2Im")
1186+
self.assertEqual(onnx_model.graph.node[-1].domain, "")
1187+
self.assertEqual(len(onnx_model.graph.node[-1].input), 3)
1188+
self.assertEqual(onnx_model.graph.node[-1].attribute[0].name, "dilations")
1189+
self.assertEqual(onnx_model.graph.node[-1].attribute[1].name, "pads")
1190+
self.assertEqual(onnx_model.graph.node[-1].attribute[2].name, "strides")
1191+
11591192

11601193
class TestQuantizeEagerONNXExport(common_utils.TestCase):
11611194
def _test_lower_graph_impl(self, model, data):

test/onnx/test_pytorch_onnx_onnxruntime.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@
4444
# The min onnx opset version to test for
4545
MIN_ONNX_OPSET_VERSION = 9
4646
# The max onnx opset version to test for
47-
MAX_ONNX_OPSET_VERSION = _constants.ONNX_MAX_OPSET
47+
MAX_ONNX_OPSET_VERSION = (
48+
_constants.ONNX_MAX_OPSET - 1
49+
) # TODO: ORT does not support opset 18 yet
4850

4951

5052
def _init_test_generalized_rcnn_transform():

torch/csrc/jit/serialization/export.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ namespace onnx_torch = ::torch::onnx;
5959
namespace onnx = ::ONNX_NAMESPACE;
6060

6161
const static int kInvalidOpsetVersion = -1;
62-
const static int kMainOpsetVersion = 17;
62+
const static int kMainOpsetVersion = 18;
6363
// Based on OP_SET_ID_VERSION_MAP in
6464
// https://github.com/onnx/onnx/blob/master/onnx/helper.py.
6565
constexpr static std::array<int64_t, kMainOpsetVersion + 1>
@@ -82,6 +82,7 @@ constexpr static std::array<int64_t, kMainOpsetVersion + 1>
8282
8, // opset 15
8383
8, // opset 16
8484
8, // opset 17
85+
8, // opset 18
8586
};
8687

8788
std::string getNodeStackTraceString(const Node* n) {

torch/onnx/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
symbolic_opset15,
2626
symbolic_opset16,
2727
symbolic_opset17,
28+
symbolic_opset18,
2829
utils,
2930
)
3031

@@ -62,6 +63,7 @@
6263
"symbolic_opset15",
6364
"symbolic_opset16",
6465
"symbolic_opset17",
66+
"symbolic_opset18",
6567
# Enums
6668
"ExportTypes",
6769
"OperatorExportTypes",

torch/onnx/_constants.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
ONNX_BASE_OPSET = 9
66
ONNX_MIN_OPSET = 7
7-
ONNX_MAX_OPSET = 17
7+
ONNX_MAX_OPSET = 18
88
# ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py
99
ONNX_DEFAULT_OPSET = 14
1010
ONNX_CONSTANT_FOLDING_MIN_OPSET = 9

torch/onnx/symbolic_opset18.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""This file exports ONNX ops for opset 18.
2+
3+
Note [ONNX Operators that are added/updated in opset 18]
4+
5+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6+
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set
7+
New operators:
8+
CenterCropPad
9+
Col2Im
10+
Mish
11+
OptionalGetElement
12+
OptionalHasElement
13+
Pad
14+
Resize
15+
ScatterElements
16+
ScatterND
17+
"""
18+
19+
import functools
20+
from typing import Sequence
21+
22+
from torch import _C
23+
from torch.onnx import symbolic_helper
24+
from torch.onnx._internal import _beartype, registration
25+
26+
# EDITING THIS FILE? READ THIS FIRST!
27+
# see Note [Edit Symbolic Files] in symbolic_helper.py
28+
29+
__all__ = ["col2im"]
30+
31+
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
32+
33+
34+
@_onnx_symbolic("aten::col2im")
35+
@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
36+
@_beartype.beartype
37+
def col2im(
38+
g,
39+
input: _C.Value,
40+
output_size: _C.Value,
41+
kernel_size: _C.Value,
42+
dilation: Sequence[int],
43+
padding: Sequence[int],
44+
stride: Sequence[int],
45+
):
46+
# convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
47+
adjusted_padding = []
48+
for pad in padding:
49+
for _ in range(2):
50+
adjusted_padding.append(pad)
51+
52+
num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
53+
if not adjusted_padding:
54+
adjusted_padding = [0, 0] * num_dimensional_axis
55+
56+
if not dilation:
57+
dilation = [1] * num_dimensional_axis
58+
59+
if not stride:
60+
stride = [1] * num_dimensional_axis
61+
62+
return g.op(
63+
"Col2Im",
64+
input,
65+
output_size,
66+
kernel_size,
67+
dilations_i=dilation,
68+
pads_i=adjusted_padding,
69+
strides_i=stride,
70+
)

0 commit comments

Comments
 (0)