Skip to content

Commit 3c4beb1

Browse files
authored
[PT FE]: support aten::outer op (openvinotoolkit#18903)
1 parent da36633 commit 3c4beb1

File tree

3 files changed

+99
-0
lines changed

3 files changed

+99
-0
lines changed
+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright (C) 2018-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "openvino/frontend/pytorch/node_context.hpp"
6+
#include "openvino/op/constant.hpp"
7+
#include "openvino/op/convert_like.hpp"
8+
#include "openvino/op/matmul.hpp"
9+
#include "openvino/op/unsqueeze.hpp"
10+
#include "utils.hpp"
11+
12+
namespace ov {
13+
namespace frontend {
14+
namespace pytorch {
15+
namespace op {
16+
17+
using namespace ov::op;
18+
19+
OutputVector translate_outer(const NodeContext& context) {
20+
// aten::outer(Tensor self, Tensor vec2) -> Tensor
21+
// aten::outer.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)
22+
num_inputs_check(context, 2, 3);
23+
auto vec1 = context.get_input(0);
24+
auto vec2 = context.get_input(1);
25+
align_eltwise_input_types(context, vec1, vec2, true);
26+
auto const_zero = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
27+
auto const_minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
28+
vec1 = context.mark_node(std::make_shared<v0::Unsqueeze>(vec1, const_minus_one));
29+
vec2 = context.mark_node(std::make_shared<v0::Unsqueeze>(vec2, const_zero));
30+
auto out = context.mark_node(std::make_shared<v0::MatMul>(vec1, vec2));
31+
if (!context.input_is_none(2)) {
32+
out = context.mark_node(std::make_shared<v1::ConvertLike>(out, context.get_input(2)));
33+
context.mutate_input(2, out);
34+
}
35+
return {out};
36+
};
37+
38+
} // namespace op
39+
} // namespace pytorch
40+
} // namespace frontend
41+
} // namespace ov

src/frontends/pytorch/src/op_table.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ OP_CONVERTER(translate_norm);
109109
OP_CONVERTER(translate_numel);
110110
OP_CONVERTER(translate_ones);
111111
OP_CONVERTER(translate_ones_like);
112+
OP_CONVERTER(translate_outer);
112113
OP_CONVERTER(translate_pad);
113114
OP_CONVERTER(translate_pairwise_distance);
114115
OP_CONVERTER(translate_pow);
@@ -361,6 +362,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
361362
{"aten::numel", op::translate_numel},
362363
{"aten::ones", op::translate_ones},
363364
{"aten::ones_like", op::translate_ones_like},
365+
{"aten::outer", op::translate_outer},
364366
{"aten::pad", op::translate_pad},
365367
{"aten::pairwise_distance", op::translate_pairwise_distance},
366368
{"aten::permute", op::translate_1to1_match_2_inputs<opset10::Transpose>},
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (C) 2018-2023 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import pytest
5+
6+
from pytorch_layer_test_class import PytorchLayerTest
7+
8+
9+
class TestOuter(PytorchLayerTest):
10+
def _prepare_input(self, x_shape, y_shape, x_dtype, y_dtype, out=False):
11+
import numpy as np
12+
x = np.random.randn(*x_shape).astype(x_dtype)
13+
y = np.random.randn(*y_shape).astype(y_dtype)
14+
if not out:
15+
return (x, y)
16+
out = np.zeros((x_shape[0], y_shape[0]))
17+
return (x, y, out)
18+
19+
def create_model(self, out=False, x_dtype="float32", y_dtype="float32"):
20+
import torch
21+
22+
dtypes = {
23+
"float32": torch.float32,
24+
"float64": torch.float64,
25+
"int32": torch.int32
26+
}
27+
x_dtype = dtypes[x_dtype]
28+
y_dtype = dtypes[y_dtype]
29+
class aten_outer(torch.nn.Module):
30+
def __init__(self, out, x_dtype, y_dtype) -> None:
31+
super().__init__()
32+
self.x_dtype = x_dtype
33+
self.y_dtype = y_dtype
34+
if out:
35+
self.forward = self.forward_out
36+
37+
def forward(self, x, y):
38+
return torch.outer(x.to(self.x_dtype), y.to(self.y_dtype))
39+
40+
def forward_out(self, x, y, out):
41+
return torch.outer(x.to(self.x_dtype), y.to(self.y_dtype), out=out), out
42+
43+
ref_net = None
44+
45+
return aten_outer(out, x_dtype, y_dtype), ref_net, 'aten::outer'
46+
47+
@pytest.mark.parametrize("x_shape", ([1], [2], [3]))
48+
@pytest.mark.parametrize("y_shape", ([1], [7], [5]))
49+
@pytest.mark.parametrize("x_dtype", ("float32", "float64", "int32"))
50+
@pytest.mark.parametrize("y_dtype", ("float32", "float64", "int32"))
51+
@pytest.mark.parametrize("out", [True, False])
52+
@pytest.mark.nightly
53+
@pytest.mark.precommit
54+
def test_numel(self, x_shape, y_shape, x_dtype, y_dtype, out, ie_device, precision, ir_version):
55+
self._test(*self.create_model(out, x_dtype, y_dtype), ie_device, precision, ir_version,
56+
kwargs_to_prepare_input={"out": out, "x_shape": x_shape, "y_shape": y_shape, "x_dtype": x_dtype, "y_dtype": y_dtype})

0 commit comments

Comments
 (0)