Skip to content

Commit 4a37f57

Browse files
pearupytorchmergebot
authored andcommitted
Add batched sparse CSR/CSC/BSR/BSC to sparse COO conversion support (pytorch#116206)
As in the title. Fixes pytorch#104868 Pull Request resolved: pytorch#116206 Approved by: https://github.com/amjames, https://github.com/lezcano, https://github.com/cpuhrsch
1 parent 4b74bb6 commit 4a37f57

File tree

5 files changed

+72
-77
lines changed

5 files changed

+72
-77
lines changed

aten/src/ATen/native/TensorConversions.cpp

+37-20
Original file line numberDiff line numberDiff line change
@@ -1513,24 +1513,38 @@ void convert_indices_from_csr_to_coo_cpu(
15131513
const Tensor& crow_indices,
15141514
const Tensor& col_indices,
15151515
const bool transpose = false) {
1516-
int64_t nrows = crow_indices.numel() - 1;
1517-
if (nrows == 0) {
1518-
indices.zero_();
1516+
int64_t nrows = crow_indices.size(-1) - 1;
1517+
int64_t nnz = col_indices.size(-1);
1518+
if (nrows == 0 || nnz == 0) {
1519+
indices.zero_(); // is this needed as indices has a zero-valued
1520+
// dimension when nrows or nnz is 0?
15191521
return;
15201522
}
15211523
auto crow_indices_ = crow_indices.expect_contiguous();
1524+
int64_t total_nnz = col_indices.numel();
1525+
int64_t batch_ndim = crow_indices.dim() - 1;
1526+
if (batch_ndim > 0) {
1527+
auto batch_indices = indices.narrow(0, 0, batch_ndim);
1528+
batch_indices.copy_(batch_indices.new_ones(crow_indices.sizes().slice(0, batch_ndim))
1529+
.nonzero()
1530+
.transpose(0, 1)
1531+
.repeat_interleave(nnz, 1));
1532+
}
15221533
const input_t* crow_indices_data_in = crow_indices_->data_ptr<input_t>();
15231534
TORCH_INTERNAL_ASSERT(indices.is_contiguous());
1524-
auto row0 = indices.select(0, transpose ? 1 : 0);
1525-
auto row1 = indices.select(0, transpose ? 0 : 1);
1535+
auto row0 = indices.select(0, transpose ? batch_ndim + 1 : batch_ndim + 0);
1536+
auto row1 = indices.select(0, transpose ? batch_ndim + 0 : batch_ndim + 1);
15261537
output_t* data_out = row0.data_ptr<output_t>();
1527-
row1.copy_(*col_indices.expect_contiguous());
1538+
auto col_indices_ = col_indices.expect_contiguous();
1539+
row1.copy_(col_indices_->view({-1}));
15281540
at::parallel_for(
1529-
0, nrows, at::internal::GRAIN_SIZE, [&](int64_t start, int64_t end) {
1530-
for (const auto i : c10::irange(start, end)) {
1541+
0, nrows * total_nnz / nnz, at::internal::GRAIN_SIZE, [&](int64_t start, int64_t end) {
1542+
for (const auto i_ : c10::irange(start, end)) {
1543+
auto b = i_ / nrows;
1544+
auto i = i_ % nrows;
15311545
std::fill(
1532-
&data_out[crow_indices_data_in[i]],
1533-
&data_out[crow_indices_data_in[i + 1]],
1546+
&data_out[b * nnz + crow_indices_data_in[b * (nrows + 1) + i]],
1547+
&data_out[b * nnz + crow_indices_data_in[b * (nrows + 1) + i + 1]],
15341548
static_cast<output_t>(i));
15351549
}
15361550
});
@@ -1829,27 +1843,30 @@ Tensor sparse_compressed_to_sparse(const Tensor& self, const int64_t sparse_dim)
18291843
Tensor values;
18301844
Tensor indices = at::_convert_indices_from_csr_to_coo(compressed_indices, plain_indices,
18311845
false, (layout == kSparseCsc || layout == kSparseBsc));
1846+
const auto batch_ndim = compressed_indices.dim() - 1;
18321847
// Only CSR is trivially coalesced
18331848
bool coalesced = layout == kSparseCsr || self.numel() == 0 || self._nnz() == 1;
18341849
AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(layout, "sparse_compressed_to_sparse",
1835-
[&] { values = self.values(); },
1850+
[&] { values = self.values().flatten(0, batch_ndim); },
18361851
[&] {
1837-
auto size = DimVector(self.sizes().slice(0, 2));
1838-
auto blocksize = DimVector(self.values().sizes().slice(1, 2));
1839-
1852+
auto blocksize = DimVector(self.values().sizes().slice(batch_ndim + 1, 2));
1853+
DimVector batch_blocksize;
1854+
batch_blocksize.append(batch_ndim, 1);
1855+
batch_blocksize.append(blocksize);
18401856
const auto max_blocksize = std::max(blocksize[0], blocksize[1]);
18411857
const auto max_blocksize_arange = at::arange(max_blocksize, indices.options());
18421858
const auto blocksize_arange_0 = max_blocksize_arange.narrow(-1, 0, blocksize[0]);
18431859
const auto blocksize_arange_1 = max_blocksize_arange.narrow(-1, 0, blocksize[1]);
1844-
const auto block_coo_indices = at::stack({
1860+
const auto block_coo_indices_ = at::stack({
18451861
blocksize_arange_0.unsqueeze(-1).expand({-1, blocksize[1]}),
18461862
blocksize_arange_1.unsqueeze(0).expand({blocksize[0], -1})
1847-
}).flatten(-2, -1);
1848-
1863+
}).flatten(-2, -1); // equivalent to torch.ones(blocksize).nonzero().T
1864+
const auto block_coo_indices = at::zeros({batch_ndim + 2, blocksize[0] * blocksize[1]}, indices.options());
1865+
block_coo_indices.narrow(0, batch_ndim, 2).copy_(block_coo_indices_);
18491866
indices = indices
18501867
// Scale indices that identify blocks to element-wise coordinates that correspond
18511868
// to the top-left corner of each block.
1852-
.mul(at::tensor(blocksize, indices.options()).unsqueeze_(-1))
1869+
.mul(at::tensor(batch_blocksize, indices.options()).unsqueeze_(1))
18531870
// Now that we know top-left block coordinates, we offset them with element-wise
18541871
// coordinates in the block to get the result.
18551872
// NOTE: indices is mapped from (dim, nnz) to (dim, nnz, 1),
@@ -1861,10 +1878,10 @@ Tensor sparse_compressed_to_sparse(const Tensor& self, const int64_t sparse_dim)
18611878
// to produce valid nnz dimension of a COO tensor.
18621879
.flatten(-2, -1);
18631880

1864-
values = self.values().flatten(0, 2);
1881+
values = self.values().flatten(0, batch_ndim + 2);
18651882

18661883
// BSRs not spanning across several rows produces coalesced results.
1867-
coalesced |= (layout == kSparseBsr && blocksize[0] == 1);
1884+
coalesced |= (layout == kSparseBsr && blocksize[0] == 1 && batch_ndim == 0);
18681885
});
18691886
return at::native::_sparse_coo_tensor_unsafe(indices, values, self.sizes())._coalesced_(coalesced);
18701887
}

aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp

+4-35
Original file line numberDiff line numberDiff line change
@@ -147,21 +147,18 @@ TORCH_META_FUNC(_convert_indices_from_csr_to_coo)
147147
const bool out_int32,
148148
const bool transpose) {
149149
TORCH_CHECK(
150-
crow_indices.dim() == 1, "crow_indices is supposed to be a vector, but got ",
151-
crow_indices.dim(), " dimensional tensor.");
152-
TORCH_CHECK(col_indices.dim() == 1, "col_indices is supposed to be a vector, but got ",
153-
col_indices.dim(), " dimensional tensor.");
150+
crow_indices.dim() == col_indices.dim(), "crow_indices and col_indices are supposed to have"
151+
" the same dimensionality, but got ", crow_indices.dim(), " and ",
152+
crow_indices.dim(), " dimensional tensors, respectively.");
154153
ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
155154
c10::TensorOptions options = crow_indices.options().dtype(scalar_type);
156-
set_output_raw_strided(0, {2, col_indices.numel()}, {}, options, {});
155+
set_output_raw_strided(0, {col_indices.dim() + 1, col_indices.numel()}, {}, options, {});
157156
}
158157

159158
} // namespace meta
160159

161160
namespace {
162161

163-
constexpr int64_t GRAIN_SIZE = at::internal::GRAIN_SIZE;
164-
165162
template <typename F>
166163
Tensor& unary_op_out(F op_out, const Tensor& self, Tensor& result) {
167164
TORCH_INTERNAL_ASSERT(self.is_sparse_csr());
@@ -194,34 +191,6 @@ Tensor& unary_op_inplace(Tensor& self, const F& op_inplace, Args&&... args) {
194191
return self;
195192
}
196193

197-
template <typename input_t, typename output_t>
198-
void convert_indices_from_csr_to_coo_cpu(
199-
const Tensor& indices,
200-
const Tensor& crow_indices,
201-
const Tensor& col_indices,
202-
const bool transpose = false) {
203-
int64_t nrows = crow_indices.numel() - 1;
204-
if (nrows == 0) {
205-
indices.zero_();
206-
return;
207-
}
208-
auto crow_indices_ = crow_indices.expect_contiguous();
209-
const input_t* crow_indices_data_in = crow_indices_->data_ptr<input_t>();
210-
TORCH_INTERNAL_ASSERT(indices.is_contiguous());
211-
auto row0 = indices.select(0, transpose ? 1 : 0);
212-
auto row1 = indices.select(0, transpose ? 0 : 1);
213-
output_t* data_out = row0.data_ptr<output_t>();
214-
row1.copy_(*col_indices.expect_contiguous());
215-
at::parallel_for(0, nrows, GRAIN_SIZE, [&](int64_t start, int64_t end) {
216-
for (const auto i : c10::irange(start, end)) {
217-
std::fill(
218-
&data_out[crow_indices_data_in[i]],
219-
&data_out[crow_indices_data_in[i + 1]],
220-
static_cast<output_t>(i));
221-
}
222-
});
223-
}
224-
225194
} // end anonymous namespace
226195

227196
namespace native {

aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu

+27-12
Original file line numberDiff line numberDiff line change
@@ -85,36 +85,51 @@ void convert_indices_from_coo_to_csr_cuda(const Tensor& result, const Tensor& in
8585
}
8686

8787
template <typename input_t, typename output_t>
88-
__global__ void convert_indices_from_csr_to_coo_cuda_kernel(output_t* data_out, const input_t* data_in, const int64_t nrows) {
88+
__global__ void convert_indices_from_csr_to_coo_cuda_kernel(output_t* data_out, const input_t* data_in, const int64_t nrows, const int64_t nnz, const int64_t nbatches) {
8989
int64_t tid = blockDim.x * blockIdx.x + threadIdx.x;
9090

91-
if (tid < nrows) {
92-
for (int64_t i = data_in[tid]; i < data_in[tid + 1]; i++)
93-
data_out[i] = static_cast<output_t>(tid);
91+
if (tid < nrows * nbatches) {
92+
int64_t b = tid / nrows;
93+
int64_t i_ = b * (nrows + 1) + tid % nrows;
94+
for (int64_t i = data_in[i_]; i < data_in[i_ + 1]; i++) {
95+
data_out[b * nnz + i] = static_cast<output_t>(tid % nrows);
96+
}
9497
}
9598
}
9699

97100
template <typename input_t, typename output_t>
98101
void convert_indices_from_csr_to_coo_cuda(const Tensor& indices, const Tensor& crow_indices, const Tensor& col_indices, const bool transpose=false) {
99-
int64_t nrows = crow_indices.numel() - 1;
100-
if (nrows == 0) {
102+
int64_t nrows = crow_indices.size(-1) - 1;
103+
int64_t nnz = col_indices.size(-1);
104+
if (nrows == 0 || nnz == 0) {
101105
indices.zero_();
102106
return;
103107
}
108+
int64_t total_nnz = col_indices.numel();
109+
int64_t batch_ndim = crow_indices.dim() - 1;
110+
if (batch_ndim > 0) {
111+
auto batch_indices = indices.narrow(0, 0, batch_ndim);
112+
batch_indices.copy_(batch_indices.new_ones(crow_indices.sizes().slice(0, batch_ndim))
113+
.nonzero()
114+
.transpose(0, 1)
115+
.repeat_interleave(nnz, 1));
116+
}
104117

105118
auto crow_indices_ = crow_indices.expect_contiguous();
106119
const input_t* crow_indices_data_in = crow_indices_->data_ptr<input_t>();
107120
TORCH_INTERNAL_ASSERT(indices.is_contiguous());
108-
auto row0 = indices.select(0, transpose?1:0);
109-
auto row1 = indices.select(0, transpose?0:1);
121+
auto row0 = indices.select(0, transpose?batch_ndim + 1:batch_ndim + 0);
122+
auto row1 = indices.select(0, transpose?batch_ndim + 0:batch_ndim + 1);
123+
auto col_indices_ = col_indices.expect_contiguous();
124+
row1.copy_(col_indices_->view({-1}));
110125
output_t* data_out = row0.data_ptr<output_t>();
111126

112-
// Run nrows threads...
127+
// Run nrows * nbatches threads...
128+
int64_t nbatches = total_nnz / nnz;
113129
int64_t THREADS = at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock;
114-
int64_t BLOCKS = (nrows + THREADS) / THREADS;
130+
int64_t BLOCKS = (nrows * nbatches + THREADS) / THREADS;
115131
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
116-
row1.copy_(*col_indices.expect_contiguous());
117-
convert_indices_from_csr_to_coo_cuda_kernel<<<BLOCKS, THREADS, 0, stream>>>(data_out, crow_indices_data_in, nrows);
132+
convert_indices_from_csr_to_coo_cuda_kernel<<<BLOCKS, THREADS, 0, stream>>>(data_out, crow_indices_data_in, nrows, nnz, nbatches);
118133
C10_CUDA_KERNEL_LAUNCH_CHECK();
119134
}
120135

test/test_sparse.py

-9
Original file line numberDiff line numberDiff line change
@@ -4660,15 +4660,6 @@ def explicit_to_sparse(x):
46604660
r"conversion from Sparse to .* for input tensors with sparse_dim\(\)!=2 is not supported"):
46614661
explicit_to_sparse(t)
46624662
continue
4663-
elif from_layout in {torch.sparse_csr, torch.sparse_csc,
4664-
torch.sparse_bsr, torch.sparse_bsc} and to_layout is torch.sparse_coo and is_batch:
4665-
with self.assertRaisesRegex(RuntimeError,
4666-
"crow_indices is supposed to be a vector, but got \\d+ dimensional tensor"):
4667-
t.to_sparse(layout=to_layout, blocksize=blocksize)
4668-
with self.assertRaisesRegex(RuntimeError,
4669-
"crow_indices is supposed to be a vector, but got \\d+ dimensional tensor"):
4670-
explicit_to_sparse(t)
4671-
continue
46724663
elif (from_layout, to_layout) in {(torch.sparse_bsc, torch.sparse_csr), (torch.sparse_bsc, torch.sparse_csc),
46734664
(torch.sparse_bsr, torch.sparse_csr), (torch.sparse_bsr, torch.sparse_csc)}:
46744665
with self.assertRaisesRegex(

torch/testing/_internal/opinfo/definitions/sparse.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,10 @@ def _validate_sample_input_elementwise_binary_sparse_mul(sample):
574574
if layout is torch.sparse_csr and batch_dim > 0 and t_args[0].ndim > 0:
575575
return ErrorInput(
576576
sample,
577-
error_regex="crow_indices is supposed to be a vector, but got 2 dimensional tensor",
577+
error_regex=(
578+
"coo_to_sparse_csr: conversion from Sparse to SparseCsr for input"
579+
" tensors with sparse_dim[(][)]!=2 is not supported"
580+
),
578581
)
579582
elif layout is torch.sparse_csc and t_args[0].ndim > 0:
580583
return ErrorInput(

0 commit comments

Comments
 (0)