1
+ # Copyright (C) 2018-2023 Intel Corporation
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import pytest
5
+
6
+ from pytorch_layer_test_class import PytorchLayerTest
7
+
8
+
9
+ class TestOuter (PytorchLayerTest ):
10
+ def _prepare_input (self , x_shape , y_shape , x_dtype , y_dtype , out = False ):
11
+ import numpy as np
12
+ x = np .random .randn (* x_shape ).astype (x_dtype )
13
+ y = np .random .randn (* y_shape ).astype (y_dtype )
14
+ if not out :
15
+ return (x , y )
16
+ out = np .zeros ((x_shape [0 ], y_shape [0 ]))
17
+ return (x , y , out )
18
+
19
+ def create_model (self , out = False , x_dtype = "float32" , y_dtype = "float32" ):
20
+ import torch
21
+
22
+ dtypes = {
23
+ "float32" : torch .float32 ,
24
+ "float64" : torch .float64 ,
25
+ "int32" : torch .int32
26
+ }
27
+ x_dtype = dtypes [x_dtype ]
28
+ y_dtype = dtypes [y_dtype ]
29
+ class aten_outer (torch .nn .Module ):
30
+ def __init__ (self , out , x_dtype , y_dtype ) -> None :
31
+ super ().__init__ ()
32
+ self .x_dtype = x_dtype
33
+ self .y_dtype = y_dtype
34
+ if out :
35
+ self .forward = self .forward_out
36
+
37
+ def forward (self , x , y ):
38
+ return torch .outer (x .to (self .x_dtype ), y .to (self .y_dtype ))
39
+
40
+ def forward_out (self , x , y , out ):
41
+ return torch .outer (x .to (self .x_dtype ), y .to (self .y_dtype ), out = out ), out
42
+
43
+ ref_net = None
44
+
45
+ return aten_outer (out , x_dtype , y_dtype ), ref_net , 'aten::outer'
46
+
47
+ @pytest .mark .parametrize ("x_shape" , ([1 ], [2 ], [3 ]))
48
+ @pytest .mark .parametrize ("y_shape" , ([1 ], [7 ], [5 ]))
49
+ @pytest .mark .parametrize ("x_dtype" , ("float32" , "float64" , "int32" ))
50
+ @pytest .mark .parametrize ("y_dtype" , ("float32" , "float64" , "int32" ))
51
+ @pytest .mark .parametrize ("out" , [True , False ])
52
+ @pytest .mark .nightly
53
+ @pytest .mark .precommit
54
+ def test_numel (self , x_shape , y_shape , x_dtype , y_dtype , out , ie_device , precision , ir_version ):
55
+ self ._test (* self .create_model (out , x_dtype , y_dtype ), ie_device , precision , ir_version ,
56
+ kwargs_to_prepare_input = {"out" : out , "x_shape" : x_shape , "y_shape" : y_shape , "x_dtype" : x_dtype , "y_dtype" : y_dtype })
0 commit comments