Skip to content

Commit 90dabff

Browse files
kurtamohlerpytorchmergebot
authored andcommitted
Avoid COW materialize in various operations (pytorch#119506)
Operations affected include dot, cross, scatter/gather, shape, sort, triangular, unary, scalar, pad, complex, to_list, fft Pull Request resolved: pytorch#119506 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#119501, pytorch#119502, pytorch#119503, pytorch#119504
1 parent 8a09f13 commit 90dabff

19 files changed

+76
-76
lines changed

aten/src/ATen/native/Blas.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ Tensor dot(const Tensor &self, const Tensor &other){
185185

186186
return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::BFloat16, at::ScalarType::Half, self.scalar_type(), "dot", [&] {
187187
Tensor result = at::empty({}, self.options());
188-
result.fill_(dot_impl<scalar_t>(self.numel(), self.data_ptr<scalar_t>(), self.stride(0), other.data_ptr<scalar_t>(), other.stride(0)));
188+
result.fill_(dot_impl<scalar_t>(self.numel(), const_cast<scalar_t*>(self.const_data_ptr<scalar_t>()), self.stride(0), const_cast<scalar_t*>(other.const_data_ptr<scalar_t>()), other.stride(0)));
189189
return result;
190190
});
191191
}
@@ -216,7 +216,7 @@ Tensor vdot(const Tensor &self, const Tensor &other){
216216

217217
return AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "vdot", [&] {
218218
Tensor result = at::empty({}, self.options());
219-
result.fill_(vdot_impl<scalar_t>(self.numel(), self.data_ptr<scalar_t>(), self.stride(0), other.data_ptr<scalar_t>(), other.stride(0)));
219+
result.fill_(vdot_impl<scalar_t>(self.numel(), const_cast<scalar_t*>(self.const_data_ptr<scalar_t>()), self.stride(0), const_cast<scalar_t *>(other.const_data_ptr<scalar_t>()), other.stride(0)));
220220
return result;
221221
});
222222

