@@ -113,6 +113,53 @@ class PatchingSpec:
113
113
op_wrapper : Optional [Callable ] = None
114
114
115
115
116
+ # An ONNX-export-compatible version of `tensor.unfold`. Without this, we get:
117
+ # torch.onnx.errors.SymbolicValueError: Unsupported: ONNX export of operator Unfold, input size not accessible.
118
+ # See https://github.com/pytorch/pytorch/issues/81871 for more information
119
+ def onnx_compatible_unfold (input_tensor , dimension , size , step ):
120
+ """
121
+ Custom implementation of torch.unfold without using torch.unfold.
122
+
123
+ Args:
124
+ input_tensor (torch.Tensor): The input tensor.
125
+ dimension (int): The dimension to unfold.
126
+ size (int): The size of each slice.
127
+ step (int): The step size between slices.
128
+
129
+ Returns:
130
+ torch.Tensor: The unfolded tensor.
131
+ """
132
+ # Check if dimension is within the valid range
133
+ if not (- input_tensor .dim () <= dimension < input_tensor .dim ()):
134
+ raise ValueError (
135
+ f"Dimension out of range (expected to be in range of [{ - input_tensor .dim ()} , { input_tensor .dim () - 1 } ], but got { dimension } )"
136
+ )
137
+
138
+ # Normalize negative dimension
139
+ dimension = dimension % input_tensor .dim ()
140
+
141
+ # Compute the shape of the unfolded output
142
+ input_size = input_tensor .size (dimension )
143
+ num_slices = (input_size - size ) // step + 1
144
+
145
+ # Permute dimension to the end for easier indexing
146
+ input_tensor = input_tensor .transpose (dimension , - 1 )
147
+
148
+ # Extract slices
149
+ slices = []
150
+ for i in range (num_slices ):
151
+ start = i * step
152
+ end = start + size
153
+ slices .append (input_tensor [..., start :end ])
154
+
155
+ # Stack slices and permute dimensions back
156
+ result = torch .stack (slices , dim = - 2 ).transpose (dimension , - 2 )
157
+ return result
158
+
159
+
160
+ UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec (torch .Tensor , "unfold" , onnx_compatible_unfold , torch .Tensor .unfold )]
161
+
162
+
116
163
class ModelPatcher :
117
164
def __init__ (
118
165
self ,
@@ -122,9 +169,11 @@ def __init__(
122
169
):
123
170
self ._model = model
124
171
125
- patching_specs = config .PATCHING_SPECS
172
+ patching_specs = config .PATCHING_SPECS or []
173
+ patching_specs .extend (UNSUPPORTED_OPS_PATCHING_SPEC )
174
+
126
175
self ._patching_specs = []
127
- for spec in patching_specs if patching_specs is not None else [] :
176
+ for spec in patching_specs :
128
177
final_spec = spec
129
178
if spec .orig_op is None :
130
179
final_spec = dataclasses .replace (spec , orig_op = getattr (spec .o , spec .name ))
0 commit comments