Skip to content

Commit 9346da6

Browse files
authored
Merge branch 'main' into xccl/fix_group
2 parents 6c6775b + b8c05de commit 9346da6

File tree

2 files changed

+76
-35
lines changed

2 files changed

+76
-35
lines changed

src/ATen/native/xpu/NMS.cpp

+2-33
Original file line numberDiff line numberDiff line change
@@ -42,39 +42,8 @@ Tensor nms(const Tensor& dets, const Tensor& scores, double iou_threshold_) {
4242
scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true));
4343
auto dets_sorted = dets.index_select(0, order_t).contiguous();
4444

45-
int dets_num = dets.size(0);
46-
int col_blocks = (dets_num + nms_items_per_group - 1) / nms_items_per_group;
47-
48-
auto mask = nms_kernel(dets_sorted, iou_threshold);
49-
50-
at::Tensor mask_cpu = mask.to(at::kCPU);
51-
unsigned long long* mask_host =
52-
(unsigned long long*)mask_cpu.mutable_data_ptr();
53-
54-
std::vector<unsigned long long> remv(col_blocks);
55-
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
56-
57-
at::Tensor keep =
58-
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
59-
int64_t* keep_out = keep.mutable_data_ptr<int64_t>();
60-
61-
int num_to_keep = 0;
62-
for (int i = 0; i < dets_num; i++) {
63-
int nblock = i / nms_items_per_group;
64-
int inblock = i % nms_items_per_group;
65-
66-
if (!(remv[nblock] & (1ULL << inblock))) {
67-
keep_out[num_to_keep++] = i;
68-
unsigned long long* p = mask_host + i * col_blocks;
69-
for (int j = nblock; j < col_blocks; j++) {
70-
remv[j] |= p[j];
71-
}
72-
}
73-
}
74-
75-
return order_t.index(
76-
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
77-
.to(order_t.device(), keep.scalar_type())});
45+
auto keep = nms_kernel(dets_sorted, iou_threshold);
46+
return order_t.masked_select(keep);
7847
}
7948

8049
} // namespace at::native::xpu

src/ATen/native/xpu/sycl/NMSKernel.cpp

+74-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <comm/SYCLContext.h>
33
#include <comm/xpu_aten.h>
44

5+
#include <ATen/ceil_div.h>
56
#include <ATen/native/xpu/sycl/NMSKernel.h>
67

78
namespace at {
@@ -97,9 +98,68 @@ struct NMSKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
9798
sycl_local_acc_t<acc_t> slm_;
9899
};
99100

101+
struct GatherKeepFromMask : public __SYCL_KER_CONFIG_CONVENTION__ {
102+
void operator()(sycl::nd_item<1> item) const {
103+
const int thread_id = item.get_local_id(0);
104+
105+
// Initialize removed
106+
for (int i = thread_id; i < col_blocks_; i += nms_items_per_group) {
107+
removed_[i] = 0;
108+
}
109+
item.barrier(sycl_local_fence);
110+
111+
for (int nblock = 0; nblock < col_blocks_; nblock++) {
112+
auto removed_val = removed_[nblock];
113+
item.barrier(sycl_local_fence);
114+
const int i_offset = nblock * nms_items_per_group;
115+
116+
for (int inblock = 0; inblock < nms_items_per_group; inblock++) {
117+
const int i = i_offset + inblock;
118+
if (i >= n_boxes_)
119+
break;
120+
121+
// Select a candidate, check if it should be kept
122+
if (!(removed_val & (1ULL << inblock))) {
123+
if (thread_id == 0) {
124+
keep_[i] = true;
125+
}
126+
auto p = dev_mask_ + i * col_blocks_;
127+
128+
// Remove all bboxes which overlap the candidate
129+
for (int j = thread_id; j < col_blocks_; j += nms_items_per_group) {
130+
if (j >= nblock)
131+
removed_[j] |= p[j];
132+
}
133+
item.barrier(sycl_local_fence);
134+
removed_val = removed_[nblock];
135+
}
136+
}
137+
}
138+
}
139+
GatherKeepFromMask(
140+
bool* keep,
141+
const unsigned long long* dev_mask,
142+
const int n_boxes)
143+
: keep_(keep),
144+
dev_mask_(dev_mask),
145+
n_boxes_(n_boxes),
146+
col_blocks_(ceil_div(n_boxes, nms_items_per_group)) {}
147+
148+
void sycl_ker_config_convention(sycl::handler& cgh) {
149+
removed_ = sycl_local_acc_t<unsigned long long>(col_blocks_, cgh);
150+
}
151+
152+
private:
153+
bool* keep_;
154+
const unsigned long long* dev_mask_;
155+
const int n_boxes_;
156+
const int col_blocks_;
157+
sycl_local_acc_t<unsigned long long> removed_;
158+
};
159+
100160
Tensor nms_kernel(const Tensor& dets_sorted, float iou_threshold) {
101161
int dets_num = dets_sorted.size(0);
102-
int col_blocks = (dets_num + nms_items_per_group - 1) / nms_items_per_group;
162+
int col_blocks = ceil_div(dets_num, nms_items_per_group);
103163
auto mask = at::empty(
104164
{dets_num * col_blocks}, dets_sorted.options().dtype(at::kLong));
105165

@@ -120,7 +180,19 @@ Tensor nms_kernel(const Tensor& dets_sorted, float iou_threshold) {
120180
sycl_kernel_submit(
121181
global_range, local_range, at::xpu::getCurrentSYCLQueue(), caller);
122182
});
123-
return mask;
183+
184+
at::Tensor keep = at::zeros(
185+
{dets_num}, dets_sorted.options().dtype(at::kBool).device(at::kXPU));
186+
auto caller = GatherKeepFromMask(
187+
keep.data_ptr<bool>(),
188+
(unsigned long long*)mask.data_ptr<int64_t>(),
189+
dets_num);
190+
sycl_kernel_submit(
191+
std::min(col_blocks, nms_items_per_group),
192+
std::min(col_blocks, nms_items_per_group),
193+
at::xpu::getCurrentSYCLQueue(),
194+
caller);
195+
return keep;
124196
}
125197

126198
} // namespace xpu

0 commit comments

Comments
 (0)