1
+ // Copyright (C) 2018-2025 Intel Corporation
2
+ // SPDX-License-Identifier: Apache-2.0
3
+ //
4
+
1
5
#include " openvino/frontend/pytorch/node_context.hpp"
2
6
#include " openvino/op/constant.hpp"
3
7
#include " openvino/op/random_uniform.hpp"
4
- #include " openvino/op/shape_of.hpp"
5
8
#include " openvino/op/topk.hpp"
9
+ #include " openvino/op/unsqueeze.hpp"
6
10
#include " utils.hpp"
7
11
8
12
namespace ov {
@@ -14,38 +18,34 @@ using namespace ov::op;
14
18
15
19
OutputVector translate_randperm (const NodeContext& context) {
16
20
auto num_inputs = context.get_input_size ();
17
- int64_t n = context.const_input < int64_t > (0 );
21
+ auto n_node = context.get_input (0 );
18
22
int dtype_value = 4 ;
19
23
if (num_inputs == 1 ) {
20
24
} else if (num_inputs == 2 ) {
21
25
if (!context.input_is_none (1 )) {
22
26
dtype_value = context.const_input <int >(1 );
23
- OPENVINO_ASSERT (dtype_value == 4 ,
24
- " Only dtype value 4 (int64) is supported for aten::randperm, got: " ,
25
- dtype_value);
27
+ PYTORCH_OP_CONVERSION_CHECK (dtype_value == 4 ,
28
+ " Only dtype value 4 (int64) is supported for aten::randperm, got: " ,
29
+ dtype_value);
26
30
}
27
31
} else if (num_inputs == 5 ) {
28
32
if (!context.input_is_none (1 )) {
29
33
dtype_value = context.const_input <int >(1 );
30
- OPENVINO_ASSERT (dtype_value == 4 ,
31
- " Only dtype value 4 (int64) is supported for aten::randperm, got: " ,
32
- dtype_value);
34
+ PYTORCH_OP_CONVERSION_CHECK (dtype_value == 4 ,
35
+ " Only dtype value 4 (int64) is supported for aten::randperm, got: " ,
36
+ dtype_value);
33
37
}
34
38
} else {
35
39
PYTORCH_OP_CONVERSION_CHECK (false , " Unexpected number of inputs for aten::randperm: " , num_inputs);
36
40
}
37
- if (n == 0 ) {
38
- auto const_empty = std::make_shared<v0::Constant>(element::i64, Shape{0 }, std::vector<int64_t >{});
39
- return {context.mark_node (const_empty)};
40
- }
41
- auto shape = v0::Constant::create (element::i64, Shape{1 }, {n});
41
+ auto axis_zero = v0::Constant::create (element::i64, Shape{1 }, {0 });
42
+ auto shape = context.mark_node (std::make_shared<v0::Unsqueeze>(n_node, axis_zero));
42
43
auto min_val = v0::Constant::create (element::f32, Shape{}, {0 .0f });
43
44
auto max_val = v0::Constant::create (element::f32, Shape{}, {1 .0f });
44
45
auto random_tensor = context.mark_node (std::make_shared<v8::RandomUniform>(shape, min_val, max_val, element::f32));
45
46
const int64_t axis = 0 ;
46
- auto k = v0::Constant::create (element::i64, Shape{}, {n});
47
47
auto topk = context.mark_node (std::make_shared<v11::TopK>(random_tensor,
48
- k ,
48
+ n_node ,
49
49
axis,
50
50
ov::op::TopKMode::MIN,
51
51
ov::op::TopKSortType::SORT_VALUES,
0 commit comments