aten/src/ATen/native/Scalar.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Scalar _local_scalar_dense_cpu(const Tensor& self) {
4040
self.scalar_type(),
4141
"_local_scalar_dense_cpu",
4242
AT_WRAP([&] {
43-
scalar_t value = *self.data_ptr<scalar_t>();
43+
scalar_t value = *self.const_data_ptr<scalar_t>();
4444
r = Scalar(value);
4545
}),
4646
AT_EXPAND(AT_SD_TYPES)

aten/src/ATen/native/Sorting.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ std::tuple<Tensor&, Tensor&> median_with_indices_impl(
546546
.declare_static_shape(sizes, /*squash_dims=*/dim)
547547
.add_output(vals)
548548
.add_output(inds)
549-
.add_input(in)
549+
.add_const_input(in)
550550
.build();
551551

552552
AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, in.scalar_type(), "median_out", [&] {

aten/src/ATen/native/TensorFactories.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ Tensor& complex_out(const Tensor& real, const Tensor& imag, Tensor& result) {
214214
complex_check_dtype(result, real, imag);
215215
auto iter = TensorIteratorConfig()
216216
.add_output(result)
217-
.add_input(real)
218-
.add_input(imag)
217+
.add_const_input(real)
218+
.add_const_input(imag)
219219
.check_all_same_dtype(false)
220220
.build();
221221
complex_stub(iter.device_type(), iter);
@@ -234,8 +234,8 @@ Tensor& polar_out(const Tensor& abs, const Tensor& angle, Tensor& result) {
234234
complex_check_dtype(result, abs, angle);
235235
auto iter = TensorIteratorConfig()
236236
.add_output(result)
237-
.add_input(abs)
238-
.add_input(angle)
237+
.add_const_input(abs)
238+
.add_const_input(angle)
239239
.check_all_same_dtype(false)
240240
.build();
241241
polar_stub(iter.device_type(), iter);

aten/src/ATen/native/TensorShape.cpp

+11-11
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ static void fastCatOutDim0(const Tensor& out, const MaterializedITensorListRef&
554554
for (const Tensor& input : inputs) {
555555
TORCH_CHECK(outBytes >= totalBytes);
556556
if (input.nbytes() > 0) {
557-
std::memcpy(dataPtr + totalBytes, input.data_ptr(), input.nbytes());
557+
std::memcpy(dataPtr + totalBytes, input.const_data_ptr(), input.nbytes());
558558
}
559559
totalBytes += input.nbytes();
560560
}
@@ -609,18 +609,18 @@ TORCH_IMPL_FUNC(cat_out_cpu)
609609
.set_check_mem_overlap(false)
610610
.resize_outputs(false)
611611
.add_output(result_slice)
612-
.add_input(source_slice)
612+
.add_const_input(source_slice)
613613
.enforce_safe_casting_to_output(true)
614614
.build();
615615

616616
for (const Tensor& tensor : materialized) {
617617
if (cat_should_skip_tensor(tensor)) {
618618
continue;
619619
}
620-
auto source_data = static_cast<char*>(tensor.data_ptr());
620+
auto source_data = static_cast<const char*>(tensor.const_data_ptr());
621621
auto result_data = static_cast<char*>(result_slice_data) + offset * result_stride_bytes;
622622
iter.unsafe_replace_operand(0, result_data);
623-
iter.unsafe_replace_operand(1, source_data);
623+
iter.unsafe_replace_operand(1, const_cast<char*>(source_data));
624624
copy_stub(iter.device_type(), iter, false);
625625
offset += slice_dim_size;
626626
}
@@ -636,7 +636,7 @@ TORCH_IMPL_FUNC(cat_out_cpu)
636636
.set_check_mem_overlap(false) // Already checked above
637637
.resize_outputs(false)
638638
.add_output(result_slice)
639-
.add_input(tensor)
639+
.add_const_input(tensor)
640640
.promote_inputs_to_common_dtype(true)
641641
.cast_common_dtype_to_outputs(true)
642642
.enforce_safe_casting_to_output(true)
@@ -1004,7 +1004,7 @@ std::vector<Tensor> tensor_split(const Tensor& self, const Tensor& tensor_indice
10041004
int64_t sections = tensor_indices_or_sections.item<int64_t>();
10051005
return self.tensor_split(sections, dim);
10061006
} else {
1007-
auto indices_data = tensor_indices_or_sections.data_ptr<int64_t>();
1007+
auto indices_data = tensor_indices_or_sections.const_data_ptr<int64_t>();
10081008
auto stride = tensor_indices_or_sections.stride(0);
10091009
auto numel = tensor_indices_or_sections.numel();
10101010
std::vector<int64_t> indices(numel);
@@ -1344,22 +1344,22 @@ Tensor& narrow_copy_dense_cpu_out(
13441344
return output;
13451345
}
13461346

1347-
char* src_bytes = static_cast<char*>(self_contig->data_ptr());
1347+
const char* src_bytes = static_cast<const char*>(self_contig->const_data_ptr());
13481348
char* dst_bytes = static_cast<char*>(output.data_ptr());
13491349

13501350
size_t src_block_size_bytes = itemsize * src_block_size;
13511351
size_t dst_block_size_bytes = itemsize * dst_block_size;
13521352
size_t src_offset = unit * start;
13531353

1354-
char* src_offset_bytes = src_bytes + itemsize * src_offset;
1354+
const char* src_offset_bytes = src_bytes + itemsize * src_offset;
13551355
char* dst_offset_bytes = dst_bytes;
13561356

13571357
for (const auto i : c10::irange(num_blocks)) {
1358-
char* local_src_offset_bytes = src_offset_bytes + i * src_block_size_bytes;
1358+
const char* local_src_offset_bytes = src_offset_bytes + i * src_block_size_bytes;
13591359
char* local_dst_offset_bytes = dst_offset_bytes + i * dst_block_size_bytes;
13601360
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1361-
static_cast<void*>(local_src_offset_bytes + dst_block_size_bytes) <=
1362-
static_cast<void*>(src_bytes + src_nbytes));
1361+
static_cast<const void*>(local_src_offset_bytes + dst_block_size_bytes) <=
1362+
static_cast<const void*>(src_bytes + src_nbytes));
13631363
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
13641364
static_cast<void*>(local_dst_offset_bytes + dst_block_size_bytes) <=
13651365
static_cast<void*>(dst_bytes + dst_nbytes));

aten/src/ATen/native/TensorTransformations.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ Tensor flip(const Tensor& self, IntArrayRef dims) {
6363
.check_all_same_dtype(false)
6464
.declare_static_dtype_and_device(self.scalar_type(), self.device())
6565
.add_output(out_tensor)
66-
.add_input(self)
67-
.add_input(restrided_self)
66+
.add_const_input(self)
67+
.add_const_input(restrided_self)
6868
.build();
6969

7070
auto* data = reinterpret_cast<char*>(iter.data_ptr(0));

aten/src/ATen/native/TriangularOps.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ namespace {
4141
template <typename scalar_t>
4242
void apply_triu_tril_single(
4343
scalar_t* result,
44-
scalar_t* self,
44+
const scalar_t* self,
4545
bool inplace,
4646
int64_t k,
4747
int64_t n,
@@ -86,7 +86,7 @@ template <typename scalar_t>
8686
void apply_triu_tril(const Tensor& result, const Tensor& self, bool inplace, int64_t k, bool upper) {
8787
auto n = self.size(-2);
8888
auto m = self.size(-1);
89-
auto self_data = self.data_ptr<scalar_t>();
89+
auto self_data = self.const_data_ptr<scalar_t>();
9090
auto self_stride = (self.dim() > 2 && self.stride(-3) > 0) ? self.stride(-3) : 1;
9191
auto batchsize = batchCountTrilTriu(result);
9292
auto self_row_stride = self.stride(-2);
@@ -107,7 +107,7 @@ void apply_triu_tril(const Tensor& result, const Tensor& self, bool inplace, int
107107

108108
parallel_for(0, batchsize, 0, [&](int64_t start, int64_t end) {
109109
for (const auto b : c10::irange(start, end)) {
110-
scalar_t* self_batch = &self_data[b * self_stride];
110+
const scalar_t* self_batch = &self_data[b * self_stride];
111111
scalar_t* result_batch = &result_data[b * result_stride];
112112
apply_triu_tril_single<scalar_t>(
113113
result_batch,

aten/src/ATen/native/UnaryOps.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,7 @@ Tensor& logical_not_out(const Tensor& self, Tensor& result) {
868868
TensorIterator iter = TensorIteratorConfig()
869869
.check_all_same_dtype(false)
870870
.add_output(result)
871-
.add_input(self)
871+
.add_const_input(self)
872872
.build();
873873
logical_not_stub(iter.device_type(), iter);
874874
return result;
@@ -964,7 +964,7 @@ std::tuple<Tensor&, Tensor&> frexp_out(const Tensor& self,
964964
auto iter = TensorIteratorConfig()
965965
.add_output(mantissa)
966966
.add_output(exponent)
967-
.add_input(self)
967+
.add_const_input(self)
968968
.check_all_same_dtype(false)
969969
.set_check_mem_overlap(true)
970970
.build();

aten/src/ATen/native/cpu/CatKernel.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ namespace at::native {
1212
namespace {
1313

1414
struct InputMeta {
15-
void* data_ptr;
15+
const void* data_ptr;
1616
int64_t inner_size;
1717

1818
InputMeta(const Tensor& t, int64_t dim, int64_t inner)
19-
: data_ptr(t.data_ptr())
19+
: data_ptr(t.const_data_ptr())
2020
, inner_size(t.sizes()[dim] * inner) {}
2121
};
2222

@@ -38,7 +38,7 @@ void cat_serial_kernel_impl(const Tensor& result, const MaterializedITensorListR
3838
for (const auto i : c10::irange(outer)) {
3939
for (const auto j : c10::irange(ninputs)) {
4040
int64_t local_inner = inputs[j].inner_size;
41-
scalar_t* input_ptr = (scalar_t*)(inputs[j].data_ptr) + i * local_inner;
41+
const scalar_t* input_ptr = (const scalar_t*)(inputs[j].data_ptr) + i * local_inner;
4242
int64_t d = 0;
4343
for (; d < local_inner - (local_inner % Vec::size()); d += Vec::size()) {
4444
Vec in_vec = Vec::loadu(input_ptr + d);

aten/src/ATen/native/cpu/CrossKernel.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ static void apply_cross(const Tensor& result, const Tensor& a, const Tensor& b,
2121
int64_t b_stride = b.stride(dim);
2222
int64_t r_stride = result.stride(dim);
2323

24-
scalar_t *a_ptr = a.data_ptr<scalar_t>();
25-
scalar_t *b_ptr = b.data_ptr<scalar_t>();
24+
const scalar_t *a_ptr = a.const_data_ptr<scalar_t>();
25+
const scalar_t *b_ptr = b.const_data_ptr<scalar_t>();
2626
scalar_t *r_ptr = result.data_ptr<scalar_t>();
2727

2828
parallel_for(0, total, internal::GRAIN_SIZE, [&](int64_t s, int64_t e) {

aten/src/ATen/native/cpu/PaddingKernel.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ void cpu_padding(
136136
auto input = input_.contiguous();
137137
auto output = output_.contiguous();
138138

139-
auto input_data = input.data_ptr<scalar_t>();
139+
auto input_data = input.const_data_ptr<scalar_t>();
140140
auto output_data = output.data_ptr<scalar_t>();
141141

142142
// fold nbatch and channels into single dimension for channels first.
@@ -158,7 +158,7 @@ void cpu_padding(
158158

159159
// do vectorized copy whe output is overlapped with input on W,
160160
// only applies to positive padding
161-
auto loop = [=](scalar_t* out, scalar_t* in, bool positive_padding) {
161+
auto loop = [=](scalar_t* out, const scalar_t* in, bool positive_padding) {
162162
if (positive_padding) {
163163
for (const auto ow : c10::irange(pad_w)) {
164164
int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
@@ -198,7 +198,7 @@ void cpu_padding(
198198
for (const auto i : c10::irange(begin, end)) {
199199
int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
200200
scalar_t* output_ptr = output_data + i * output_width;
201-
scalar_t* input_ptr = input_data + c * input_height * input_width + ih * input_width;
201+
const scalar_t* input_ptr = input_data + c * input_height * input_width + ih * input_width;
202202

203203
loop(output_ptr, input_ptr, p.is_padding_positive_width);
204204
data_index_step(c, channels, oh, output_height);
@@ -214,7 +214,7 @@ void cpu_padding(
214214
int64_t id = PaddingType::index(od, input_depth, pad_d, offset_d);
215215
int64_t ih = PaddingType::index(oh, input_height, pad_h, offset_h);
216216
scalar_t* output_ptr = output_data + i * output_width;
217-
scalar_t* input_ptr = input_data + c * input_depth * input_height * input_width +
217+
const scalar_t* input_ptr = input_data + c * input_depth * input_height * input_width +
218218
id * input_height * input_width + ih * input_width;
219219

220220
loop(output_ptr, input_ptr, p.is_padding_positive_width);
@@ -243,7 +243,7 @@ void cpu_padding_channels_last(
243243
auto input = input_.contiguous(memory_format);
244244
auto output = output_.contiguous(memory_format);
245245

246-
auto input_data = input.data_ptr<scalar_t>();
246+
auto input_data = input.const_data_ptr<scalar_t>();
247247
auto output_data = output.data_ptr<scalar_t>();
248248

249249
int64_t nbatch = p.nbatch;
@@ -274,7 +274,7 @@ void cpu_padding_channels_last(
274274
int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
275275

276276
scalar_t* output_ptr = output_data + i * channels;
277-
scalar_t* input_ptr = input_data + (n * input_height * input_width + ih * input_width + iw) * channels;
277+
const scalar_t* input_ptr = input_data + (n * input_height * input_width + ih * input_width + iw) * channels;
278278
copy_stub(output_ptr, input_ptr, channels);
279279

280280
data_index_step(n, nbatch, oh, output_height, ow, output_width);
@@ -292,7 +292,7 @@ void cpu_padding_channels_last(
292292
int64_t iw = PaddingType::index(ow, input_width, pad_w, offset_w);
293293

294294
scalar_t* output_ptr = output_data + i * channels;
295-
scalar_t* input_ptr = input_data + (n * input_depth * input_height * input_width +
295+
const scalar_t* input_ptr = input_data + (n * input_depth * input_height * input_width +
296296
id * input_height * input_width + ih * input_width + iw) * channels;
297297
copy_stub(output_ptr, input_ptr, channels);
298298

aten/src/ATen/native/cpu/ScatterGatherKernel.cpp

+14-14
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ struct cpu_scatter_gather_base_kernel {
186186
// NOLINTNEXTLINE(bugprone-argument-comment)
187187
.declare_static_shape(index.sizes(), /*squash_dim=*/dim)
188188
.add_output(buffer)
189-
.add_input(index)
189+
.add_const_input(index)
190190
.build();
191191

192192
auto self_dim_stride = ensure_nonempty_stride(buffer, dim);
@@ -273,8 +273,8 @@ struct cpu_scatter_gather_base_kernel {
273273
// NOLINTNEXTLINE(bugprone-argument-comment)
274274
.declare_static_shape(index.sizes(), /*squash_dim=*/dim)
275275
.add_output(buffer)
276-
.add_input(src)
277-
.add_input(index)
276+
.add_const_input(src)
277+
.add_const_input(index)
278278
.build();
279279

280280
auto self_dim_stride = ensure_nonempty_stride(buffer, dim);
@@ -369,8 +369,8 @@ struct cpu_scatter_gather_base_kernel {
369369
// NOLINTNEXTLINE(bugprone-argument-comment)
370370
.declare_static_shape(index.sizes(), /*squash_dim=*/dim)
371371
.add_output(buffer)
372-
.add_input(src)
373-
.add_input(index)
372+
.add_const_input(src)
373+
.add_const_input(index)
374374
.build();
375375

376376
auto self_dim_stride = ensure_nonempty_stride(buffer, dim);
@@ -464,8 +464,8 @@ struct cpu_scatter_gather_base_kernel {
464464
// NOLINTNEXTLINE(bugprone-argument-comment)
465465
.declare_static_shape(index.sizes(), /*squash_dim=*/dim)
466466
.add_output(buffer)
467-
.add_input(src)
468-
.add_input(index)
467+
.add_const_input(src)
468+
.add_const_input(index)
469469
.build();
470470

471471
auto self_dim_stride = ensure_nonempty_stride(buffer, dim);
@@ -560,8 +560,8 @@ struct cpu_scatter_gather_base_kernel {
560560
// NOLINTNEXTLINE(bugprone-argument-comment)
561561
.declare_static_shape(index.sizes(), /*squash_dim=*/dim)
562562
.add_output(buffer)
563-
.add_input(src)
564-
.add_input(index)
563+
.add_const_input(src)
564+
.add_const_input(index)
565565
.build();
566566

567567
auto self_dim_stride = ensure_nonempty_stride(buffer, dim);
@@ -687,9 +687,9 @@ std::pair<K*, V*> radix_sort_parallel(
687687

688688
template <typename scalar_t, ReductionType reduce>
689689
void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index, const Tensor& src, bool include_self) {
690-
int64_t* index_data = index.data_ptr<int64_t>();
690+
const int64_t* index_data = index.const_data_ptr<int64_t>();
691691
scalar_t* self_data = self.data_ptr<scalar_t>();
692-
scalar_t* src_data = src.data_ptr<scalar_t>();
692+
const scalar_t* src_data = src.const_data_ptr<scalar_t>();
693693

694694
const int64_t M = ensure_nonempty_size(self, 0);
695695
const int64_t nnz = ensure_nonempty_size(index, 0);
@@ -812,9 +812,9 @@ void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index,
812812

813813
template <typename scalar_t>
814814
void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& index, const Tensor& self) {
815-
int64_t* index_data = index.data_ptr<int64_t>();
815+
const int64_t* index_data = index.const_data_ptr<int64_t>();
816816
scalar_t* result_data = result.data_ptr<scalar_t>();
817-
scalar_t* self_data = self.data_ptr<scalar_t>();
817+
const scalar_t* self_data = self.const_data_ptr<scalar_t>();
818818

819819
const int64_t M = ensure_nonempty_size(result, 0);
820820
const int64_t N = ensure_nonempty_size(self, 0);
@@ -832,7 +832,7 @@ void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& index,
832832
"index ", index,
833833
" is out of bounds for dimension ", 0,
834834
" with size ", index_upper_bound);
835-
scalar_t* self_ptr = self_data + index * K;
835+
const scalar_t* self_ptr = self_data + index * K;
836836
int64_t d = 0;
837837
for (; d < K - (K % Vec::size()); d += Vec::size()) {
838838
Vec out_vec = Vec::loadu(self_ptr + d);

aten/src/ATen/native/cpu/SortingKernel.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ static void topk_kernel(
216216
.declare_static_shape(sizes, /*squash_dims=*/dim)
217217
.add_output(values)
218218
.add_output(indices)
219-
.add_input(self)
219+
.add_const_input(self)
220220
.build();
221221

222222
auto mode_values_stride = values.strides()[dim];

aten/src/ATen/native/cuda/CrossKernel.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ void cross_impl(const Tensor& result, const Tensor& x1, const Tensor& x2, int64_
6868

6969
auto iter = TensorIteratorConfig()
7070
.add_output(result)
71-
.add_input(x1)
72-
.add_input(x2)
71+
.add_const_input(x1)
72+
.add_const_input(x2)
7373
.resize_outputs(false)
7474
.declare_static_shape(result.sizes(), /*squash_dims=*/dim)
7575
.build();

0 commit comments

Comments
 (0)