Skip to content

Commit ef38e2a

Browse files
authored
Make vitdet jit trace complient (#30065)
* remove controlflows * style * rename patch_ to padded_ following review comment * style
1 parent a71def0 commit ef38e2a

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

src/transformers/models/vitdet/modeling_vitdet.py

+20-15
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,12 @@ def get_absolute_positions(self, abs_pos_embeddings, has_cls_token, height, widt
9494
if has_cls_token:
9595
abs_pos_embeddings = abs_pos_embeddings[:, 1:]
9696
num_position = abs_pos_embeddings.shape[1]
97-
size = int(math.sqrt(num_position))
97+
size = int(math.sqrt(num_position)) # This is a constant and can be recorded as such in the ONNX export.
9898
if size * size != num_position:
9999
raise ValueError("Absolute position embeddings must be a square number.")
100100

101-
if size != height or size != width:
101+
if torch.jit.is_tracing() or (size != height or size != width):
102+
# nn.functional.interpolate is a noop in case size == height and size == width - we need to always capture this path with jit.trace.
102103
new_abs_pos_embeddings = nn.functional.interpolate(
103104
abs_pos_embeddings.reshape(1, size, size, -1).permute(0, 3, 1, 2),
104105
size=(height, width),
@@ -132,6 +133,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
132133
return embeddings
133134

134135

136+
@torch.jit.script_if_tracing # nn.functional.interpolate's `size` needs to be dynamic.
135137
def get_rel_pos(q_size, k_size, rel_pos):
136138
"""
137139
Get relative positional embeddings according to the relative positions of query and key sizes.
@@ -399,21 +401,23 @@ def window_partition(hidden_state, window_size):
399401
Returns:
400402
`tuple(torch.FloatTensor)` comprising various elements:
401403
- windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels].
402-
- (patch_height, patch_width): padded height and width before partition
404+
- (padded_height, padded_width): padded height and width before partition
403405
"""
404406
batch_size, height, width, num_channels = hidden_state.shape
405407

406408
pad_height = (window_size - height % window_size) % window_size
407409
pad_width = (window_size - width % window_size) % window_size
408-
if pad_height > 0 or pad_width > 0:
409-
hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
410-
patch_height, patch_width = height + pad_height, width + pad_width
410+
411+
# Noop in case pad_width == 0 and pad_height == 0.
412+
hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
413+
414+
padded_height, padded_width = height + pad_height, width + pad_width
411415

412416
hidden_state = hidden_state.view(
413-
batch_size, patch_height // window_size, window_size, patch_width // window_size, window_size, num_channels
417+
batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels
414418
)
415419
windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
416-
return windows, (patch_height, patch_width)
420+
return windows, (padded_height, padded_width)
417421

418422

419423
def window_unpartition(windows, window_size, pad_height_width, height_width):
@@ -426,23 +430,24 @@ def window_unpartition(windows, window_size, pad_height_width, height_width):
426430
window_size (`int`):
427431
Window size.
428432
pad_height_width (`Tuple[int]`):
429-
Padded height and width (patch_height, patch_width).
433+
Padded height and width (padded_height, padded_width).
430434
height_width (`Tuple[int]`):
431435
Original height and width before padding.
432436
433437
Returns:
434438
hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels].
435439
"""
436-
patch_height, patch_width = pad_height_width
440+
padded_height, padded_width = pad_height_width
437441
height, width = height_width
438-
batch_size = windows.shape[0] // (patch_height * patch_width // window_size // window_size)
442+
batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size)
439443
hidden_state = windows.view(
440-
batch_size, patch_height // window_size, patch_width // window_size, window_size, window_size, -1
444+
batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1
441445
)
442-
hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, patch_height, patch_width, -1)
446+
hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous()
447+
hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1)
443448

444-
if patch_height > height or patch_width > width:
445-
hidden_state = hidden_state[:, :height, :width, :].contiguous()
449+
# We always have height <= padded_height and width <= padded_width
450+
hidden_state = hidden_state[:, :height, :width, :].contiguous()
446451
return hidden_state
447452

448453

0 commit comments

Comments
 (0)