Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make vitdet jit trace complient #30065

Merged
merged 4 commits into from
Apr 8, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/transformers/models/vitdet/modeling_vitdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ def get_absolute_positions(self, abs_pos_embeddings, has_cls_token, height, widt
if has_cls_token:
abs_pos_embeddings = abs_pos_embeddings[:, 1:]
num_position = abs_pos_embeddings.shape[1]
size = int(math.sqrt(num_position))
size = int(math.sqrt(num_position)) # This is a constant and can be recorded as such in the ONNX export.
if size * size != num_position:
raise ValueError("Absolute position embeddings must be a square number.")

if size != height or size != width:
if torch.jit.is_tracing() or (size != height or size != width):
# nn.functional.interpolate is a noop in case size == height and size == width - we need to always capture this path with jit.trace.
new_abs_pos_embeddings = nn.functional.interpolate(
abs_pos_embeddings.reshape(1, size, size, -1).permute(0, 3, 1, 2),
size=(height, width),
Expand Down Expand Up @@ -131,7 +132,7 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:

return embeddings


@torch.jit.script_if_tracing # nn.functional.interpolate's `size` needs to be dynamic.
def get_rel_pos(q_size, k_size, rel_pos):
"""
Get relative positional embeddings according to the relative positions of query and key sizes.
Expand Down Expand Up @@ -405,8 +406,7 @@ def window_partition(hidden_state, window_size):

pad_height = (window_size - height % window_size) % window_size
pad_width = (window_size - width % window_size) % window_size
if pad_height > 0 or pad_width > 0:
hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height))
hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height)) # Noop in case pad_width == 0 and pad_height == 0.
patch_height, patch_width = height + pad_height, width + pad_width

hidden_state = hidden_state.view(
Expand Down Expand Up @@ -441,8 +441,8 @@ def window_unpartition(windows, window_size, pad_height_width, height_width):
)
hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, patch_height, patch_width, -1)

if patch_height > height or patch_width > width:
hidden_state = hidden_state[:, :height, :width, :].contiguous()
# We always have height <= patch_height and width <= patch_width
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was a victim of terrible variable naming - it should really be padded_height and padded_width. Patch height should always be less that or equal to height. I know it's not related to tracing, but could you update these for the sake of future readers?

hidden_state = hidden_state[:, :height, :width, :].contiguous()
return hidden_state


Expand Down
Loading