Skip to content

Commit 8a09f13

Browse files
kurtamohlerpytorchmergebot
authored andcommitted
Avoid COW materialize in index, reduce, compare, unique, and copy ops (pytorch#119504)
Pull Request resolved: pytorch#119504 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#119501, pytorch#119502, pytorch#119503
1 parent 0e6b314 commit 8a09f13

14 files changed

+94
-89
lines changed

aten/src/ATen/native/Copy.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,15 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
8181
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.sizes().equals(src.sizes()));
8282

8383
_AT_DISPATCH_CP_TYPES(self.scalar_type(), "copy_", [&] {
84-
scalar_t* sp = src.data_ptr<scalar_t>();
84+
const scalar_t* sp = src.const_data_ptr<scalar_t>();
8585
scalar_t* rp = self.data_ptr<scalar_t>();
8686
scalar_t* bp = buf.data_ptr<scalar_t>();
8787

8888
int64_t NR = src.size(0);
8989
int64_t NC = src.size(1);
9090
for (int64_t R = 0; R < NR; R += BLOCK_SZ) {
9191
for (int64_t C = 0; C < NC; C += BLOCK_SZ) {
92-
scalar_t* spo = sp + R + C * NR;
92+
const scalar_t* spo = sp + R + C * NR;
9393
scalar_t* rpo = rp + C + R * NC;
9494

9595
int nr = std::min(NR - R, BLOCK_SZ);
@@ -156,22 +156,22 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
156156
auto* output_ptr =
157157
reinterpret_cast<fbgemm::float16*>(self.data_ptr<at::Half>());
158158
if (self.numel() < at::internal::GRAIN_SIZE) {
159-
fbgemm::FloatToFloat16_simd(src.data_ptr<float>(), output_ptr, self.numel());
159+
fbgemm::FloatToFloat16_simd(src.const_data_ptr<float>(), output_ptr, self.numel());
160160
} else {
161161
at::parallel_for(
162162
0,
163163
self.numel(),
164164
at::internal::GRAIN_SIZE,
165165
[&](int64_t begin, int64_t end) {
166166
fbgemm::FloatToFloat16_simd(
167-
src.data_ptr<float>() + begin,
167+
src.const_data_ptr<float>() + begin,
168168
output_ptr + begin,
169169
end - begin);
170170
});
171171
}
172172
} else {
173-
auto in_data = reinterpret_cast<fbgemm::float16*>(
174-
src.data_ptr<at::Half>());
173+
auto in_data = reinterpret_cast<const fbgemm::float16*>(
174+
src.const_data_ptr<at::Half>());
175175
auto* output_ptr = self.data_ptr<float>();
176176
if (self.numel() < at::internal::GRAIN_SIZE) {
177177
fbgemm::Float16ToFloat_simd(in_data, output_ptr, self.numel());
@@ -265,7 +265,7 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
265265

266266
auto iter = TensorIteratorConfig()
267267
.add_output(self)
268-
.add_input(src)
268+
.add_const_input(src)
269269
.resize_outputs(false)
270270
.check_all_same_dtype(false)
271271
.check_all_same_device(false)
@@ -335,7 +335,7 @@ void copy_ignoring_overlaps(const TensorBase &dst, const TensorBase &src) {
335335
// FIXME: really, overlapping writes should be illegal/an error in Torch
336336
auto iter = TensorIteratorConfig()
337337
.add_output(dst)
338-
.add_input(src)
338+
.add_const_input(src)
339339
.resize_outputs(false)
340340
.set_check_mem_overlap(false)
341341
.check_all_same_dtype(true)

aten/src/ATen/native/ReduceOps.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -1250,7 +1250,7 @@ Tensor trace_cpu(const Tensor& self) {
12501250
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(self.scalar_type(), "trace", [&] {
12511251
using accscalar_t = at::acc_type<scalar_t, false>;
12521252
accscalar_t sum = 0;
1253-
const auto* t_data = self.data_ptr<scalar_t>();
1253+
const auto* t_data = self.const_data_ptr<scalar_t>();
12541254

12551255
int64_t t_stride_0, t_stride_1, t_diag_size;
12561256

@@ -1726,7 +1726,7 @@ static double std_var_all_cpu(const Tensor& self, double correction, bool take_s
17261726

17271727
auto mean = self.mean().item<double>();
17281728
auto iter = TensorIteratorConfig()
1729-
.add_input(self)
1729+
.add_const_input(self)
17301730
.build();
17311731

17321732
auto reduction = [&](int64_t begin, int64_t end, double thread_sum) {
@@ -2197,7 +2197,7 @@ bool cpu_equal(const Tensor& self, const Tensor& other) {
21972197
return true;
21982198
}
21992199
std::atomic<bool> result{true};
2200-
auto iter = TensorIteratorConfig().add_input(self).build();
2200+
auto iter = TensorIteratorConfig().add_const_input(self).build();
22012201
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(kHalf, kBFloat16, iter.input_dtype(), "equal_notnan_cpu", [&] {
22022202
iter.for_each([&](char** data, const int64_t *strides, int64_t dim_size) {
22032203
if (!result) {
@@ -2218,8 +2218,8 @@ bool cpu_equal(const Tensor& self, const Tensor& other) {
22182218

22192219
std::atomic<bool> result{true};
22202220
auto iter = TensorIteratorConfig()
2221-
.add_input(self)
2222-
.add_input(other)
2221+
.add_const_input(self)
2222+
.add_const_input(other)
22232223
.allow_cpu_scalars(true)
22242224
.promote_inputs_to_common_dtype(true)
22252225
.build();

aten/src/ATen/native/TensorAdvancedIndexing.cpp

+41-41
Original file line numberDiff line numberDiff line change
@@ -408,9 +408,9 @@ static void build_index_op(
408408
config.set_check_mem_overlap(false)
409409
.check_all_same_dtype(false)
410410
.add_output(result)
411-
.add_owned_input(info.src);
411+
.add_owned_const_input(info.src);
412412
for (auto& index : info.indices) {
413-
config.add_owned_input(index);
413+
config.add_owned_const_input(index);
414414
}
415415
if (!result.defined()) {
416416
config.declare_static_dtype_and_device(info.src.scalar_type(), info.src.device());
@@ -614,9 +614,9 @@ static TensorIterator make_index_put_iterator(const AdvancedIndex& info, const T
614614
config.resize_outputs(false);
615615
config.check_all_same_dtype(false);
616616
config.add_output(info.src);
617-
config.add_input(value);
617+
config.add_const_input(value);
618618
for (auto& index : info.indices) {
619-
config.add_input(index);
619+
config.add_const_input(index);
620620
}
621621
return config.build();
622622
}
@@ -689,8 +689,8 @@ Tensor & put_(Tensor & self, const Tensor& index, const Tensor & source, const b
689689
auto iter = TensorIteratorConfig()
690690
.set_check_mem_overlap(false)
691691
.check_all_same_dtype(false)
692-
.add_input(source)
693-
.add_input(index_reshaped)
692+
.add_const_input(source)
693+
.add_const_input(index_reshaped)
694694
.build();
695695

696696
put_stub(iter.device_type(), iter, self, accumulate);
@@ -769,7 +769,7 @@ Tensor& take_out(const Tensor& self, const Tensor& index, Tensor& out) {
769769
.set_check_mem_overlap(false)
770770
.check_all_same_dtype(false)
771771
.add_output(out)
772-
.add_input(index)
772+
.add_const_input(index)
773773
.build();
774774

775775
// Early return after out has been resized
@@ -848,8 +848,8 @@ TORCH_IMPL_FUNC(index_copy_out)
848848
.check_all_same_dtype(false)
849849
.resize_outputs(false)
850850
.add_output(result_restrided)
851-
.add_input(index_restrided)
852-
.add_input(source_nonzero)
851+
.add_const_input(index_restrided)
852+
.add_const_input(source_nonzero)
853853
.build();
854854

855855
auto result_dim_size = result_nonzero.size(dim);
@@ -943,15 +943,15 @@ TORCH_IMPL_FUNC(index_add_cpu_out)
943943
auto iter = TensorIterator::borrowing_binary_op(selfSlice, selfSlice, sourceSlice);
944944

945945
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cpu_", [&] () {
946-
auto index_data = index_contig.data_ptr<index_t>();
946+
auto index_data = index_contig.const_data_ptr<index_t>();
947947
for (const auto i : c10::irange(numel)) {
948948
auto self_i = index_data[i];
949949
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
950950
auto self_data = static_cast<char*>(selfSlice.data_ptr()) + self_i * self_stride_bytes;
951-
auto source_data = static_cast<char*>(sourceSlice.data_ptr()) + i * source_stride_bytes;
951+
auto source_data = static_cast<const char*>(sourceSlice.const_data_ptr()) + i * source_stride_bytes;
952952
iter.unsafe_replace_operand(0, self_data);
953953
iter.unsafe_replace_operand(1, self_data);
954-
iter.unsafe_replace_operand(2, source_data);
954+
iter.unsafe_replace_operand(2, const_cast<char*>(source_data));
955955
add_stub(iter.device_type(), iter, alpha);
956956
}
957957
});
@@ -967,10 +967,10 @@ TORCH_IMPL_FUNC(index_add_cpu_out)
967967
auto source_stride = source.dim() == 0 ? 1 : source.stride(dim);
968968
// TODO: Maybe TensorAccessor can be used here?
969969
auto* result_ptr = result.data_ptr<scalar_t>();
970-
auto* source_ptr = source.data_ptr<scalar_t>();
970+
auto* source_ptr = source.const_data_ptr<scalar_t>();
971971
AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_add_cpu_",
972972
[&index_contig, &numel, &result, &result_ptr, &result_stride, &source_ptr, &source_stride, &alpha_value] {
973-
auto index_data = index_contig.data_ptr<index_t>();
973+
auto index_data = index_contig.const_data_ptr<index_t>();
974974
for (const auto i : c10::irange(numel)) {
975975
auto self_i = index_data[i];
976976
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < result.numel()), "index out of range in self");
@@ -1040,15 +1040,15 @@ static void index_reduce_func_impl(
10401040
auto iter = TensorIterator::borrowing_binary_op(selfSlice, selfSlice, sourceSlice);
10411041

10421042
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_func_cpu_", [&] () {
1043-
auto index_data = index_contig.data_ptr<index_t>();
1043+
auto index_data = index_contig.const_data_ptr<index_t>();
10441044
for (const auto i : c10::irange(numel)) {
10451045
auto self_i = index_data[i];
10461046
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
10471047
auto self_data = static_cast<char*>(selfSlice.data_ptr()) + self_i * self_stride_bytes;
1048-
auto source_data = static_cast<char*>(sourceSlice.data_ptr()) + i * source_stride_bytes;
1048+
auto source_data = static_cast<const char*>(sourceSlice.const_data_ptr()) + i * source_stride_bytes;
10491049
iter.unsafe_replace_operand(0, self_data);
10501050
iter.unsafe_replace_operand(1, self_data);
1051-
iter.unsafe_replace_operand(2, source_data);
1051+
iter.unsafe_replace_operand(2, const_cast<char*>(source_data));
10521052

10531053
switch (op) {
10541054
case ReductionType::PROD :
@@ -1090,11 +1090,11 @@ static void index_reduce_func_impl(
10901090
auto counts_stride = counts.dim() == 0 ? 1 : counts.stride(dim);
10911091
// TODO: Maybe TensorAccessor can be used here?
10921092
auto* result_ptr = result.data_ptr<scalar_t>();
1093-
auto* source_ptr = source.data_ptr<scalar_t>();
1093+
auto* source_ptr = source.const_data_ptr<scalar_t>();
10941094
auto counts_ptr = counts.data_ptr<scalar_t>();
10951095
AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_func_cpu_",
10961096
[&index_contig, &numel, &result, &result_ptr, &result_stride, &source_ptr, &source_stride, &op, &counts_ptr, &counts_stride] {
1097-
auto index_data = index_contig.data_ptr<index_t>();
1097+
auto index_data = index_contig.const_data_ptr<index_t>();
10981098
for (const auto i : c10::irange(numel)) {
10991099
auto self_i = index_data[i];
11001100
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < result.numel()), "index out of range in self");
@@ -1175,7 +1175,7 @@ static Tensor & index_select_out_cpu_dim1_(
11751175

11761176
auto out = static_cast<char*>(result_contig.data_ptr());
11771177

1178-
auto src_base = static_cast<const char*>(self_contig.data_ptr());
1178+
auto src_base = static_cast<const char*>(self_contig.const_data_ptr());
11791179

11801180
auto self_sizes = self_contig.sizes();
11811181
auto outer_dims_product = c10::size_to_dim_(1, self_sizes);
@@ -1191,7 +1191,7 @@ static Tensor & index_select_out_cpu_dim1_(
11911191
AT_DISPATCH_INDEX_TYPES(
11921192
index_contig.scalar_type(), "batch_index_select_compute", [&]() {
11931193

1194-
const auto* idxs = index_contig.data_ptr<index_t>();
1194+
const auto* idxs = index_contig.const_data_ptr<index_t>();
11951195
check_indexarray_range<index_t>(idxs, N, src_indexing_axis_dim);
11961196

11971197
// Special-case single-float copy for efficiency
@@ -1256,7 +1256,7 @@ Tensor & index_select_out_cpu_(const Tensor & self, int64_t dim, const Tensor &
12561256
"index_select(): self indexing axis dim should be positive");
12571257
AT_DISPATCH_INDEX_TYPES(
12581258
index_contig.scalar_type(), "index_select_empty_self_bound_check", [&]() {
1259-
const auto* idxs = index_contig.data_ptr<index_t>();
1259+
const auto* idxs = index_contig.const_data_ptr<index_t>();
12601260
check_indexarray_range<index_t>(idxs, numel, src_indexing_axis_dim);
12611261
});
12621262
return result;
@@ -1280,7 +1280,7 @@ Tensor & index_select_out_cpu_(const Tensor & self, int64_t dim, const Tensor &
12801280
.check_all_same_dtype(false)
12811281
.resize_outputs(false)
12821282
.add_output(resultSlice)
1283-
.add_input(selfSlice)
1283+
.add_const_input(selfSlice)
12841284
.build();
12851285

12861286
auto grain_size = at::internal::GRAIN_SIZE;
@@ -1293,7 +1293,7 @@ Tensor & index_select_out_cpu_(const Tensor & self, int64_t dim, const Tensor &
12931293
AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
12941294
[&index_contig, &start, &end, &sub_iter, &self_dim_size, &selfSlice_data, &self_stride_bytes,
12951295
&resultSlice_data, &result_stride_bytes] () {
1296-
auto index_data = index_contig.data_ptr<index_t>();
1296+
auto index_data = index_contig.const_data_ptr<index_t>();
12971297
for (const auto i : c10::irange(start, end)) {
12981298
auto self_i = index_data[i];
12991299
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
@@ -1322,7 +1322,7 @@ Tensor & index_select_out_cpu_(const Tensor & self, int64_t dim, const Tensor &
13221322
AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
13231323
[&index_contig, &slice_size_bytes, &self_dim_size, &selfSlice_data,
13241324
&self_stride_bytes, &resultSlice_data, &result_stride_bytes, &start, &end] () {
1325-
auto index_data = index_contig.data_ptr<index_t>();
1325+
auto index_data = index_contig.const_data_ptr<index_t>();
13261326
for (const auto i : c10::irange(start, end)) {
13271327
auto self_i = index_data[i];
13281328
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_dim_size), "index out of range in self");
@@ -1344,16 +1344,16 @@ Tensor & index_select_out_cpu_(const Tensor & self, int64_t dim, const Tensor &
13441344
AT_DISPATCH_QINT_TYPES(self.scalar_type(), "index_select_quant", [&index_contig, &self, &result, &dim, &numel] {
13451345
auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
13461346
auto result_stride = result.dim() == 0 ? 1 : result.stride(dim);
1347-
auto self_data_ptr = self.data_ptr<scalar_t>();
1347+
auto self_data_ptr = self.const_data_ptr<scalar_t>();
13481348
auto result_data_ptr = result.data_ptr<scalar_t>();
13491349
auto self_numel = self.numel();
13501350
AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_quant_",
13511351
[&index_contig, &numel, &self_numel, &self_data_ptr, &self_stride, &result_data_ptr, &result_stride] {
1352-
auto index_data = index_contig.data_ptr<index_t>();
1352+
auto index_data = index_contig.const_data_ptr<index_t>();
13531353
for (const auto i : c10::irange(numel)) {
13541354
auto self_i = index_data[i];
13551355
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self");
1356-
scalar_t *self_ip = self_data_ptr + self_i * self_stride;
1356+
const scalar_t *self_ip = self_data_ptr + self_i * self_stride;
13571357
*(result_data_ptr + i * result_stride) = *self_ip;
13581358
}
13591359
});
@@ -1364,16 +1364,16 @@ Tensor & index_select_out_cpu_(const Tensor & self, int64_t dim, const Tensor &
13641364
auto self_stride = self.dim() == 0 ? 1 : self.stride(dim);
13651365
auto result_stride = result.dim() == 0 ? 1 : result.stride(dim);
13661366

1367-
auto self_data_ptr = self.data_ptr<scalar_t>();
1367+
auto self_data_ptr = self.const_data_ptr<scalar_t>();
13681368
auto result_data_ptr = result.data_ptr<scalar_t>();
13691369
auto self_numel = self.numel();
13701370
AT_DISPATCH_INDEX_TYPES(index_contig.scalar_type(), "index_select_out_cpu_",
13711371
[&index_contig, &numel, &self_numel, &self_data_ptr, &self_stride, &result_data_ptr, &result_stride] {
1372-
auto index_data = index_contig.data_ptr<index_t>();
1372+
auto index_data = index_contig.const_data_ptr<index_t>();
13731373
for (const auto i : c10::irange(numel)) {
13741374
auto self_i = index_data[i];
13751375
TORCH_CHECK_INDEX((self_i >= 0) && (self_i < self_numel), "index out of range in self");
1376-
scalar_t *self_ip = self_data_ptr + self_i * self_stride;
1376+
const scalar_t *self_ip = self_data_ptr + self_i * self_stride;
13771377
*(result_data_ptr + i * result_stride) = *self_ip;
13781378
}
13791379
});
@@ -1462,7 +1462,7 @@ Tensor & index_fill_(Tensor & self, int64_t dim, const Tensor & index, const Sca
14621462
.check_all_same_dtype(false)
14631463
.resize_outputs(false)
14641464
.add_output(self_restrided)
1465-
.add_input(index_restrided)
1465+
.add_const_input(index_restrided)
14661466
.build();
14671467

14681468
auto self_dim_size = (self_nonzero_dim.sizes())[dim];
@@ -1924,7 +1924,7 @@ static Tensor & masked_fill_impl_cpu(Tensor & self, const Tensor & mask, const S
19241924
.check_all_same_dtype(false)
19251925
.resize_outputs(false)
19261926
.add_output(self)
1927-
.add_input(mask)
1927+
.add_const_input(mask)
19281928
.build();
19291929

19301930
masked_fill_stub(iter.device_type(), iter, value);
@@ -2017,8 +2017,8 @@ static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self,
20172017
.check_all_same_dtype(false)
20182018
.resize_outputs(false)
20192019
.add_output(result_strided)
2020-
.add_input(*_self)
2021-
.add_input(*_mask)
2020+
.add_const_input(*_self)
2021+
.add_const_input(*_mask)
20222022
.build();
20232023

20242024
masked_select_serial_stub(iter.device_type(), iter, orig_stride);
@@ -2041,9 +2041,9 @@ static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self,
20412041
.check_all_same_dtype(false)
20422042
.resize_outputs(false)
20432043
.add_output(result_strided)
2044-
.add_input(*_self)
2045-
.add_input(*_mask)
2046-
.add_input(mask_prefix_sum)
2044+
.add_const_input(*_self)
2045+
.add_const_input(*_mask)
2046+
.add_const_input(mask_prefix_sum)
20472047
.build();
20482048

20492049
masked_select_stub(iter.device_type(), iter, orig_stride);
@@ -2228,7 +2228,7 @@ Tensor count_nonzero_cpu(const Tensor& self, IntArrayRef dims){
22282228

22292229
// Optimized all-reduce
22302230
auto iter = TensorIteratorConfig()
2231-
.add_input(self)
2231+
.add_const_input(self)
22322232
.build();
22332233

22342234
const auto num_threads = at::get_num_threads();
@@ -2267,7 +2267,7 @@ Tensor& nonzero_out_cpu(const Tensor& self, Tensor& result) {
22672267
at::assert_no_overlap(result, self);
22682268

22692269
auto iter = TensorIteratorConfig()
2270-
.add_input(self)
2270+
.add_const_input(self)
22712271
.enforce_linear_iteration()
22722272
.build();
22732273

@@ -2495,7 +2495,7 @@ Tensor & masked_scatter__cpu(Tensor& self, const Tensor & mask, const Tensor & s
24952495
// order of indexing matters
24962496
.enforce_linear_iteration()
24972497
.add_output(self)
2498-
.add_input(*b_mask)
2498+
.add_const_input(*b_mask)
24992499
.build();
25002500

25012501
masked_scatter_stub(iter.device_type(), iter, src_cont);

0 commit comments

Comments
 (0)