-
Notifications
You must be signed in to change notification settings - Fork 533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support batch and classes for NonMaxSuppression #3999
base: main
Are you sure you want to change the base?
Conversation
Following tests are passing https://github.com/iree-org/iree-test-suites/tree/main/onnx_ops/onnx/node/generated/test_nonmaxsuppression_two_batches The following test fails in IREE for cpu due to large tensor size used for stack allocation for sort operation. Tested for smaller tensor sizes and it is working as expected with correct results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left few comments, but I'm not quite clear about the nmsLoop part. @zjgarvey Would be nice to have your review too.
auto finalResIdx = batchLoopBody->getArgument(2); | ||
auto numResultValues = batchLoopBody->getArgument(3); | ||
|
||
auto boxValue = rewriter.create<Torch::AtenSelectIntOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would use AtenSliceTensorOp, also for the slice tensor case in the rest of changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jinchen62 I had used AtenSelectIntOp as the selected dim is to be removed for all the usages and it gets decomposed in the subsequent DecomposeComplexOps pass. Please let me know if it would be better to use add + slice_tensor + squeeze in this change itself ?
loc, emptyTensorTy, numOutputBoxes); | ||
Value maxBoxesPerClass = | ||
rewriter.create<Torch::PrimNumToTensorScalarOp>( | ||
loc, emptyTensorTy, maxOutputBoxesPerClass); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would use tensor type with shape [1] instead of [] since those few arguments are coming with [1] and you do Minimum op with them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both the values passed to Minimum op are scalars, which are from aten.size.int op, so had used the shape [] for minimum op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm mostly having difficulties parsing the nmsLoop
. Could you give some details as to what the implementation there is trying to do?
rewriter.create<Torch::AtenItemOp>(loc, intTy, minVal); | ||
|
||
// Loop through the nms result | ||
auto nmsLoop = rewriter.create<Torch::PrimLoopOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm finding it difficult to parse this loop.
The result of the (per-batch per-channel) torchvision nms op has shape <num_selected>
, and we need it to be <num_selected x 3>
, where each triple is like [batch_index, class_index, selected_box_index]
. Is the purpose of this loop to insert these elements into the final result? Is it possible to avoid using a loop for this and instead concatenate the nms result with some splat tensors, then insert that into the final result by keeping track of what the cumulative num_selected
is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zjgarvey @jinchen62 Updated the comments for the nmsLoop part. This loop is used to insert the triplet [batch_index, class_index, selected_box_index] at the required indices element by element.
" Is the purpose of this loop to insert these elements into the final result? Is it possible to avoid using a loop for this and instead concatenate the nms result with some splat tensors, then insert that into the final result by keeping track of what the cumulative num_selected is?"
-> Yes, I had already tried the approach with splat + concats as part of #3981
I was running into runtime issues like segfault / invalid mem access due to non handling of dynamic dims in IREE.
The IR using concat + splat method is here
I made use of loops so that we can have a working solution initially and then update the logic once issues in IREE are fixed. Please let me know your thoughts on this!
e259e2b
to
420bbca
Compare
No description provided.