@@ -79,64 +79,6 @@ Tensor bmm_nested(const Tensor& self, const Tensor& mat2) {
79
79
return output;
80
80
}
81
81
82
- // utilities support `matmul_nested`
83
- namespace {
84
- // Args:
85
- // self_sizes: the sizes of `self` in `matmul_nested`
86
- // mat2_sizes: the sizes of `mat2` in `matmul_nested`
87
- // buffer_op: the options for new buffer
88
- // sizemat_op: the options for new size matrix
89
- // Returns:
90
- // the batch size of each input underlying tensor, i.e. the product of batch-dimension sizes
91
- // the empty output nested tensor
92
- inline std::tuple<std::vector<int64_t>, Tensor>
93
- matmul_nested_helper(
94
- const std::vector<IntArrayRef>& self_sizes,
95
- const std::vector<IntArrayRef>& mat2_sizes,
96
- const c10::TensorOptions& buffer_op,
97
- const c10::TensorOptions& sizemat_op) {
98
- int64_t ntensors = self_sizes.size(),
99
- ndims = self_sizes[0].size();
100
- std::vector<int64_t> batch_sizes(ntensors, 1);
101
- Tensor sizemat = at::empty({ntensors, ndims}, sizemat_op);
102
- int64_t* sizemat_ptr = sizemat.mutable_data_ptr<int64_t>();
103
- int64_t numel = 0;
104
- for (int64_t i = 0; i < ntensors; i++) {
105
- const IntArrayRef& self_size = self_sizes[i],
106
- & mat2_size = mat2_sizes[i];
107
- int64_t& batch_size = batch_sizes[i];
108
- // batch dimensions
109
- for (int64_t j = 0; j < ndims - 2; j++) {
110
- const int64_t& self_sizej = self_size[j],
111
- & mat2_sizej = mat2_size[j];
112
- TORCH_CHECK(
113
- self_sizej == mat2_sizej,
114
- "matmul: For nested tensors, no broadcasting is currently performed: ",
115
- i, "-th nested matrices in batch at dimension ", j + 1,
116
- " have mismatching sizes ", self_sizej, " and ", mat2_sizej);
117
- sizemat_ptr[j] = self_sizej;
118
- batch_size *= sizemat_ptr[j];
119
- }
120
- // matrix multiplication dimensions
121
- const int64_t& self_size0 = self_size[ndims - 2], & self_size1 = self_size[ndims - 1],
122
- & mat2_size0 = mat2_size[ndims - 2], & mat2_size1 = mat2_size[ndims - 1];
123
- TORCH_CHECK(
124
- self_size1 == mat2_size0,
125
- "matmul: ",
126
- i, "-th nested matrices in batch cannot be multiplied (",
127
- self_size0, "x", self_size1, " and ",
128
- mat2_size0, "x", mat2_size1, ")");
129
- sizemat_ptr[ndims - 2] = self_size0;
130
- sizemat_ptr[ndims - 1] = mat2_size1;
131
- sizemat_ptr += ndims;
132
- numel += batch_size * self_size0 * mat2_size1;
133
- }
134
- Tensor buffer = at::empty(numel, buffer_op);
135
- Tensor output = wrap_buffer(buffer, sizemat);
136
- return std::make_tuple(batch_sizes, output);
137
- }
138
- }
139
-
140
82
Tensor matmul_with_bmm_nested(const Tensor& self, const Tensor& mat2) {
141
83
// Tensor self = self_.contiguous();
142
84
// Tensor mat2 = mat2_.contiguous();
0 commit comments