Skip to content

Commit 576383c

Browse files
lancertspytorchmergebot
authored andcommitted
Add torch check for dtype within bilinear (pytorch#118900)
Fixes pytorch#117237 Short-term fix, when dtype does not match, it will be reflected in the torch check. @ezyang a cpp test case is added Pull Request resolved: pytorch#118900 Approved by: https://github.com/ezyang, https://github.com/malfet
1 parent a4355d6 commit 576383c

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

aten/src/ATen/native/Linear.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,28 @@ Tensor bilinear(const Tensor& input1, const Tensor& input2, const Tensor& weight
703703
// See [Note: hacky wrapper removal for optional tensor]
704704
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
705705
const Tensor& bias = *bias_maybe_owned;
706+
if (bias.defined()) {
707+
TORCH_CHECK(
708+
input1.dtype() == input2.dtype() && input1.dtype() == weight.dtype() &&
709+
input1.dtype() == bias.dtype(),
710+
"All tensors must have the same dtype, got input1: ",
711+
input1.dtype(),
712+
", input2: ",
713+
input2.dtype(),
714+
", weight: ",
715+
weight.dtype(),
716+
", bias: ",
717+
bias.dtype());
718+
} else {
719+
TORCH_CHECK(
720+
input1.dtype() == input2.dtype() && input1.dtype() == weight.dtype(),
721+
"All tensors must have the same dtype, got input1: ",
722+
input1.dtype(),
723+
", input2: ",
724+
input2.dtype(),
725+
", weight: ",
726+
weight.dtype());
727+
}
706728

707729
TORCH_CHECK(input1.dim() == input2.dim(), "bilinear(): input dimensions do not match: got ", input1.dim(), " and ", input2.dim());
708730
for (const auto i : c10::irange(input1.dim() - 1)) {

test/cpp/api/functional.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,13 @@ TEST_F(FunctionalTest, Bilinear) {
15241524
ASSERT_EQ(y_no_bias.sizes(), torch::IntArrayRef({2, 1}));
15251525
auto y_no_bias_exp = torch::tensor({{448, 1701}}).reshape({2, 1});
15261526
ASSERT_TRUE(torch::allclose(y_no_bias, y_no_bias_exp, 1e-4, 1e-7));
1527+
1528+
input1 = input1.to(torch::kFloat64);
1529+
input2 = input2.to(torch::kInt32);
1530+
weight = weight.to(torch::kInt32);
1531+
ASSERT_THROWS_WITH(
1532+
F::bilinear(input1, input2, weight),
1533+
"All tensors must have the same dtype, got input1: double, input2: int, weight: int");
15271534
}
15281535

15291536
TEST_F(FunctionalTest, Normalize) {

0 commit comments

Comments
 (0)