@@ -94,11 +94,12 @@ def get_absolute_positions(self, abs_pos_embeddings, has_cls_token, height, widt
94
94
if has_cls_token :
95
95
abs_pos_embeddings = abs_pos_embeddings [:, 1 :]
96
96
num_position = abs_pos_embeddings .shape [1 ]
97
- size = int (math .sqrt (num_position ))
97
+ size = int (math .sqrt (num_position )) # This is a constant and can be recorded as such in the ONNX export.
98
98
if size * size != num_position :
99
99
raise ValueError ("Absolute position embeddings must be a square number." )
100
100
101
- if size != height or size != width :
101
+ if torch .jit .is_tracing () or (size != height or size != width ):
102
+ # nn.functional.interpolate is a noop in case size == height and size == width - we need to always capture this path with jit.trace.
102
103
new_abs_pos_embeddings = nn .functional .interpolate (
103
104
abs_pos_embeddings .reshape (1 , size , size , - 1 ).permute (0 , 3 , 1 , 2 ),
104
105
size = (height , width ),
@@ -132,6 +133,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
132
133
return embeddings
133
134
134
135
136
+ @torch .jit .script_if_tracing # nn.functional.interpolate's `size` needs to be dynamic.
135
137
def get_rel_pos (q_size , k_size , rel_pos ):
136
138
"""
137
139
Get relative positional embeddings according to the relative positions of query and key sizes.
@@ -399,21 +401,23 @@ def window_partition(hidden_state, window_size):
399
401
Returns:
400
402
`tuple(torch.FloatTensor)` comprising various elements:
401
403
- windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels].
402
- - (patch_height, patch_width ): padded height and width before partition
404
+ - (padded_height, padded_width ): padded height and width before partition
403
405
"""
404
406
batch_size , height , width , num_channels = hidden_state .shape
405
407
406
408
pad_height = (window_size - height % window_size ) % window_size
407
409
pad_width = (window_size - width % window_size ) % window_size
408
- if pad_height > 0 or pad_width > 0 :
409
- hidden_state = nn .functional .pad (hidden_state , (0 , 0 , 0 , pad_width , 0 , pad_height ))
410
- patch_height , patch_width = height + pad_height , width + pad_width
410
+
411
+ # Noop in case pad_width == 0 and pad_height == 0.
412
+ hidden_state = nn .functional .pad (hidden_state , (0 , 0 , 0 , pad_width , 0 , pad_height ))
413
+
414
+ padded_height , padded_width = height + pad_height , width + pad_width
411
415
412
416
hidden_state = hidden_state .view (
413
- batch_size , patch_height // window_size , window_size , patch_width // window_size , window_size , num_channels
417
+ batch_size , padded_height // window_size , window_size , padded_width // window_size , window_size , num_channels
414
418
)
415
419
windows = hidden_state .permute (0 , 1 , 3 , 2 , 4 , 5 ).contiguous ().view (- 1 , window_size , window_size , num_channels )
416
- return windows , (patch_height , patch_width )
420
+ return windows , (padded_height , padded_width )
417
421
418
422
419
423
def window_unpartition (windows , window_size , pad_height_width , height_width ):
@@ -426,23 +430,24 @@ def window_unpartition(windows, window_size, pad_height_width, height_width):
426
430
window_size (`int`):
427
431
Window size.
428
432
pad_height_width (`Tuple[int]`):
429
- Padded height and width (patch_height, patch_width ).
433
+ Padded height and width (padded_height, padded_width ).
430
434
height_width (`Tuple[int]`):
431
435
Original height and width before padding.
432
436
433
437
Returns:
434
438
hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels].
435
439
"""
436
- patch_height , patch_width = pad_height_width
440
+ padded_height , padded_width = pad_height_width
437
441
height , width = height_width
438
- batch_size = windows .shape [0 ] // (patch_height * patch_width // window_size // window_size )
442
+ batch_size = windows .shape [0 ] // (padded_height * padded_width // window_size // window_size )
439
443
hidden_state = windows .view (
440
- batch_size , patch_height // window_size , patch_width // window_size , window_size , window_size , - 1
444
+ batch_size , padded_height // window_size , padded_width // window_size , window_size , window_size , - 1
441
445
)
442
- hidden_state = hidden_state .permute (0 , 1 , 3 , 2 , 4 , 5 ).contiguous ().view (batch_size , patch_height , patch_width , - 1 )
446
+ hidden_state = hidden_state .permute (0 , 1 , 3 , 2 , 4 , 5 ).contiguous ()
447
+ hidden_state = hidden_state .view (batch_size , padded_height , padded_width , - 1 )
443
448
444
- if patch_height > height or patch_width > width :
445
- hidden_state = hidden_state [:, :height , :width , :].contiguous ()
449
+ # We always have height <= padded_height and width <= padded_width
450
+ hidden_state = hidden_state [:, :height , :width , :].contiguous ()
446
451
return hidden_state
447
452
448
453
0 commit comments