@@ -1513,24 +1513,38 @@ void convert_indices_from_csr_to_coo_cpu(
1513
1513
const Tensor& crow_indices,
1514
1514
const Tensor& col_indices,
1515
1515
const bool transpose = false ) {
1516
- int64_t nrows = crow_indices.numel () - 1 ;
1517
- if (nrows == 0 ) {
1518
- indices.zero_ ();
1516
+ int64_t nrows = crow_indices.size (-1 ) - 1 ;
1517
+ int64_t nnz = col_indices.size (-1 );
1518
+ if (nrows == 0 || nnz == 0 ) {
1519
+ indices.zero_ (); // is this needed as indices has a zero-valued
1520
+ // dimension when nrows or nnz is 0?
1519
1521
return ;
1520
1522
}
1521
1523
auto crow_indices_ = crow_indices.expect_contiguous ();
1524
+ int64_t total_nnz = col_indices.numel ();
1525
+ int64_t batch_ndim = crow_indices.dim () - 1 ;
1526
+ if (batch_ndim > 0 ) {
1527
+ auto batch_indices = indices.narrow (0 , 0 , batch_ndim);
1528
+ batch_indices.copy_ (batch_indices.new_ones (crow_indices.sizes ().slice (0 , batch_ndim))
1529
+ .nonzero ()
1530
+ .transpose (0 , 1 )
1531
+ .repeat_interleave (nnz, 1 ));
1532
+ }
1522
1533
const input_t * crow_indices_data_in = crow_indices_->data_ptr <input_t >();
1523
1534
TORCH_INTERNAL_ASSERT (indices.is_contiguous ());
1524
- auto row0 = indices.select (0 , transpose ? 1 : 0 );
1525
- auto row1 = indices.select (0 , transpose ? 0 : 1 );
1535
+ auto row0 = indices.select (0 , transpose ? batch_ndim + 1 : batch_ndim + 0 );
1536
+ auto row1 = indices.select (0 , transpose ? batch_ndim + 0 : batch_ndim + 1 );
1526
1537
output_t * data_out = row0.data_ptr <output_t >();
1527
- row1.copy_ (*col_indices.expect_contiguous ());
1538
+ auto col_indices_ = col_indices.expect_contiguous ();
1539
+ row1.copy_ (col_indices_->view ({-1 }));
1528
1540
at::parallel_for (
1529
- 0 , nrows, at::internal::GRAIN_SIZE, [&](int64_t start, int64_t end) {
1530
- for (const auto i : c10::irange (start, end)) {
1541
+ 0 , nrows * total_nnz / nnz, at::internal::GRAIN_SIZE, [&](int64_t start, int64_t end) {
1542
+ for (const auto i_ : c10::irange (start, end)) {
1543
+ auto b = i_ / nrows;
1544
+ auto i = i_ % nrows;
1531
1545
std::fill (
1532
- &data_out[crow_indices_data_in[i]],
1533
- &data_out[crow_indices_data_in[i + 1 ]],
1546
+ &data_out[b * nnz + crow_indices_data_in[b * (nrows + 1 ) + i]],
1547
+ &data_out[b * nnz + crow_indices_data_in[b * (nrows + 1 ) + i + 1 ]],
1534
1548
static_cast <output_t >(i));
1535
1549
}
1536
1550
});
@@ -1829,27 +1843,30 @@ Tensor sparse_compressed_to_sparse(const Tensor& self, const int64_t sparse_dim)
1829
1843
Tensor values;
1830
1844
Tensor indices = at::_convert_indices_from_csr_to_coo (compressed_indices, plain_indices,
1831
1845
false , (layout == kSparseCsc || layout == kSparseBsc ));
1846
+ const auto batch_ndim = compressed_indices.dim () - 1 ;
1832
1847
// Only CSR is trivially coalesced
1833
1848
bool coalesced = layout == kSparseCsr || self.numel () == 0 || self._nnz () == 1 ;
1834
1849
AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS (layout, " sparse_compressed_to_sparse" ,
1835
- [&] { values = self.values (); },
1850
+ [&] { values = self.values (). flatten ( 0 , batch_ndim) ; },
1836
1851
[&] {
1837
- auto size = DimVector (self.sizes ().slice (0 , 2 ));
1838
- auto blocksize = DimVector (self.values ().sizes ().slice (1 , 2 ));
1839
-
1852
+ auto blocksize = DimVector (self.values ().sizes ().slice (batch_ndim + 1 , 2 ));
1853
+ DimVector batch_blocksize;
1854
+ batch_blocksize.append (batch_ndim, 1 );
1855
+ batch_blocksize.append (blocksize);
1840
1856
const auto max_blocksize = std::max (blocksize[0 ], blocksize[1 ]);
1841
1857
const auto max_blocksize_arange = at::arange (max_blocksize, indices.options ());
1842
1858
const auto blocksize_arange_0 = max_blocksize_arange.narrow (-1 , 0 , blocksize[0 ]);
1843
1859
const auto blocksize_arange_1 = max_blocksize_arange.narrow (-1 , 0 , blocksize[1 ]);
1844
- const auto block_coo_indices = at::stack ({
1860
+ const auto block_coo_indices_ = at::stack ({
1845
1861
blocksize_arange_0.unsqueeze (-1 ).expand ({-1 , blocksize[1 ]}),
1846
1862
blocksize_arange_1.unsqueeze (0 ).expand ({blocksize[0 ], -1 })
1847
- }).flatten (-2 , -1 );
1848
-
1863
+ }).flatten (-2 , -1 ); // equivalent to torch.ones(blocksize).nonzero().T
1864
+ const auto block_coo_indices = at::zeros ({batch_ndim + 2 , blocksize[0 ] * blocksize[1 ]}, indices.options ());
1865
+ block_coo_indices.narrow (0 , batch_ndim, 2 ).copy_ (block_coo_indices_);
1849
1866
indices = indices
1850
1867
// Scale indices that identify blocks to element-wise coordinates that correspond
1851
1868
// to the top-left corner of each block.
1852
- .mul (at::tensor (blocksize , indices.options ()).unsqueeze_ (- 1 ))
1869
+ .mul (at::tensor (batch_blocksize , indices.options ()).unsqueeze_ (1 ))
1853
1870
// Now that we know top-left block coordinates, we offset them with element-wise
1854
1871
// coordinates in the block to get the result.
1855
1872
// NOTE: indices is mapped from (dim, nnz) to (dim, nnz, 1),
@@ -1861,10 +1878,10 @@ Tensor sparse_compressed_to_sparse(const Tensor& self, const int64_t sparse_dim)
1861
1878
// to produce valid nnz dimension of a COO tensor.
1862
1879
.flatten (-2 , -1 );
1863
1880
1864
- values = self.values ().flatten (0 , 2 );
1881
+ values = self.values ().flatten (0 , batch_ndim + 2 );
1865
1882
1866
1883
// BSRs not spanning across several rows produces coalesced results.
1867
- coalesced |= (layout == kSparseBsr && blocksize[0 ] == 1 );
1884
+ coalesced |= (layout == kSparseBsr && blocksize[0 ] == 1 && batch_ndim == 0 );
1868
1885
});
1869
1886
return at::native::_sparse_coo_tensor_unsafe (indices, values, self.sizes ())._coalesced_ (coalesced);
1870
1887
}
0 commit comments