Skip to content

Commit 7b4b3b7

Browse files
authored
Update randperm.cpp
now it can handle value of n dynamically
1 parent 7c9e745 commit 7b4b3b7

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

src/frontends/pytorch/src/op/randperm.cpp

+15-15
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
15
#include "openvino/frontend/pytorch/node_context.hpp"
26
#include "openvino/op/constant.hpp"
37
#include "openvino/op/random_uniform.hpp"
4-
#include "openvino/op/shape_of.hpp"
58
#include "openvino/op/topk.hpp"
9+
#include "openvino/op/unsqueeze.hpp"
610
#include "utils.hpp"
711

812
namespace ov {
@@ -14,38 +18,34 @@ using namespace ov::op;
1418

1519
OutputVector translate_randperm(const NodeContext& context) {
1620
auto num_inputs = context.get_input_size();
17-
int64_t n = context.const_input<int64_t>(0);
21+
auto n_node = context.get_input(0);
1822
int dtype_value = 4;
1923
if (num_inputs == 1) {
2024
} else if (num_inputs == 2) {
2125
if (!context.input_is_none(1)) {
2226
dtype_value = context.const_input<int>(1);
23-
OPENVINO_ASSERT(dtype_value == 4,
24-
"Only dtype value 4 (int64) is supported for aten::randperm, got: ",
25-
dtype_value);
27+
PYTORCH_OP_CONVERSION_CHECK(dtype_value == 4,
28+
"Only dtype value 4 (int64) is supported for aten::randperm, got: ",
29+
dtype_value);
2630
}
2731
} else if (num_inputs == 5) {
2832
if (!context.input_is_none(1)) {
2933
dtype_value = context.const_input<int>(1);
30-
OPENVINO_ASSERT(dtype_value == 4,
31-
"Only dtype value 4 (int64) is supported for aten::randperm, got: ",
32-
dtype_value);
34+
PYTORCH_OP_CONVERSION_CHECK(dtype_value == 4,
35+
"Only dtype value 4 (int64) is supported for aten::randperm, got: ",
36+
dtype_value);
3337
}
3438
} else {
3539
PYTORCH_OP_CONVERSION_CHECK(false, "Unexpected number of inputs for aten::randperm: ", num_inputs);
3640
}
37-
if (n == 0) {
38-
auto const_empty = std::make_shared<v0::Constant>(element::i64, Shape{0}, std::vector<int64_t>{});
39-
return {context.mark_node(const_empty)};
40-
}
41-
auto shape = v0::Constant::create(element::i64, Shape{1}, {n});
41+
auto axis_zero = v0::Constant::create(element::i64, Shape{1}, {0});
42+
auto shape = context.mark_node(std::make_shared<v0::Unsqueeze>(n_node, axis_zero));
4243
auto min_val = v0::Constant::create(element::f32, Shape{}, {0.0f});
4344
auto max_val = v0::Constant::create(element::f32, Shape{}, {1.0f});
4445
auto random_tensor = context.mark_node(std::make_shared<v8::RandomUniform>(shape, min_val, max_val, element::f32));
4546
const int64_t axis = 0;
46-
auto k = v0::Constant::create(element::i64, Shape{}, {n});
4747
auto topk = context.mark_node(std::make_shared<v11::TopK>(random_tensor,
48-
k,
48+
n_node,
4949
axis,
5050
ov::op::TopKMode::MIN,
5151
ov::op::TopKSortType::SORT_VALUES,

0 commit comments

Comments
 (0)