2
2
#include < comm/SYCLContext.h>
3
3
#include < comm/xpu_aten.h>
4
4
5
+ #include < ATen/ceil_div.h>
5
6
#include < ATen/native/xpu/sycl/NMSKernel.h>
6
7
7
8
namespace at {
@@ -97,9 +98,68 @@ struct NMSKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
97
98
sycl_local_acc_t <acc_t > slm_;
98
99
};
99
100
101
+ struct GatherKeepFromMask : public __SYCL_KER_CONFIG_CONVENTION__ {
102
+ void operator ()(sycl::nd_item<1 > item) const {
103
+ const int thread_id = item.get_local_id (0 );
104
+
105
+ // Initialize removed
106
+ for (int i = thread_id; i < col_blocks_; i += nms_items_per_group) {
107
+ removed_[i] = 0 ;
108
+ }
109
+ item.barrier (sycl_local_fence);
110
+
111
+ for (int nblock = 0 ; nblock < col_blocks_; nblock++) {
112
+ auto removed_val = removed_[nblock];
113
+ item.barrier (sycl_local_fence);
114
+ const int i_offset = nblock * nms_items_per_group;
115
+
116
+ for (int inblock = 0 ; inblock < nms_items_per_group; inblock++) {
117
+ const int i = i_offset + inblock;
118
+ if (i >= n_boxes_)
119
+ break ;
120
+
121
+ // Select a candidate, check if it should be kept
122
+ if (!(removed_val & (1ULL << inblock))) {
123
+ if (thread_id == 0 ) {
124
+ keep_[i] = true ;
125
+ }
126
+ auto p = dev_mask_ + i * col_blocks_;
127
+
128
+ // Remove all bboxes which overlap the candidate
129
+ for (int j = thread_id; j < col_blocks_; j += nms_items_per_group) {
130
+ if (j >= nblock)
131
+ removed_[j] |= p[j];
132
+ }
133
+ item.barrier (sycl_local_fence);
134
+ removed_val = removed_[nblock];
135
+ }
136
+ }
137
+ }
138
+ }
139
+ GatherKeepFromMask (
140
+ bool * keep,
141
+ const unsigned long long * dev_mask,
142
+ const int n_boxes)
143
+ : keep_(keep),
144
+ dev_mask_ (dev_mask),
145
+ n_boxes_(n_boxes),
146
+ col_blocks_(ceil_div(n_boxes, nms_items_per_group)) {}
147
+
148
+ void sycl_ker_config_convention (sycl::handler& cgh) {
149
+ removed_ = sycl_local_acc_t <unsigned long long >(col_blocks_, cgh);
150
+ }
151
+
152
+ private:
153
+ bool * keep_;
154
+ const unsigned long long * dev_mask_;
155
+ const int n_boxes_;
156
+ const int col_blocks_;
157
+ sycl_local_acc_t <unsigned long long > removed_;
158
+ };
159
+
100
160
Tensor nms_kernel (const Tensor& dets_sorted, float iou_threshold) {
101
161
int dets_num = dets_sorted.size (0 );
102
- int col_blocks = (dets_num + nms_items_per_group - 1 ) / nms_items_per_group ;
162
+ int col_blocks = ceil_div (dets_num, nms_items_per_group) ;
103
163
auto mask = at::empty (
104
164
{dets_num * col_blocks}, dets_sorted.options ().dtype (at::kLong ));
105
165
@@ -120,7 +180,19 @@ Tensor nms_kernel(const Tensor& dets_sorted, float iou_threshold) {
120
180
sycl_kernel_submit (
121
181
global_range, local_range, at::xpu::getCurrentSYCLQueue (), caller);
122
182
});
123
- return mask;
183
+
184
+ at::Tensor keep = at::zeros (
185
+ {dets_num}, dets_sorted.options ().dtype (at::kBool ).device (at::kXPU ));
186
+ auto caller = GatherKeepFromMask (
187
+ keep.data_ptr <bool >(),
188
+ (unsigned long long *)mask.data_ptr <int64_t >(),
189
+ dets_num);
190
+ sycl_kernel_submit (
191
+ std::min (col_blocks, nms_items_per_group),
192
+ std::min (col_blocks, nms_items_per_group),
193
+ at::xpu::getCurrentSYCLQueue (),
194
+ caller);
195
+ return keep;
124
196
}
125
197
126
198
} // namespace xpu
0 commit comments