-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
Copy pathsplit.cpp
127 lines (102 loc) · 4.23 KB
/
split.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/op/split.hpp"
#include <numeric>
#include "bound_evaluate.hpp"
#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/reference/split.hpp"
#include "split_shape_inference.hpp"
namespace ov {
namespace op {
namespace v1 {
namespace validate {
namespace {
bool axis_type(const element::Type& et) {
return et.is_integral_number();
}
} // namespace
} // namespace validate
Split::Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_splits)
: Op({data, axis}),
m_num_splits{num_splits} {
constructor_validate_and_infer_types();
}
bool Split::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v1_Split_visit_attributes);
visitor.on_attribute("num_splits", m_num_splits);
return true;
}
void Split::validate_and_infer_types() {
OV_OP_SCOPE(v1_Split_validate_and_infer_types);
const auto& axis_et = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
validate::axis_type(axis_et),
"Element type of 'axis' input must be integer. Got: ",
axis_et);
NODE_VALIDATION_CHECK(this,
m_num_splits > 0,
"Attribute 'num_splits' must be greater than zero. Got: ",
m_num_splits);
const auto input_shapes = ov::util::get_node_input_partial_shapes(*this);
const auto output_shapes = shape_infer(this, input_shapes);
for (size_t i = 0; i < m_num_splits; ++i) {
set_output_type(i, get_input_element_type(0), output_shapes[i]);
}
set_input_is_relevant_to_shape(0);
}
std::shared_ptr<Node> Split::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v1_Split_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<Split>(new_args.at(0), new_args.at(1), m_num_splits);
}
bool Split::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v1_Split_evaluate);
OPENVINO_ASSERT(outputs.size() == m_num_splits);
const auto output_shapes =
shape_infer(this, ov::util::get_tensors_partial_shapes(inputs), make_tensor_accessor(inputs));
const auto& axis_tensor = inputs[1];
const auto result = validate::axis_type(axis_tensor.get_element_type());
if (result) {
const auto& data_tensor = inputs[0];
auto outputs_data = std::vector<char*>(m_num_splits);
{
auto outputs_it = outputs.begin();
auto outputs_data_it = outputs_data.begin();
for (const auto& p_shape : output_shapes) {
outputs_it->set_shape(p_shape.get_shape());
*outputs_data_it = static_cast<char*>(outputs_it->data());
++outputs_it, ++outputs_data_it;
}
}
auto axis = get_tensor_data_as<int64_t>(axis_tensor).front();
axis = ov::util::normalize(axis, data_tensor.get_shape().size());
ov::reference::split(static_cast<const char*>(data_tensor.data()),
data_tensor.get_shape(),
data_tensor.get_element_type().size(),
axis,
m_num_splits,
outputs_data.data());
}
return result;
}
bool Split::has_evaluate() const {
OV_OP_SCOPE(v1_Split_has_evaluate);
return validate::axis_type(get_input_element_type(1));
}
bool Split::evaluate_lower(ov::TensorVector& output_values) const {
OV_OP_SCOPE(v1_Split_evaluate_lower);
return get_input_tensor(1).has_and_set_bound() && default_lower_bound_evaluator(this, output_values);
}
bool Split::evaluate_upper(ov::TensorVector& output_values) const {
OV_OP_SCOPE(v1_Split_evaluate_upper);
return get_input_tensor(1).has_and_set_bound() && default_upper_bound_evaluator(this, output_values);
}
bool Split::evaluate_symbol(TensorSymbolVector& output_symbols) const {
OPENVINO_ASSERT(output_symbols.size() == get_num_splits());
return get_input_tensor(1).has_and_set_bound() && ov::util::default_symbol_evaluator(this, output_symbols);
}
} // namespace v1
} // namespace op
} // namespace ov