diff --git a/src/core/dev_api/openvino/runtime/itensor.hpp b/src/core/dev_api/openvino/runtime/itensor.hpp index 72b5d592df3f00..49c0b37d2b6b8e 100644 --- a/src/core/dev_api/openvino/runtime/itensor.hpp +++ b/src/core/dev_api/openvino/runtime/itensor.hpp @@ -51,24 +51,32 @@ class OPENVINO_API ITensor : public std::enable_shared_from_this { virtual const ov::Strides& get_strides() const = 0; /** - * @brief Provides an access to the underlaying host memory + * @brief Provides an access to the underlying host memory * @param type Optional type parameter. * @note If type parameter is specified, the method throws an exception * if specified type's fundamental type does not match with tensor element type's fundamental type * @return A host pointer to tensor memory + * @{ */ - virtual void* data(const element::Type& type = {}) const = 0; + virtual void* data(const element::Type& type = {}); + virtual const void* data(const element::Type& type = {}) const = 0; + /// @} /** - * @brief Provides an access to the underlaying host memory casted to type `T` + * @brief Provides an access to the underlying host memory casted to type `T` * @return A host pointer to tensor memory casted to specified type `T`. * @note Throws exception if specified type does not match with tensor element type */ template ::type> - T* data() const { + T* data() { return static_cast(data(element::from())); } + template ::type> + const T* data() const { + return static_cast(data(element::from())); + } + /** * @brief Reports whether the tensor is continuous or not * diff --git a/src/core/include/openvino/runtime/tensor.hpp b/src/core/include/openvino/runtime/tensor.hpp index 830b1db694dd1d..c29c470839bc7f 100644 --- a/src/core/include/openvino/runtime/tensor.hpp +++ b/src/core/include/openvino/runtime/tensor.hpp @@ -121,6 +121,17 @@ class OPENVINO_API Tensor { */ Tensor(const element::Type& type, const Shape& shape, void* host_ptr, const Strides& strides = {}); + /** + * @brief Constructs Tensor using element type and shape. Wraps allocated host memory as read only. + * @note Does not perform memory allocation internally + * @param type Tensor element type + * @param shape Tensor shape + * @param host_ptr Pointer to pre-allocated host memory with initialized objects + * @param strides Optional strides parameters in bytes. Strides are supposed to be computed automatically based + * on shape and element size + */ + Tensor(const element::Type& type, const Shape& shape, const void* host_ptr, const Strides& strides = {}); + /** * @brief Constructs Tensor using port from node. Allocate internal host storage using default allocator * @param port port from node @@ -138,6 +149,16 @@ class OPENVINO_API Tensor { */ Tensor(const ov::Output& port, void* host_ptr, const Strides& strides = {}); + /** + * @brief Constructs Tensor using port from node. Wraps allocated host memory as read only. + * @note Does not perform memory allocation internally + * @param port port from node + * @param host_ptr Pointer to pre-allocated host memory with initialized objects + * @param strides Optional strides parameters in bytes. Strides are supposed to be computed automatically based + * on shape and element size + */ + Tensor(const ov::Output& port, const void* host_ptr, const Strides& strides = {}); + /** * @brief Constructs region of interest (ROI) tensor form another tensor. * @note Does not perform memory allocation internally @@ -197,23 +218,37 @@ class OPENVINO_API Tensor { Strides get_strides() const; /** - * @brief Provides an access to the underlaying host memory + * @brief Provides an access to the underlying host memory * @param type Optional type parameter. * @note If type parameter is specified, the method throws an exception * if specified type's fundamental type does not match with tensor element type's fundamental type * @return A host pointer to tensor memory + * @{ */ - void* data(const element::Type& type = {}) const; + const void* data(const element::Type& type = {}) const; + void* data(const element::Type& type = {}); + /// @} /** - * @brief Provides an access to the underlaying host memory casted to type `T` + * @brief Provides an access to the underlying host memory casted to type `T` * @return A host pointer to tensor memory casted to specified type `T`. * @note Throws exception if specified type does not match with tensor element type + * @{ */ - template ::type> - T* data() const { - return static_cast(data(element::from())); + template > + const T* data() const { + return static_cast(data(element::from())); + } + + template > + T* data() { + if constexpr (std::is_const_v) { + return static_cast(this)->data(); + } else { + return static_cast(data(element::from())); + } } + /// @} /** * @brief Checks if current Tensor object is not initialized @@ -234,7 +269,7 @@ class OPENVINO_API Tensor { * @return true if this object can be dynamically cast to the type const T*. Otherwise, false */ template - typename std::enable_if::value, bool>::type is() const noexcept { + std::enable_if_t, bool> is() const noexcept { try { T::type_check(*this); } catch (...) { @@ -250,7 +285,7 @@ class OPENVINO_API Tensor { * @return T object */ template - const typename std::enable_if::value, T>::type as() const { + const std::enable_if_t, T> as() const { T::type_check(*this); return *static_cast(this); } diff --git a/src/core/reference/include/openvino/reference/interpolate.hpp b/src/core/reference/include/openvino/reference/interpolate.hpp index 5696fc16e172ea..df47689a6b6848 100644 --- a/src/core/reference/include/openvino/reference/interpolate.hpp +++ b/src/core/reference/include/openvino/reference/interpolate.hpp @@ -803,7 +803,7 @@ void interpolate(const T* input_data, } template -void interpolate(T* input_data, +void interpolate(const T* input_data, const PartialShape& input_data_shape, T* out, const Shape& out_shape, @@ -817,7 +817,7 @@ void interpolate(T* input_data, size_t bytes_in_padded_input = shape_size(padded_input_shape) * sizeof(T); std::vector padded_input_data(bytes_in_padded_input, 0); uint8_t* padded_data_ptr = padded_input_data.data(); - pad_input_data(reinterpret_cast(input_data), + pad_input_data(reinterpret_cast(input_data), padded_data_ptr, sizeof(T), input_data_shape.to_shape(), diff --git a/src/core/reference/include/openvino/reference/mvn.hpp b/src/core/reference/include/openvino/reference/mvn.hpp index c99c5eb802cb19..fd539b5428a0a5 100644 --- a/src/core/reference/include/openvino/reference/mvn.hpp +++ b/src/core/reference/include/openvino/reference/mvn.hpp @@ -89,7 +89,7 @@ void mvn_6(const T* arg, template AxisSet mvn_6_reduction_axes(const ov::Tensor& axes_input, size_t rank) { - T* a = axes_input.data(); + const T* a = axes_input.data(); auto v = std::vector(a, a + axes_input.get_shape()[0]); std::vector axes(v.size(), 0); for (size_t i = 0; i < v.size(); i++) { diff --git a/src/core/reference/src/op/loop.cpp b/src/core/reference/src/op/loop.cpp index 217c2e5658f8de..718e392ad9f52d 100644 --- a/src/core/reference/src/op/loop.cpp +++ b/src/core/reference/src/op/loop.cpp @@ -113,7 +113,7 @@ void loop(const std::shared_ptr& func, pointers_to_data[slice_desc->m_stride > 0 ? j : (pointers_to_data.size() - j - 1)] = static_cast(sliced_values[slice_in_idx][j].data()); } - reference::split(static_cast(args[slice_desc->m_input_index].data()), + reference::split(static_cast(args[slice_desc->m_input_index].data()), args[slice_desc->m_input_index].get_shape(), el_size, slice_desc->m_axis, @@ -211,7 +211,7 @@ void loop(const std::shared_ptr& func, std::vector pointers_on_values; pointers_on_values.reserve(values_to_concat[i].size()); for (const auto& vec : values_to_concat[i]) { - pointers_on_values.push_back(static_cast(vec.data())); + pointers_on_values.push_back(static_cast(vec.data())); } reference::concat(pointers_on_values, static_cast(out[concat_desc->m_output_index].data()), diff --git a/src/core/reference/src/op/tensor_iterator.cpp b/src/core/reference/src/op/tensor_iterator.cpp index 53640226ef0da2..b90f6ca985e67c 100644 --- a/src/core/reference/src/op/tensor_iterator.cpp +++ b/src/core/reference/src/op/tensor_iterator.cpp @@ -64,7 +64,7 @@ void tensor_iterator(uint64_t num_iterations, pointers_to_data[slice_desc->m_stride > 0 ? j : (pointers_to_data.size() - j - 1)] = static_cast(sliced_values[slice_in_idx][j].data()); } - reference::split(static_cast(args[slice_desc->m_input_index].data()), + reference::split(static_cast(args[slice_desc->m_input_index].data()), args[slice_desc->m_input_index].get_shape(), el_size, slice_desc->m_axis, diff --git a/src/core/src/bound_evaluate.cpp b/src/core/src/bound_evaluate.cpp index af2945c9f1291c..fddc596287e4ca 100644 --- a/src/core/src/bound_evaluate.cpp +++ b/src/core/src/bound_evaluate.cpp @@ -453,7 +453,7 @@ bool ov::interval_bound_evaluator(const Node* node, node->evaluate(lower_output_values, *input_variants.begin()); auto zero = op::v0::Constant::create(element::i64, {1}, {0}); - const auto zero_t = ov::Tensor(element::i64, Shape{}); + auto zero_t = ov::Tensor(element::i64, Shape{}); *zero_t.data() = 0; std::vector unsqueezed_output_variants; @@ -529,8 +529,8 @@ bool ov::interval_bound_evaluator(const Node* node, fully_defined = false; } else { // Can not set to make_tensor_of_min_value(lower_output_values[i]->get_element_type()) yet - const auto then = Tensor{lower_out[0].get_element_type(), Shape{}}; - const auto then_data = static_cast(then.data()); + auto then = Tensor{lower_out[0].get_element_type(), Shape{}}; + auto then_data = static_cast(then.data()); std::memset(then_data, 0, then.get_byte_size()); op::v1::Select().evaluate(lower_out, {final_input_dyn_mask, then, lower_out[0]}); node->get_output_tensor(i).set_lower_value(lower_out[0]); diff --git a/src/core/src/op/constant.cpp b/src/core/src/op/constant.cpp index ec6aef50215041..5577481a8a799b 100644 --- a/src/core/src/op/constant.cpp +++ b/src/core/src/op/constant.cpp @@ -207,8 +207,10 @@ Constant::Constant(const Tensor& tensor) : m_element_type{tensor.get_element_type()}, m_shape{tensor.get_shape()}, m_byte_strides{m_element_type.bitwidth() >= 8 ? tensor.get_strides() : Strides{}}, - m_data{ - std::make_shared>(static_cast(tensor.data()), tensor.get_byte_size(), tensor)} { + // cast is for internal use only to store tensor data in shared buffer (not for modification) + m_data{std::make_shared>(const_cast(static_cast(tensor.data())), + tensor.get_byte_size(), + tensor)} { constructor_validate_and_infer_types(); } diff --git a/src/core/src/op/convert.cpp b/src/core/src/op/convert.cpp index c3acf6caeaf2c3..c10e3438f1a278 100644 --- a/src/core/src/op/convert.cpp +++ b/src/core/src/op/convert.cpp @@ -57,7 +57,7 @@ struct Evaluate : public element::NoAction { CONVERT_ET_LIST, EvalByOutputType, out.get_element_type(), - iterator(reinterpret_cast(arg.data())), + iterator(arg.data()), out, count); } diff --git a/src/core/src/op/depth_to_space.cpp b/src/core/src/op/depth_to_space.cpp index d6f0316dbd70ba..a7e4d0da2b81ff 100644 --- a/src/core/src/op/depth_to_space.cpp +++ b/src/core/src/op/depth_to_space.cpp @@ -51,7 +51,7 @@ bool DepthToSpace::evaluate(TensorVector& outputs, const TensorVector& inputs) c OPENVINO_ASSERT(outputs.size() == 1); const auto& in = inputs[0]; - const auto& out = outputs[0]; + auto& out = outputs[0]; reference::depth_to_space(static_cast(in.data()), in.get_shape(), static_cast(out.data()), diff --git a/src/core/src/op/divide.cpp b/src/core/src/op/divide.cpp index bfa35150421a62..1cc9a66c771981 100644 --- a/src/core/src/op/divide.cpp +++ b/src/core/src/op/divide.cpp @@ -94,7 +94,7 @@ bool evaluate_bound(const Node* node, TensorVector& output_values, bool is_upper return false; const auto zeros_const = Constant::create(input2.get_element_type(), {}, {0}); - const auto zero_t = Tensor(input2.get_element_type(), Shape{}); + auto zero_t = Tensor(input2.get_element_type(), Shape{}); memcpy(zero_t.data(), zeros_const->get_data_ptr(), zero_t.get_byte_size()); const auto max_value = ov::util::make_tensor_of_max_value(input2.get_element_type()); @@ -172,7 +172,7 @@ bool evaluate_bound(const Node* node, TensorVector& output_values, bool is_upper // replace zeros by 1 values to get result of divide for other values of arguments const auto ones = Constant::create(input2.get_element_type(), input2.get_shape(), {1}); - const auto ones_t = Tensor(ones->get_element_type(), ones->get_shape()); + auto ones_t = Tensor(ones->get_element_type(), ones->get_shape()); memcpy(ones_t.data(), ones->get_data_ptr(), ones_t.get_byte_size()); status = Select().evaluate(value2_outs, {input2_zeros_mask, ones_t, value2}); diff --git a/src/core/src/op/roi_align.cpp b/src/core/src/op/roi_align.cpp index c62f08478c2ed7..28a0cfb3b151c6 100644 --- a/src/core/src/op/roi_align.cpp +++ b/src/core/src/op/roi_align.cpp @@ -162,7 +162,7 @@ template bool evaluate(const Tensor& feature_maps, const Tensor& rois, const std::vector& batch_indices_vec_scaled_up, - const Tensor& out, + Tensor& out, const int pooled_height, const int pooled_width, const int sampling_ratio, @@ -189,7 +189,7 @@ bool evaluate(const Tensor& feature_maps, } bool evaluate(const TensorVector& args, - const Tensor& out, + Tensor& out, const int pooled_height, const int pooled_width, const int sampling_ratio, diff --git a/src/core/src/op/space_to_batch.cpp b/src/core/src/op/space_to_batch.cpp index 17d20d38f40738..842a9a84de3221 100644 --- a/src/core/src/op/space_to_batch.cpp +++ b/src/core/src/op/space_to_batch.cpp @@ -74,7 +74,7 @@ namespace space_to_batch { namespace { bool evaluate(TensorVector& outputs, const TensorVector& inputs) { const auto& data = inputs[0]; - const auto& out = outputs[0]; + auto& out = outputs[0]; const auto elem_size = data.get_element_type().size(); auto data_shape = data.get_shape(); diff --git a/src/core/src/op/space_to_depth.cpp b/src/core/src/op/space_to_depth.cpp index 33ff10639187e3..39d192eb7f9e72 100644 --- a/src/core/src/op/space_to_depth.cpp +++ b/src/core/src/op/space_to_depth.cpp @@ -55,7 +55,7 @@ bool SpaceToDepth::evaluate(TensorVector& outputs, const TensorVector& inputs) c OPENVINO_ASSERT(outputs.size() == 1); const auto& in = inputs[0]; - const auto& out = outputs[0]; + auto& out = outputs[0]; reference::space_to_depth(static_cast(in.data()), in.get_shape(), static_cast(out.data()), diff --git a/src/core/src/op/split.cpp b/src/core/src/op/split.cpp index ca54db7ee7c400..f449564675345b 100644 --- a/src/core/src/op/split.cpp +++ b/src/core/src/op/split.cpp @@ -91,7 +91,7 @@ bool Split::evaluate(TensorVector& outputs, const TensorVector& inputs) const { auto axis = get_tensor_data_as(axis_tensor).front(); axis = ov::util::normalize(axis, data_tensor.get_shape().size()); - ov::reference::split(static_cast(data_tensor.data()), + ov::reference::split(static_cast(data_tensor.data()), data_tensor.get_shape(), data_tensor.get_element_type().size(), axis, diff --git a/src/core/src/op/transpose.cpp b/src/core/src/op/transpose.cpp index 0de48f87442d0e..78f9a10e74953e 100644 --- a/src/core/src/op/transpose.cpp +++ b/src/core/src/op/transpose.cpp @@ -96,7 +96,8 @@ bool Transpose::evaluate(TensorVector& outputs, const TensorVector& inputs) cons }; auto out_ptr = int4_iterator(static_cast(out.data())); - auto in_ptr = int4_iterator(static_cast(arg.data())); + // The int4_iterator not supports const pointer but these data are not modified + auto in_ptr = int4_iterator(static_cast(const_cast(arg.data()))); if ((arg_type == ov::element::i4 || arg_type == ov::element::u4) && arg.get_shape().size() == 2) { for (size_t i = 0; i < out_shape[0]; i++) { size_t off = i; diff --git a/src/core/src/op/util/pad_base.cpp b/src/core/src/op/util/pad_base.cpp index 814385c0466a2e..f6c630dcd5ffb9 100644 --- a/src/core/src/op/util/pad_base.cpp +++ b/src/core/src/op/util/pad_base.cpp @@ -106,7 +106,7 @@ bool op::util::PadBase::evaluate_pad(TensorVector& outputs, const TensorVector& const char* pad_value = nullptr; const std::vector pad_zero_value(elem_size, 0); if (get_input_size() == 4) { - pad_value = static_cast(inputs[3].data()); + pad_value = static_cast(inputs[3].data()); } else { pad_value = pad_zero_value.data(); } @@ -127,7 +127,7 @@ bool op::util::PadBase::evaluate_pad(TensorVector& outputs, const TensorVector& } outputs[0].set_shape(padded_shape); - ov::reference::pad(static_cast(inputs[0].data()), + ov::reference::pad(static_cast(inputs[0].data()), pad_value, static_cast(outputs[0].data()), elem_size, diff --git a/src/core/src/runtime/itensor.cpp b/src/core/src/runtime/itensor.cpp index 67dde4e38aa463..a2049faf470653 100644 --- a/src/core/src/runtime/itensor.cpp +++ b/src/core/src/runtime/itensor.cpp @@ -189,4 +189,8 @@ void ITensor::copy_to(const std::shared_ptr& dst) const { } } +void* ITensor::data(const element::Type& type) { + return const_cast(static_cast(this)->data(type)); +} + } // namespace ov diff --git a/src/core/src/runtime/tensor.cpp b/src/core/src/runtime/tensor.cpp index 5aad89d4ffb841..5f5241389ef25c 100644 --- a/src/core/src/runtime/tensor.cpp +++ b/src/core/src/runtime/tensor.cpp @@ -53,6 +53,9 @@ Tensor::Tensor(const element::Type& element_type, const Shape& shape, const Allo Tensor::Tensor(const element::Type& element_type, const Shape& shape, void* host_ptr, const Strides& byte_strides) : _impl{make_tensor(element_type, shape, host_ptr, byte_strides)} {} +Tensor::Tensor(const element::Type& element_type, const Shape& shape, const void* host_ptr, const Strides& byte_strides) + : _impl{make_tensor(element_type, shape, host_ptr, byte_strides)} {} + Tensor::Tensor(const Tensor& owner, const Coordinate& begin, const Coordinate& end) : _impl{make_tensor(owner._impl, begin, end)}, _so{owner._so} {} @@ -68,6 +71,12 @@ Tensor::Tensor(const ov::Output& port, void* host_ptr, const Str host_ptr, byte_strides) {} +Tensor::Tensor(const ov::Output& port, const void* host_ptr, const Strides& byte_strides) + : Tensor(port.get_element_type(), + port.get_partial_shape().is_dynamic() ? ov::Shape{0} : port.get_shape(), + host_ptr, + byte_strides) {} + const element::Type& Tensor::get_element_type() const { OV_TENSOR_STATEMENT(return _impl->get_element_type()); } @@ -96,10 +105,15 @@ size_t Tensor::get_byte_size() const { OV_TENSOR_STATEMENT(return _impl->get_byte_size();); } -void* Tensor::data(const element::Type& element_type) const { +void* Tensor::data(const element::Type& element_type) { OV_TENSOR_STATEMENT(return _impl->data(element_type)); } +const void* Tensor::data(const element::Type& element_type) const { + using const_data = const void* (ITensor::*)(const element::Type&) const; + OV_TENSOR_STATEMENT(return std::invoke(&ITensor::data, _impl, element_type);); +} + bool Tensor::operator!() const noexcept { return !_impl; } diff --git a/src/core/tests/ov_tensor_test.cpp b/src/core/tests/ov_tensor_test.cpp index df20e559045ee3..66aba8dfad12f0 100644 --- a/src/core/tests/ov_tensor_test.cpp +++ b/src/core/tests/ov_tensor_test.cpp @@ -8,22 +8,24 @@ #include #include -#include -#include -#include #include "common_test_utils/test_assertions.hpp" #include "openvino/core/except.hpp" #include "openvino/core/partial_shape.hpp" +#include "openvino/core/shape.hpp" +#include "openvino/core/strides.hpp" +#include "openvino/core/type/element_type.hpp" #include "openvino/core/type/element_type_traits.hpp" +#include "openvino/op/constant.hpp" #include "openvino/op/parameter.hpp" #include "openvino/reference/utils/coordinate_transform.hpp" #include "openvino/runtime/allocator.hpp" #include "openvino/runtime/remote_tensor.hpp" #include "openvino/runtime/tensor.hpp" +namespace ov::test { using OVTensorTest = ::testing::Test; -using testing::_; +using testing::_, testing::HasSubstr; const size_t string_size = ov::element::string.size(); @@ -84,6 +86,26 @@ TEST_F(OVTensorTest, createTensorFromPort) { EXPECT_EQ(t4.get_element_type(), parameter3->get_element_type()); } +TEST_F(OVTensorTest, createTensorFromConstantPort) { + auto constant1 = std::make_shared(ov::element::f64, ov::Shape{1, 3, 2, 2}, 0); + auto constant2 = std::make_shared(ov::element::f32, ov::Shape{1, 3}, 1.0f); + ov::Tensor t1{constant1->output(0)}; + ov::Tensor t2{constant2->output(0), constant1->get_data_ptr()}; + const auto& t2c = t2; + const ov::Tensor t3{constant2->output(0), constant2->get_data_ptr()}; + + EXPECT_EQ(t1.get_shape(), constant1->get_shape()); + EXPECT_EQ(t1.get_element_type(), constant1->get_element_type()); + EXPECT_EQ(t2.get_shape(), constant2->get_shape()); + EXPECT_EQ(t2.get_element_type(), constant2->get_element_type()); + OV_EXPECT_THROW(t2.data(), ov::Exception, _); + OV_ASSERT_NO_THROW(t2c.data()); + + EXPECT_EQ(t3.get_shape(), constant2->get_shape()); + EXPECT_EQ(t3.get_element_type(), constant2->get_element_type()); + OV_ASSERT_NO_THROW(t3.data()); +} + TEST_F(OVTensorTest, createStringTensorFromPort) { auto parameter1 = std::make_shared(ov::element::string, ov::Shape{1, 3, 2, 2}); auto parameter2 = std::make_shared(ov::element::string, ov::Shape{1, 3}); @@ -290,12 +312,12 @@ TEST_F(OVTensorTest, canAccessExternalDataWithStridesStringTensor) { TEST_F(OVTensorTest, cannotCreateTensorWithExternalNullptr) { ov::Shape shape = {2, 3}; - ASSERT_THROW(ov::Tensor(ov::element::f32, shape, nullptr), ov::Exception); + ASSERT_THROW(ov::Tensor(ov::element::f32, shape, static_cast(nullptr)), ov::Exception); } TEST_F(OVTensorTest, cannotCreateStringTensorWithExternalNullptr) { ov::Shape shape = {2, 3}; - ASSERT_THROW(ov::Tensor(ov::element::string, shape, nullptr), ov::Exception); + ASSERT_THROW(ov::Tensor(ov::element::string, shape, static_cast(nullptr)), ov::Exception); } TEST_F(OVTensorTest, cannotCreateTensorWithWrongStrides) { @@ -555,6 +577,15 @@ TEST_F(OVTensorTest, canChangeShapeOnStridedTensorStringTensor) { OV_ASSERT_NO_THROW(t.set_shape(correct_shape)); } +TEST_F(OVTensorTest, createReadOnlyStridedView) { + const std::vector data(64 * 4); + ov::Tensor strided_view{element::string, {4, 2, 2}, data.data(), {8 * string_size, 3 * string_size, string_size}}; + + OV_ASSERT_NO_THROW(static_cast(strided_view).data(element::string)); + OV_ASSERT_NO_THROW(strided_view.data()); + OV_EXPECT_THROW(strided_view.data(), ov::Exception, HasSubstr("Can not access non-const pointer")); +} + TEST_F(OVTensorTest, makeRangeRoiTensor) { ov::Tensor t{ov::element::i32, {1, 3, 6, 5}}; // RGBp picture of size (WxH) = 5x6 ov::Tensor roi_tensor{t, {0, 0, 1, 2}, {1, 3, 5, 4}}; @@ -831,7 +862,7 @@ TEST_F(OVTensorTest, checkIsContinuousTensor3Dimensions) { } TEST_F(OVTensorTest, checkIsContinuousTensor4Dimensions) { - ov::Tensor tensor(ov::element::f32, ov::Shape{3, 5, 32, 128}); + const ov::Tensor tensor(ov::element::f32, ov::Shape{3, 5, 32, 128}); auto data = tensor.data(); auto strides = tensor.get_strides(); @@ -865,6 +896,19 @@ TEST_F(OVTensorTest, checkIsContinuousTensor4Dimensions) { EXPECT_EQ(view_tensor.is_continuous(), true); } +TEST_F(OVTensorTest, createReadOnlyView) { + const std::vector data(10, 1); + ov::Tensor data_view(ov::element::i32, ov::Shape{10}, data.data()); + + OV_ASSERT_NO_THROW(static_cast(data_view).data(element::i32)); + OV_ASSERT_NO_THROW(data_view.data()); + OV_EXPECT_THROW(data_view.data(), ov::Exception, HasSubstr("Can not access non-const pointer")); +} + +TEST_F(OVTensorTest, createReadOnlyViewFromNullptr) { + OV_EXPECT_THROW(Tensor(ov::element::i32, ov::Shape{10}, static_cast(nullptr)), ov::Exception, _); +} + struct TestParams { ov::Shape src_shape; ov::Strides src_strides; @@ -903,8 +947,8 @@ void compare_data(const ov::Tensor& src, const ov::Tensor& dst) { template ::value_type, typename std::enable_if::type = true> -void init_tensor(const ov::Tensor& tensor, bool input) { - const auto origPtr = tensor.data(); +void init_tensor(ov::Tensor& tensor, bool input) { + auto origPtr = tensor.data(); ASSERT_NE(nullptr, origPtr); for (size_t i = 0; i < tensor.get_size(); ++i) { origPtr[i] = static_cast(input ? i : -1); @@ -914,7 +958,7 @@ void init_tensor(const ov::Tensor& tensor, bool input) { template ::value_type, typename std::enable_if::type = true> -void init_tensor(const ov::Tensor& tensor, bool input) { +void init_tensor(ov::Tensor& tensor, bool input) { const auto origPtr = tensor.data(); ASSERT_NE(nullptr, origPtr); for (size_t i = 0; i < tensor.get_size(); ++i) { @@ -922,7 +966,7 @@ void init_tensor(const ov::Tensor& tensor, bool input) { } } -void init_tensor(const ov::Tensor& tensor, bool input) { +void init_tensor(ov::Tensor& tensor, bool input) { switch (tensor.get_element_type()) { case ov::element::bf16: init_tensor(tensor, input); @@ -1141,3 +1185,4 @@ INSTANTIATE_TEST_SUITE_P(copy_tests_strings, } ))); // clang-format on +} // namespace ov::test diff --git a/src/core/tests/utils/eval_utils.hpp b/src/core/tests/utils/eval_utils.hpp index abeb5fc751c68b..4ceee915c0f2d8 100644 --- a/src/core/tests/utils/eval_utils.hpp +++ b/src/core/tests/utils/eval_utils.hpp @@ -11,7 +11,7 @@ namespace { template -void copy_data(const ov::Tensor& tv, const std::vector& data) { +void copy_data(ov::Tensor& tv, const std::vector& data) { size_t data_size = data.size() * sizeof(T); if (data_size > 0) { OPENVINO_ASSERT(tv.get_byte_size() >= data_size); @@ -21,13 +21,13 @@ void copy_data(const ov::Tensor& tv, const std::vector& data) { } template <> -inline void copy_data(const ov::Tensor& tv, const std::vector& data) { +inline void copy_data(ov::Tensor& tv, const std::vector& data) { std::vector data_char(data.begin(), data.end()); copy_data(tv, data_char); } template -void init_int_tv(const ov::Tensor& tv, std::default_random_engine& engine, T min, T max) { +void init_int_tv(ov::Tensor& tv, std::default_random_engine& engine, T min, T max) { size_t size = tv.get_size(); std::uniform_int_distribution dist(min, max); std::vector vec(size); @@ -40,7 +40,7 @@ void init_int_tv(const ov::Tensor& tv, std::default_random_engine& engine, T min } template <> -inline void init_int_tv(const ov::Tensor& tv, std::default_random_engine& engine, char min, char max) { +inline void init_int_tv(ov::Tensor& tv, std::default_random_engine& engine, char min, char max) { size_t size = tv.get_size(); std::uniform_int_distribution dist(static_cast(min), static_cast(max)); std::vector vec(size); @@ -53,7 +53,7 @@ inline void init_int_tv(const ov::Tensor& tv, std::default_random_engine& } template <> -inline void init_int_tv(const ov::Tensor& tv, std::default_random_engine& engine, int8_t min, int8_t max) { +inline void init_int_tv(ov::Tensor& tv, std::default_random_engine& engine, int8_t min, int8_t max) { size_t size = tv.get_size(); std::uniform_int_distribution dist(static_cast(min), static_cast(max)); std::vector vec(size); @@ -66,7 +66,7 @@ inline void init_int_tv(const ov::Tensor& tv, std::default_random_engine } template <> -inline void init_int_tv(const ov::Tensor& tv, std::default_random_engine& engine, uint8_t min, uint8_t max) { +inline void init_int_tv(ov::Tensor& tv, std::default_random_engine& engine, uint8_t min, uint8_t max) { size_t size = tv.get_size(); std::uniform_int_distribution dist(static_cast(min), static_cast(max)); std::vector vec(size); @@ -79,7 +79,7 @@ inline void init_int_tv(const ov::Tensor& tv, std::default_random_engin } template -void init_real_tv(const ov::Tensor& tv, std::default_random_engine& engine, T min, T max) { +void init_real_tv(ov::Tensor& tv, std::default_random_engine& engine, T min, T max) { size_t size = tv.get_size(); std::uniform_real_distribution dist(min, max); std::vector vec(size); @@ -91,7 +91,7 @@ void init_real_tv(const ov::Tensor& tv, std::default_random_engine& engine, T mi memcpy(tv.data(), vec.data(), data_size); } -inline void random_init(const ov::Tensor& tv, std::default_random_engine& engine) { +inline void random_init(ov::Tensor& tv, std::default_random_engine& engine) { ov::element::Type et = tv.get_element_type(); if (et == ov::element::boolean) { init_int_tv(tv, engine, 0, 1); diff --git a/src/inference/dev_api/openvino/runtime/iremote_tensor.hpp b/src/inference/dev_api/openvino/runtime/iremote_tensor.hpp index 82be5a7496d85f..f1c8112bd07fbd 100644 --- a/src/inference/dev_api/openvino/runtime/iremote_tensor.hpp +++ b/src/inference/dev_api/openvino/runtime/iremote_tensor.hpp @@ -17,7 +17,7 @@ namespace ov { class OPENVINO_RUNTIME_API IRemoteTensor : public ITensor { public: - void* data(const element::Type& type = {}) const override final { + const void* data(const element::Type& type = {}) const override final { OPENVINO_NOT_IMPLEMENTED; } diff --git a/src/inference/dev_api/openvino/runtime/make_tensor.hpp b/src/inference/dev_api/openvino/runtime/make_tensor.hpp index 6cfe74ebaad95a..10d5de93f43eb0 100644 --- a/src/inference/dev_api/openvino/runtime/make_tensor.hpp +++ b/src/inference/dev_api/openvino/runtime/make_tensor.hpp @@ -39,6 +39,20 @@ OPENVINO_RUNTIME_API std::shared_ptr make_tensor(const element::Type ty void* host_ptr, const Strides& strides = {}); +/** + * @brief Constructs Tensor using element type and shape. Wraps allocated host memory as read-only. + * @note Does not perform memory allocation internally + * @param type Tensor element type + * @param shape Tensor shape + * @param host_ptr Pointer to pre-allocated host memory + * @param strides Optional strides parameters in bytes. Strides are supposed to be computed automatically based + * on shape and element size + */ +OPENVINO_RUNTIME_API std::shared_ptr make_tensor(const element::Type type, + const Shape& shape, + const void* host_ptr, + const Strides& strides = {}); + /** * @brief Constructs region of interest (ROI) tensor form another tensor. * @note Does not perform memory allocation internally diff --git a/src/inference/src/dev/compilation_context.cpp b/src/inference/src/dev/compilation_context.cpp index bf1a7197826f49..49e8ae2b1eb3ac 100644 --- a/src/inference/src/dev/compilation_context.cpp +++ b/src/inference/src/dev/compilation_context.cpp @@ -109,7 +109,7 @@ std::string ModelCache::compute_hash(const std::string& modelStr, if (tensor) { seed = hash_combine(seed, tensor.get_size()); - auto ptr = static_cast(tensor.data()); + auto ptr = static_cast(tensor.data()); size_t size = tensor.get_size() / sizeof(size_t); // 10MB block size in size_t @@ -141,7 +141,7 @@ std::string ModelCache::compute_hash(const std::string& modelStr, } auto size_done = size * sizeof(size_t); - auto ptr_left = static_cast(tensor.data()) + size_done; + auto ptr_left = static_cast(tensor.data()) + size_done; size_t size_left = tensor.get_size() - size_done; for (size_t i = 0; i < size_left; i++) seed = hash_combine(seed, ptr_left[i]); diff --git a/src/inference/src/dev/make_tensor.cpp b/src/inference/src/dev/make_tensor.cpp index 88f0989007c522..35a8e85cb6878a 100644 --- a/src/inference/src/dev/make_tensor.cpp +++ b/src/inference/src/dev/make_tensor.cpp @@ -57,11 +57,8 @@ class ViewTensor : public ITensor { OPENVINO_ASSERT(m_element_type.is_static()); } - void* data(const element::Type& element_type) const override { - if (element_type.is_static() && (element_type.bitwidth() != get_element_type().bitwidth() || - element_type.is_real() != get_element_type().is_real() || - (element_type == element::string && get_element_type() != element::string) || - (element_type != element::string && get_element_type() == element::string))) { + const void* data(const element::Type& element_type) const override { + if (!is_pointer_representable(element_type)) { OPENVINO_THROW("Tensor data with element type ", get_element_type(), ", is not representable as pointer to ", @@ -94,6 +91,12 @@ class ViewTensor : public ITensor { } protected: + bool is_pointer_representable(const element::Type& element_type) const { + return element_type.is_dynamic() || ((element_type.bitwidth() == get_element_type().bitwidth() && + element_type.is_real() == get_element_type().is_real()) || + (element_type == element::string && element_type == get_element_type())); + } + void update_strides() const { if (m_element_type.bitwidth() < 8) return; @@ -118,6 +121,22 @@ class ViewTensor : public ITensor { void* m_ptr; }; +/** + * @brief Read-only view tensor to external memory + * The tensor doesn't own the external memory + */ +class ReadOnlyViewTensor : public ViewTensor { +public: + ReadOnlyViewTensor(const element::Type element_type, const Shape& shape, const void* ptr) + : ViewTensor{element_type, shape, const_cast(ptr)} {} + + using ViewTensor::data; + + [[noreturn]] void* data(const element::Type& element_type) override { + OPENVINO_THROW("Can not access non-const pointer use e.g. 'static_cast.data()'"); + } +}; + /** * @brief View tensor on external memory with strides */ @@ -173,6 +192,21 @@ class StridedViewTensor : public ViewTensor { } }; +class ReadOnlyStridedViewTensor : public StridedViewTensor { +public: + ReadOnlyStridedViewTensor(const element::Type element_type, + const Shape& shape, + const void* ptr, + const Strides& strides) + : StridedViewTensor{element_type, shape, const_cast(ptr), strides} {} + + using StridedViewTensor::data; + + [[noreturn]] void* data(const element::Type& element_type) override { + OPENVINO_THROW("Can not access non-const pointer use e.g. 'static_cast.data()'"); + } +}; + /** * @brief Creates view tensor on external memory * @@ -191,6 +225,27 @@ std::shared_ptr make_tensor(const element::Type element_type, : std::make_shared(element_type, shape, ptr, byte_strides); } +/** + * @brief Creates read-only view tensor on external memory + * + * @param element_type Tensor element type + * @param shape Tensor shape + * @param ptr pointer to external memory + * @param byte_strides Tensor strides + * + * @return Shared pointer to tensor interface + */ +std::shared_ptr make_tensor(const element::Type element_type, + const Shape& shape, + const void* ptr, + const Strides& byte_strides) { + if (byte_strides.empty()) { + return std::make_shared(element_type, shape, ptr); + } else { + return std::make_shared(element_type, shape, ptr, byte_strides); + } +} + /** * @brief Tensor with allocated memory * Tensor owns the memory @@ -348,7 +403,7 @@ class RoiTensor : public BaseRoiTensor, public ITensor { BaseRoiTensor::set_shape(new_shape); } - void* data(const element::Type& element_type) const override { + const void* data(const element::Type& element_type) const override { auto owner_data = m_owner->data(element_type); return static_cast(owner_data) + m_offset; } diff --git a/src/inference/src/model_reader.cpp b/src/inference/src/model_reader.cpp index cc544e9e2f0dce..de8e865d23dc10 100644 --- a/src/inference/src/model_reader.cpp +++ b/src/inference/src/model_reader.cpp @@ -175,7 +175,7 @@ std::shared_ptr read_model(const std::string& model, ov::AnyVector params{&modelStream}; if (weights) { std::shared_ptr weights_buffer = - std::make_shared>(reinterpret_cast(weights.data()), + std::make_shared>(reinterpret_cast(const_cast(weights.data())), weights.get_byte_size(), weights); params.emplace_back(weights_buffer); diff --git a/src/plugins/intel_cpu/src/cpu_tensor.cpp b/src/plugins/intel_cpu/src/cpu_tensor.cpp index 055acd43f6dfb3..724fa9c7bed24c 100644 --- a/src/plugins/intel_cpu/src/cpu_tensor.cpp +++ b/src/plugins/intel_cpu/src/cpu_tensor.cpp @@ -83,7 +83,7 @@ void Tensor::update_strides() const { }); } -void* Tensor::data(const element::Type& element_type) const { +const void* Tensor::data(const element::Type& element_type) const { if (element_type.is_static()) { OPENVINO_ASSERT(element_type == get_element_type(), "Tensor data with element type ", diff --git a/src/plugins/intel_cpu/src/cpu_tensor.h b/src/plugins/intel_cpu/src/cpu_tensor.h index 77506389a1fd5f..fcb1ffdb723a46 100644 --- a/src/plugins/intel_cpu/src/cpu_tensor.h +++ b/src/plugins/intel_cpu/src/cpu_tensor.h @@ -27,7 +27,7 @@ class Tensor : public ITensor { const ov::Strides& get_strides() const override; - void* data(const element::Type& type) const override; + const void* data(const element::Type& type) const override; MemoryPtr get_memory() { return m_memptr; diff --git a/src/plugins/intel_gpu/include/intel_gpu/plugin/common_utils.hpp b/src/plugins/intel_gpu/include/intel_gpu/plugin/common_utils.hpp index f17c20cd99242e..136f450d433e45 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/plugin/common_utils.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/plugin/common_utils.hpp @@ -142,7 +142,7 @@ void convert_and_copy( cldnn::memory::ptr dst, cldnn::stream& stream, const cldnn::layout& src_layout = cldnn::layout({}, ov::element::dynamic, cldnn::format::bfyx, cldnn::padding())); -void convert_and_copy(const cldnn::memory::ptr src, ov::ITensor const* dst, const cldnn::stream& stream); +void convert_and_copy(const cldnn::memory::ptr src, ov::ITensor* dst, const cldnn::stream& stream); void convert_and_copy(const ov::ITensor* src, ov::ITensor* dst, const cldnn::stream& stream); void convert_and_copy(const cldnn::memory::ptr src, cldnn::memory::ptr dst, cldnn::stream& stream); diff --git a/src/plugins/intel_gpu/include/intel_gpu/plugin/sync_infer_request.hpp b/src/plugins/intel_gpu/include/intel_gpu/plugin/sync_infer_request.hpp index 73a92d98e6a7e9..42168bc663cbeb 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/plugin/sync_infer_request.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/plugin/sync_infer_request.hpp @@ -113,7 +113,7 @@ class SyncInferRequest : public ov::ISyncInferRequest { void allocate_states(); void allocate_input(const ov::Output& port, size_t input_idx); void allocate_output(const ov::Output& port, size_t output_idx); - cldnn::event::ptr copy_output_data(cldnn::memory::ptr src, const ov::ITensor& dst) const; + cldnn::event::ptr copy_output_data(cldnn::memory::ptr src, ov::ITensor& dst) const; void init_mappings(); bool is_batched_input(const ov::Output& port) const; diff --git a/src/plugins/intel_gpu/include/intel_gpu/plugin/usm_host_tensor.hpp b/src/plugins/intel_gpu/include/intel_gpu/plugin/usm_host_tensor.hpp index c19340af329b4b..caaa029bd154b9 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/plugin/usm_host_tensor.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/plugin/usm_host_tensor.hpp @@ -20,7 +20,7 @@ class USMHostTensor : public ov::ITensor { ~USMHostTensor() override = default; - void* data(const element::Type& element_type) const override; + const void* data(const element::Type& element_type) const override; const element::Type& get_element_type() const override; const Shape& get_shape() const override; diff --git a/src/plugins/intel_gpu/src/plugin/common_utils.cpp b/src/plugins/intel_gpu/src/plugin/common_utils.cpp index 229d1b8eaa3bc0..ac1ccfd1455687 100644 --- a/src/plugins/intel_gpu/src/plugin/common_utils.cpp +++ b/src/plugins/intel_gpu/src/plugin/common_utils.cpp @@ -164,7 +164,7 @@ void convert_and_copy(const ov::ITensor* src, cldnn::memory::ptr dst, cldnn::str dst->copy_from(stream, tmp_tensor.data(), blocking); } -void convert_and_copy(const cldnn::memory::ptr src, ov::ITensor const* dst, const cldnn::stream& stream) { +void convert_and_copy(const cldnn::memory::ptr src, ov::ITensor* dst, const cldnn::stream& stream) { auto src_et = src->get_layout().data_type; auto dst_et = dst->get_element_type(); diff --git a/src/plugins/intel_gpu/src/plugin/remote_tensor.cpp b/src/plugins/intel_gpu/src/plugin/remote_tensor.cpp index 97ca22abdc210f..8e6a7eea823863 100644 --- a/src/plugins/intel_gpu/src/plugin/remote_tensor.cpp +++ b/src/plugins/intel_gpu/src/plugin/remote_tensor.cpp @@ -243,7 +243,8 @@ void RemoteTensorImpl::copy_from(const std::shared_ptr& src, OPENVINO_ASSERT(!std::dynamic_pointer_cast(src), "[GPU] Unsupported Remote Tensor type"); - auto src_mem = MemWrapper(stream, nullptr, src->data()); + // MemWrapper use tensor pointer as read-only, so const_cast is safe here + auto src_mem = MemWrapper(stream, nullptr, const_cast(src->data())); auto dst_mem = MemWrapper(stream, get_memory(), nullptr); copy_roi(src_mem, dst_mem, src_offset, dst_offset, src->get_strides(), get_strides(), roi_strides, src->get_shape(), get_shape(), shape); diff --git a/src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp b/src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp index 7b8bb21da00393..c34e4ea4a74e3e 100644 --- a/src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp +++ b/src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp @@ -561,7 +561,7 @@ TensorWrapper SyncInferRequest::create_or_share_device_tensor(const TensorWrappe return { create_device_tensor(actual_memory_shape, element_type, need_lockable_mem), TensorOwner::PLUGIN }; } -cldnn::event::ptr SyncInferRequest::copy_output_data(cldnn::memory::ptr src, const ov::ITensor& dst) const { +cldnn::event::ptr SyncInferRequest::copy_output_data(cldnn::memory::ptr src, ov::ITensor& dst) const { OV_ITT_SCOPED_TASK(itt::domains::intel_gpu_plugin, "SyncInferRequest::copy_output_data"); OPENVINO_ASSERT(src->count() <= dst.get_size(), "[GPU] Unexpected elements count of dst tensor: ", diff --git a/src/plugins/intel_gpu/src/plugin/usm_host_tensor.cpp b/src/plugins/intel_gpu/src/plugin/usm_host_tensor.cpp index 50580537a8d2b4..e3edfb818c1692 100644 --- a/src/plugins/intel_gpu/src/plugin/usm_host_tensor.cpp +++ b/src/plugins/intel_gpu/src/plugin/usm_host_tensor.cpp @@ -15,7 +15,7 @@ USMHostTensor::USMHostTensor(std::shared_ptr context, const e USMHostTensor::USMHostTensor(std::shared_ptr tensor) : m_impl(tensor) {} -void* USMHostTensor::data(const element::Type& element_type) const { +const void* USMHostTensor::data(const element::Type& element_type) const { return m_impl->get_original_memory()->buffer_ptr(); } diff --git a/src/plugins/intel_gpu/tests/functional/remote_tensor_tests/gpu_remote_tensor_tests.cpp b/src/plugins/intel_gpu/tests/functional/remote_tensor_tests/gpu_remote_tensor_tests.cpp index 90f2a310eac9bb..7d3036130c4cf6 100644 --- a/src/plugins/intel_gpu/tests/functional/remote_tensor_tests/gpu_remote_tensor_tests.cpp +++ b/src/plugins/intel_gpu/tests/functional/remote_tensor_tests/gpu_remote_tensor_tests.cpp @@ -2564,9 +2564,8 @@ void compare_data(const ov::Tensor& src, const ov::Tensor& dst) { } } -template ::value_type> -void init_tensor(const ov::Tensor& tensor) { +template ::value_type> +void init_tensor(ov::Tensor& tensor) { const auto origPtr = tensor.data(); ASSERT_NE(nullptr, origPtr); for (size_t i = 0; i < tensor.get_size(); ++i) { @@ -2574,7 +2573,7 @@ void init_tensor(const ov::Tensor& tensor) { } } -void init_tensor(const ov::Tensor& tensor) { +void init_tensor(ov::Tensor& tensor) { switch (tensor.get_element_type()) { case ov::element::f32: init_tensor(tensor); diff --git a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/single_layer_tests/matrix_nms.cpp b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/single_layer_tests/matrix_nms.cpp index 646aa9e8d3732f..ac62b9fa8a2c1d 100644 --- a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/single_layer_tests/matrix_nms.cpp +++ b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/single_layer_tests/matrix_nms.cpp @@ -117,8 +117,8 @@ void MatrixNmsLayerTestGPU::compare(const std::vector &expectedOutpu for (int outputIndex = static_cast(expectedOutputs.size()) - 1; outputIndex >= 0 ; outputIndex--) { const auto& expected = expectedOutputs[outputIndex]; const auto& actual = actualOutputs[outputIndex]; - const auto actualBuffer = static_cast(actual.data()); - const auto expectedBuffer = static_cast(expected.data()); + const auto actualBuffer = static_cast(actual.data()); + const auto expectedBuffer = static_cast(expected.data()); //Compare Selected Outputs & Selected Indices if (outputIndex != batchIndex) { @@ -147,7 +147,7 @@ void MatrixNmsLayerTestGPU::compare(const std::vector &expectedOutpu default: break; } - const auto fBuffer = static_cast(actual.data()); + const auto fBuffer = static_cast(actual.data()); for (size_t tailing = validNums * 6; tailing < maxOutputBoxesPerBatch * 6; tailing++) { ASSERT_TRUE(std::abs(fBuffer[(actual_offset * 6 + tailing)] + 1.f) < 1e-5) << "Invalid default value: " << fBuffer[i] << " at index: " << i; diff --git a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/single_layer_tests/multiclass_nms.cpp b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/single_layer_tests/multiclass_nms.cpp index 79c7b60cc6f9dc..d7023580663647 100644 --- a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/single_layer_tests/multiclass_nms.cpp +++ b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/single_layer_tests/multiclass_nms.cpp @@ -237,8 +237,8 @@ void MulticlassNmsLayerTestGPU::compare(const std::vector &expectedO for (int outputIndex = static_cast(expectedOutputs.size()) - 1; outputIndex >= 0; outputIndex--) { const auto& expected = expectedOutputs[outputIndex]; const auto& actual = actualOutputs[outputIndex]; - const auto actualBuffer = static_cast(actual.data()); - const auto expectedBuffer = static_cast(expected.data()); + const auto actualBuffer = static_cast(actual.data()); + const auto expectedBuffer = static_cast(expected.data()); const auto expected_shape = expected.get_shape(); const auto actual_shape = actual.get_shape(); @@ -270,7 +270,7 @@ void MulticlassNmsLayerTestGPU::compare(const std::vector &expectedO break; } - const auto fBuffer = static_cast(actual.data()); + const auto fBuffer = static_cast(actual.data()); for (size_t tailing = validNums * 6; tailing < maxOutputBoxesPerBatch * 6; tailing++) { ASSERT_TRUE(std::abs(fBuffer[(actual_offset * 6 + tailing)] - -1.f) < 1e-5) << "Invalid default value: " << fBuffer[i] << " at index: " << i; diff --git a/src/plugins/intel_npu/src/backend/include/zero_host_tensor.hpp b/src/plugins/intel_npu/src/backend/include/zero_host_tensor.hpp index 55ba4e476ec755..2e92c0ca25119b 100644 --- a/src/plugins/intel_npu/src/backend/include/zero_host_tensor.hpp +++ b/src/plugins/intel_npu/src/backend/include/zero_host_tensor.hpp @@ -23,7 +23,7 @@ class ZeroHostTensor : public ov::ITensor { ~ZeroHostTensor() override = default; - void* data(const ov::element::Type& element_type) const override; + const void* data(const ov::element::Type& element_type) const override; const ov::element::Type& get_element_type() const override; const ov::Shape& get_shape() const override; diff --git a/src/plugins/intel_npu/src/backend/include/zero_tensor.hpp b/src/plugins/intel_npu/src/backend/include/zero_tensor.hpp index 4a42f003ea4723..332aea6fdda232 100644 --- a/src/plugins/intel_npu/src/backend/include/zero_tensor.hpp +++ b/src/plugins/intel_npu/src/backend/include/zero_tensor.hpp @@ -29,7 +29,7 @@ class ZeroTensor final : public ov::ITensor { const ov::Shape& shape, const ov::Allocator& allocator); - void* data(const ov::element::Type& type = {}) const override; + const void* data(const ov::element::Type& type = {}) const override; const ov::element::Type& get_element_type() const override; diff --git a/src/plugins/intel_npu/src/backend/src/zero_host_tensor.cpp b/src/plugins/intel_npu/src/backend/src/zero_host_tensor.cpp index 7f55cfb3e9c976..985d8f5fd6a7f9 100644 --- a/src/plugins/intel_npu/src/backend/src/zero_host_tensor.cpp +++ b/src/plugins/intel_npu/src/backend/src/zero_host_tensor.cpp @@ -24,7 +24,7 @@ ZeroHostTensor::ZeroHostTensor(const std::shared_ptr& contex tensor_type, ov::intel_npu::MemType::L0_INTERNAL_BUF)) {} -void* ZeroHostTensor::data(const ov::element::Type&) const { +const void* ZeroHostTensor::data(const ov::element::Type&) const { return _impl->get_original_memory(); } diff --git a/src/plugins/intel_npu/src/backend/src/zero_tensor.cpp b/src/plugins/intel_npu/src/backend/src/zero_tensor.cpp index ae84ce4201c5dc..9e474dae19b843 100644 --- a/src/plugins/intel_npu/src/backend/src/zero_tensor.cpp +++ b/src/plugins/intel_npu/src/backend/src/zero_tensor.cpp @@ -33,7 +33,7 @@ ZeroTensor::ZeroTensor(const std::shared_ptr& init_struct _ptr = data; } -void* ZeroTensor::data(const ov::element::Type& element_type) const { +const void* ZeroTensor::data(const ov::element::Type& element_type) const { if (element_type != ov::element::dynamic && (element_type.bitwidth() != get_element_type().bitwidth() || element_type.is_real() != get_element_type().is_real() || diff --git a/src/plugins/intel_npu/src/plugin/npuw/util.cpp b/src/plugins/intel_npu/src/plugin/npuw/util.cpp index 322525421a5519..892f6b8a0ba248 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/util.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/util.cpp @@ -460,7 +460,7 @@ ov::Tensor ov::npuw::util::to_f16(const ov::Tensor& t) { } inline uint8_t tread_4b(const ov::Tensor& t, std::size_t r, std::size_t c, std::size_t COLS) { - const uint8_t* tdata = static_cast(t.data()); + const uint8_t* tdata = static_cast(t.data()); const uint8_t* trow = tdata + r * COLS / 2; const uint8_t* telem = trow + c / 2; if (c % 2 == 0) { @@ -471,7 +471,7 @@ inline uint8_t tread_4b(const ov::Tensor& t, std::size_t r, std::size_t c, std:: template inline T tread(const ov::Tensor& t, std::size_t r, std::size_t c, std::size_t COLS) { - const T* tdata = static_cast(t.data()); + const T* tdata = static_cast(t.data()); const T* trow = tdata + r * COLS; const T* telem = trow + c; return *telem; @@ -529,7 +529,7 @@ void permute120(const ov::Tensor& src, ov::Tensor& dst) { const ov::Shape dst_shape = dst.get_shape(); NPUW_ASSERT(src_shape.size() == 3); // Yes, so far only transpose 3D tensors - const T* pSrc = static_cast(src.data()); + const T* pSrc = static_cast(src.data()); T* pDst = static_cast(dst.data()); // DSTs [b,r,c] map to SRC's [r,c,b] @@ -659,7 +659,7 @@ ov::Tensor ov::npuw::util::concat(const std::vector& tt, std::size_t const bool is_4bit = (type == ov::element::i4 || type == ov::element::u4); for (std::size_t t_idx = 0; t_idx < tt.size(); t_idx++) { - const uint8_t* pSrc = static_cast(tt[t_idx].data()); + const uint8_t* pSrc = static_cast(tt[t_idx].data()); const auto copy_size = lens[t_idx] * shape[1] * shape[2]; const auto copy_len = is_4bit ? copy_size / 2 : copy_size * type.size(); @@ -683,7 +683,7 @@ ov::Tensor ov::npuw::util::concat(const std::vector& tt, std::size_t uint8_t* pDstRow = pDst + r_offset + c_offset; const auto r_offset_src = is_4bit ? lens[t_idx] * r / 2 : lens[t_idx] * r * type.size(); - const uint8_t* pSrc = static_cast(t_src.data()); + const uint8_t* pSrc = static_cast(t_src.data()); const uint8_t* pSrcRow = pSrc + r_offset_src; std::copy_n(pSrcRow, copy_len, pDstRow); diff --git a/src/plugins/intel_npu/tools/common/include/tensor_utils.hpp b/src/plugins/intel_npu/tools/common/include/tensor_utils.hpp index 18d431c157d2dd..e4adeeaa194a40 100644 --- a/src/plugins/intel_npu/tools/common/include/tensor_utils.hpp +++ b/src/plugins/intel_npu/tools/common/include/tensor_utils.hpp @@ -19,7 +19,7 @@ namespace utils { * @param in The source tensor * @param out The destination tensor */ -void copyTensor(const ov::Tensor& in, const ov::Tensor& out); +void copyTensor(const ov::Tensor& in, ov::Tensor& out); /** * @brief Copies the contents of one tensor into another one which bears the same shape. Precision conversions from @@ -28,7 +28,7 @@ void copyTensor(const ov::Tensor& in, const ov::Tensor& out); * @param in The source tensor * @param out The destination tensor */ -void convertTensorPrecision(const ov::Tensor& in, const ov::Tensor& out); +void convertTensorPrecision(const ov::Tensor& in, ov::Tensor& out); /** * @brief Constructs a tensor with the same content as the source but with the precision converted to the specified diff --git a/src/plugins/intel_npu/tools/common/src/tensor_utils.cpp b/src/plugins/intel_npu/tools/common/src/tensor_utils.cpp index c2b4902497777e..1532891aada2a5 100644 --- a/src/plugins/intel_npu/tools/common/src/tensor_utils.cpp +++ b/src/plugins/intel_npu/tools/common/src/tensor_utils.cpp @@ -17,7 +17,7 @@ namespace { template -void convertTensorPrecisionImpl(const ov::Tensor& in, const ov::Tensor& out) { +void convertTensorPrecisionImpl(const ov::Tensor& in, ov::Tensor& out) { const auto inputBuffer = in.data(); OPENVINO_ASSERT(inputBuffer != nullptr, "Tensor was not allocated"); @@ -34,7 +34,7 @@ void convertTensorPrecisionImpl(const ov::Tensor& in, const ov::Tensor& out) { namespace npu { namespace utils { -void copyTensor(const ov::Tensor& in, const ov::Tensor& out) { +void copyTensor(const ov::Tensor& in, ov::Tensor& out) { OPENVINO_ASSERT(in.get_element_type() == out.get_element_type(), "Precision mismatch"); OPENVINO_ASSERT(in.get_shape() == out.get_shape(), "Shape mismatch"); @@ -47,7 +47,7 @@ void copyTensor(const ov::Tensor& in, const ov::Tensor& out) { std::copy_n(inputBuffer, in.get_byte_size(), outputBuffer); } -void convertTensorPrecision(const ov::Tensor& in, const ov::Tensor& out) { +void convertTensorPrecision(const ov::Tensor& in, ov::Tensor& out) { OPENVINO_ASSERT(in.get_shape() == out.get_shape(), "Mismatch in Dims"); const ov::element::Type& inPrecision = in.get_element_type(); diff --git a/src/plugins/intel_npu/tools/single-image-test/main.cpp b/src/plugins/intel_npu/tools/single-image-test/main.cpp index ced1f1c6c2c6f8..5e0356bc66e7d5 100644 --- a/src/plugins/intel_npu/tools/single-image-test/main.cpp +++ b/src/plugins/intel_npu/tools/single-image-test/main.cpp @@ -320,8 +320,11 @@ std::vector getStrides(const ov::Shape& shape) { return strides; } -std::vector ovToCV(const ov::Tensor& tensor, const ov::Shape& shape, const ov::Layout& layout, - size_t batchInd = 0, size_t depthInd = 0) { +std::vector ovToCV(ov::Tensor& tensor, + const ov::Shape& shape, + const ov::Layout& layout, + size_t batchInd = 0, + size_t depthInd = 0) { const ov::element::Type& precision = tensor.get_element_type(); OPENVINO_ASSERT(layout == ov::Layout("NCHW") || layout == ov::Layout("NCDHW"), @@ -431,7 +434,11 @@ struct BatchIndexer { return sstream.str(); } }; -void cvToOV(const cv::Mat& cvImg, const BatchIndexer &cvImgInBatch, const ov::Tensor& tensor, const ov::Shape& shape, const ov::Layout& layout, +void cvToOV(const cv::Mat& cvImg, + const BatchIndexer& cvImgInBatch, + ov::Tensor& tensor, + const ov::Shape& shape, + const ov::Layout& layout, const std::string& colorFormat) { const ov::element::Type& precision = tensor.get_element_type(); @@ -533,13 +540,15 @@ void cvToOV(const cv::Mat& cvImg, const BatchIndexer &cvImgInBatch, const ov::Te << n << " up to " << N << " with image data from the array: " << cvImgInBatch.to_string() << std::endl; } - cv::Mat batch(static_cast(H), static_cast(W), cvType, + cv::Mat batch(static_cast(H), + static_cast(W), + cvType, dataBuffer + n * (out.size().area() * out.elemSize())); out.copyTo(batch); } } else if (layout == ov::Layout("NCHW")) { ov::Tensor auxTensor(precision, shape); - const ov::Tensor &outTensor = (cvImgInBatch.index == 0 ? tensor : auxTensor); + ov::Tensor& outTensor = (cvImgInBatch.index == 0 ? tensor : auxTensor); // only a first image from an input image array fills an original input tensor up. // Subsequent images (if exist) will fill batch slices of the input tensor // by its number in the input array respectively @@ -744,7 +753,7 @@ std::string cleanName(std::string&& name) { ov::Tensor loadImages(const ov::element::Type& precision, const ov::Shape& shape, const ov::Layout& layout, const std::vector& filePaths, const std::string& colorFormat) { - const ov::Tensor tensor(precision, shape); + ov::Tensor tensor(precision, shape); for (size_t fileIndex = 0; fileIndex != filePaths.size(); fileIndex++) { const auto &filePath = filePaths[fileIndex]; const auto frame = cv::imread(filePath, cv::IMREAD_COLOR); @@ -769,7 +778,7 @@ void loadBinary(const std::string& filePath, const BatchIndexer &fileSourceInBat if (dataPrecision != modelPrecision && dataPrecision != ov::element::Type_t::dynamic) { std::cout << "Converting " << filePath << " input from " << dataPrecision << " to " << modelPrecision << std::endl; - const ov::Tensor inputTensor(dataPrecision, shape); + ov::Tensor inputTensor(dataPrecision, shape); if (fileBytes == inputTensor.get_byte_size()) { binaryFile.read(reinterpret_cast(inputTensor.data()), static_cast(fileBytes)); npu::utils::convertTensorPrecision(inputTensor, requestedTensor); @@ -786,10 +795,10 @@ void loadBinary(const std::string& filePath, const BatchIndexer &fileSourceInBat " expected while converting precision from ", dataPrecision, " to ", modelPrecision); ov::Shape debatchedInputTensorShape(shape); debatchedInputTensorShape[ov::layout::batch_idx(layout)] = 1; - const ov::Tensor inputDebatchedTensor(dataPrecision, debatchedInputTensorShape); + ov::Tensor inputDebatchedTensor(dataPrecision, debatchedInputTensorShape); binaryFile.read(reinterpret_cast(inputDebatchedTensor.data()), static_cast(fileBytes)); - const ov::Tensor convertedPrecisionTensor(modelPrecision, debatchedInputTensorShape); + ov::Tensor convertedPrecisionTensor(modelPrecision, debatchedInputTensorShape); npu::utils::convertTensorPrecision(inputDebatchedTensor, convertedPrecisionTensor); std::list tensorsToJoin; std::list tensorsFromSplit = npu::utils::splitBatchedTensor(requestedTensor, layout, N); @@ -880,7 +889,7 @@ ov::Tensor loadInput(const ov::element::Type& modelPrecision, } ov::Tensor loadTensor(const ov::element::Type& precision, const ov::Shape& shape, const std::string& filePath) { - const ov::Tensor tensor(precision, shape); + ov::Tensor tensor(precision, shape); std::ifstream file(filePath, std::ios_base::in | std::ios_base::binary); OPENVINO_ASSERT(file.is_open(), "Can't open file ", filePath, " for read"); @@ -895,7 +904,7 @@ void dumpTensor(const ov::Tensor& tensor, const std::string& filePath) { std::ofstream file(filePath, std::ios_base::out | std::ios_base::binary); OPENVINO_ASSERT(file.is_open(), "Can't open file ", filePath, " for write"); - const auto dataBuffer = reinterpret_cast(tensor.data()); + const auto dataBuffer = reinterpret_cast(tensor.data()); file.write(dataBuffer, static_cast(tensor.get_byte_size())); } @@ -1436,7 +1445,7 @@ std::vector softmax(std::vector& tensor) { return results; } -bool testNRMSE(const TensorMap& outputs, const TensorMap& references, size_t batch_size = 1) { +bool testNRMSE(TensorMap& outputs, const TensorMap& references, size_t batch_size = 1) { if (batch_size != 1) { throw std::runtime_error( "The testcase 'nrmse' doesn't support any `override_model_batch_size` values besides 1 yet"); @@ -1450,7 +1459,7 @@ bool testNRMSE(const TensorMap& outputs, const TensorMap& references, size_t bat std::vector skipped_layers; skipped_layers = splitStringList(FLAGS_skip_output_layers, ';'); - for (const auto& [tensorName, output] : outputs) { + for (auto& [tensorName, output] : outputs) { if (std::find(skipped_layers.begin(), skipped_layers.end(), tensorName) != skipped_layers.end()) { std::cout << "Skip NRMSE test for layers: " << tensorName << std::endl; continue; @@ -1472,7 +1481,10 @@ bool testNRMSE(const TensorMap& outputs, const TensorMap& references, size_t bat auto refSoftMax = softmax(refOutput); std::copy_n(actSoftMax.begin(), output.get_size(), output.data()); - std::copy_n(refSoftMax.begin(), referencesIterator->second.get_size(), referencesIterator->second.data()); + // Why reference data is not updated? + std::copy_n(refSoftMax.begin(), + referencesIterator->second.get_size(), + const_cast(referencesIterator->second.data())); } std::cout << "Compare " << tensorName << " with reference" << std::endl; @@ -2272,10 +2284,10 @@ static int runSingleImageTest() { std::cout << "Run inference on " << FLAGS_device << std::endl; const auto startTime = Time::now(); - const auto outInference = runInfer(inferRequest, compiledModel, inTensors, dumpedInputsPaths); + auto outInference = runInfer(inferRequest, compiledModel, inTensors, dumpedInputsPaths); const auto endTime = Time::now(); - const TensorMap& outputTensors = outInference.first; + TensorMap& outputTensors = outInference.first; printPerformanceCountsAndLatency(numberOfTestCase, outInference.second, endTime - startTime); diff --git a/src/plugins/template/backend/evaluates_map.cpp b/src/plugins/template/backend/evaluates_map.cpp index 26d95ccdeb360c..06d5eee729bbab 100644 --- a/src/plugins/template/backend/evaluates_map.cpp +++ b/src/plugins/template/backend/evaluates_map.cpp @@ -12,19 +12,19 @@ std::vector get_floats(const ov::Tensor& input, const ov::Shape& shape) { switch (input.get_element_type()) { case ov::element::bf16: { - ov::bfloat16* p = input.data(); + auto p = input.data(); for (size_t i = 0; i < input_size; ++i) { result[i] = float(p[i]); } } break; case ov::element::f16: { - ov::float16* p = input.data(); + auto p = input.data(); for (size_t i = 0; i < input_size; ++i) { result[i] = float(p[i]); } } break; case ov::element::f32: { - float* p = input.data(); + auto p = input.data(); memcpy(result.data(), p, input_size * sizeof(float)); } break; default: diff --git a/src/plugins/template/backend/ops/binary_convolution.cpp b/src/plugins/template/backend/ops/binary_convolution.cpp index ff24f022a3f318..3ac2a36044d502 100644 --- a/src/plugins/template/backend/ops/binary_convolution.cpp +++ b/src/plugins/template/backend/ops/binary_convolution.cpp @@ -14,8 +14,8 @@ inline void evaluate(const std::shared_ptr& op, using T_IN = typename ov::element_type_traits::value_type; using T_F = typename ov::element_type_traits::value_type; - const auto in_data_ptr = static_cast(inputs[0].data()); - const auto filter_data_ptr = static_cast(inputs[1].data()); + const auto in_data_ptr = static_cast(inputs[0].data()); + const auto filter_data_ptr = static_cast(inputs[1].data()); auto out_data_ptr = static_cast(outputs[0].data()); const auto in_shape = inputs[0].get_shape(); const auto filter_shape = inputs[1].get_shape(); diff --git a/src/plugins/template/backend/ops/ctc_loss.cpp b/src/plugins/template/backend/ops/ctc_loss.cpp index e63c1bf5297f58..ad7253eab11a8c 100644 --- a/src/plugins/template/backend/ops/ctc_loss.cpp +++ b/src/plugins/template/backend/ops/ctc_loss.cpp @@ -32,12 +32,12 @@ inline void evaluate(const std::shared_ptr& op, const ov::TensorVector& inputs) { using T1 = typename ov::element_type_traits::value_type; using T2 = typename ov::element_type_traits::value_type; - ov::reference::CTCLoss(static_cast(inputs[0].data()), + ov::reference::CTCLoss(static_cast(inputs[0].data()), inputs[0].get_shape(), - static_cast(inputs[1].data()), - static_cast(inputs[2].data()), - static_cast(inputs[3].data()), - static_cast(inputs[4].data()), + static_cast(inputs[1].data()), + static_cast(inputs[2].data()), + static_cast(inputs[3].data()), + static_cast(inputs[4].data()), op->get_preprocess_collapse_repeated(), op->get_ctc_merge_repeated(), op->get_unique(), diff --git a/src/plugins/template/backend/ops/interpolate.cpp b/src/plugins/template/backend/ops/interpolate.cpp index a6a5cd70c2bc7d..e804b35fd7e1ba 100644 --- a/src/plugins/template/backend/ops/interpolate.cpp +++ b/src/plugins/template/backend/ops/interpolate.cpp @@ -14,11 +14,11 @@ bool evaluate(const std::shared_ptr& op, ov::element::Type input_et = op->get_input_element_type(0); switch (input_et) { case ov::element::f64: - ov::reference::interpolate(inputs[0].data(), - op->get_input_partial_shape(0), - outputs[0].data(), - op->get_output_shape(0), - op->get_attrs()); + ov::reference::interpolate(inputs[0].data(), + op->get_input_partial_shape(0), + outputs[0].data(), + op->get_output_shape(0), + op->get_attrs()); break; case ov::element::f32: ov::reference::interpolate(inputs[0].data(), @@ -157,7 +157,7 @@ std::vector get_scales_vector(const ov::TensorVector& args, std::vector scales; size_t num_of_axes = axes.size(); if (attrs.shape_calculation_mode == ov::op::util::InterpolateBase::ShapeCalcMode::SCALES) { - float* scales_ptr = args[scales_port].data(); + auto scales_ptr = args[scales_port].data(); scales.insert(scales.end(), scales_ptr, scales_ptr + num_of_axes); } else { auto target_shape = get_target_shape_vector(args, num_of_axes); @@ -209,7 +209,7 @@ bool evaluate_interpolate(const std::shared_ptr& op, size_t bytes_in_padded_input = shape_size(padded_input_shape) * type_size; std::vector padded_input_data(bytes_in_padded_input, 0); - const uint8_t* data_ptr = static_cast(inputs[0].data()); + auto data_ptr = static_cast(inputs[0].data()); uint8_t* padded_data_ptr = padded_input_data.data(); reference::pad_input_data(data_ptr, diff --git a/src/plugins/template/backend/ops/mvn.cpp b/src/plugins/template/backend/ops/mvn.cpp index 17c913787f0814..aa566ee2b868c2 100644 --- a/src/plugins/template/backend/ops/mvn.cpp +++ b/src/plugins/template/backend/ops/mvn.cpp @@ -21,7 +21,7 @@ bool evaluate(const std::shared_ptr& op, ov::TensorVector& outp namespace mvn_6_axes { template ov::AxisSet mvn_6_reduction_axes(const ov::Tensor& axes_input, size_t rank) { - T* a = axes_input.data(); + const T* a = axes_input.data(); auto v = std::vector(a, a + axes_input.get_shape()[0]); std::vector axes(v.size(), 0); for (size_t i = 0; i < v.size(); i++) { diff --git a/src/plugins/template/backend/ops/reorg_yolo.cpp b/src/plugins/template/backend/ops/reorg_yolo.cpp index 8e92e76ddf7ddd..b41587cacfb249 100644 --- a/src/plugins/template/backend/ops/reorg_yolo.cpp +++ b/src/plugins/template/backend/ops/reorg_yolo.cpp @@ -9,7 +9,7 @@ bool evaluate(const std::shared_ptr& op, ov::TensorVector& outputs, const ov::TensorVector& inputs) { - ov::reference::reorg_yolo(static_cast(inputs[0].data()), + ov::reference::reorg_yolo(static_cast(inputs[0].data()), static_cast(outputs[0].data()), inputs[0].get_shape(), op->get_strides().at(0), diff --git a/src/plugins/template/backend/ops/sequences.cpp b/src/plugins/template/backend/ops/sequences.cpp index 8704acff3109c8..44f2ddd5046950 100644 --- a/src/plugins/template/backend/ops/sequences.cpp +++ b/src/plugins/template/backend/ops/sequences.cpp @@ -15,17 +15,17 @@ inline void evaluate(const std::shared_ptr& op, const ov::TensorVector& inputs) { using T1 = typename ov::element_type_traits::value_type; using T2 = typename ov::element_type_traits::value_type; - ov::reference::rnn_sequence(static_cast(inputs[0].data()), + ov::reference::rnn_sequence(static_cast(inputs[0].data()), inputs[0].get_shape(), - static_cast(inputs[1].data()), + static_cast(inputs[1].data()), inputs[1].get_shape(), - static_cast(inputs[2].data()), + static_cast(inputs[2].data()), inputs[2].get_shape(), - static_cast(inputs[3].data()), + static_cast(inputs[3].data()), inputs[3].get_shape(), - static_cast(inputs[4].data()), + static_cast(inputs[4].data()), inputs[4].get_shape(), - static_cast(inputs[5].data()), + static_cast(inputs[5].data()), inputs[5].get_shape(), static_cast(outputs[0].data()), static_cast(outputs[1].data()), @@ -61,19 +61,19 @@ inline void evaluate(const std::shared_ptr& op, const ov::TensorVector& inputs) { using T1 = typename ov::element_type_traits::value_type; using T2 = typename ov::element_type_traits::value_type; - ov::reference::lstm_sequence(static_cast(inputs[0].data()), + ov::reference::lstm_sequence(static_cast(inputs[0].data()), inputs[0].get_shape(), - static_cast(inputs[1].data()), + static_cast(inputs[1].data()), inputs[1].get_shape(), - static_cast(inputs[2].data()), + static_cast(inputs[2].data()), inputs[2].get_shape(), - static_cast(inputs[3].data()), + static_cast(inputs[3].data()), inputs[3].get_shape(), - static_cast(inputs[4].data()), + static_cast(inputs[4].data()), inputs[4].get_shape(), - static_cast(inputs[5].data()), + static_cast(inputs[5].data()), inputs[5].get_shape(), - static_cast(inputs[6].data()), + static_cast(inputs[6].data()), inputs[6].get_shape(), static_cast(outputs[0].data()), static_cast(outputs[1].data()), @@ -112,17 +112,17 @@ inline void evaluate(const std::shared_ptr& op, const ov::TensorVector& inputs) { using T1 = typename ov::element_type_traits::value_type; using T2 = typename ov::element_type_traits::value_type; - ov::reference::gru_sequence(static_cast(inputs[0].data()), + ov::reference::gru_sequence(static_cast(inputs[0].data()), inputs[0].get_shape(), - static_cast(inputs[1].data()), + static_cast(inputs[1].data()), inputs[1].get_shape(), - static_cast(inputs[2].data()), + static_cast(inputs[2].data()), inputs[2].get_shape(), - static_cast(inputs[3].data()), + static_cast(inputs[3].data()), inputs[3].get_shape(), - static_cast(inputs[4].data()), + static_cast(inputs[4].data()), inputs[4].get_shape(), - static_cast(inputs[5].data()), + static_cast(inputs[5].data()), inputs[5].get_shape(), static_cast(outputs[0].data()), static_cast(outputs[1].data()), @@ -160,17 +160,17 @@ inline void evaluate(const std::shared_ptr& op, const ov::TensorVector& inputs) { using T1 = typename ov::element_type_traits::value_type; using T2 = typename ov::element_type_traits::value_type; - ov::reference::gru_sequence(static_cast(inputs[0].data()), + ov::reference::gru_sequence(static_cast(inputs[0].data()), inputs[0].get_shape(), - static_cast(inputs[1].data()), + static_cast(inputs[1].data()), inputs[1].get_shape(), - static_cast(inputs[2].data()), + static_cast(inputs[2].data()), inputs[2].get_shape(), - static_cast(inputs[3].data()), + static_cast(inputs[3].data()), inputs[3].get_shape(), - static_cast(inputs[4].data()), + static_cast(inputs[4].data()), inputs[4].get_shape(), - static_cast(inputs[5].data()), + static_cast(inputs[5].data()), inputs[5].get_shape(), static_cast(outputs[0].data()), static_cast(outputs[1].data()), @@ -179,7 +179,7 @@ inline void evaluate(const std::shared_ptr& op, op->get_clip(), op->get_direction(), op->get_linear_before_reset(), - static_cast(inputs[6].data())); + static_cast(inputs[6].data())); } } // namespace augru_seq diff --git a/src/plugins/template/src/plugin.cpp b/src/plugins/template/src/plugin.cpp index d46339ef94c415..cbd1f07f74eb7d 100644 --- a/src/plugins/template/src/plugin.cpp +++ b/src/plugins/template/src/plugin.cpp @@ -72,7 +72,7 @@ std::shared_ptr get_ov_model_from_blob(const ov::template_plugin::Plu size_t offset, const ov::AnyMap& properties) { if (auto blob_it = properties.find(ov::hint::compiled_blob.name()); blob_it != properties.end()) { - if (auto&& blob = blob_it->second.as(); blob) { + if (auto blob = blob_it->second.as(); blob) { ov::SharedStreamBuffer shared_buffer(reinterpret_cast(blob.data()), blob.get_byte_size()); std::istream blob_stream(&shared_buffer); blob_stream.seekg(offset, std::ios::beg); diff --git a/src/plugins/template/tests/functional/op_reference/constant.cpp b/src/plugins/template/tests/functional/op_reference/constant.cpp index 520c993578e2fc..8d86ef5ed854ab 100644 --- a/src/plugins/template/tests/functional/op_reference/constant.cpp +++ b/src/plugins/template/tests/functional/op_reference/constant.cpp @@ -113,7 +113,7 @@ class ReferenceConstantLayerTest_MultiUse : public ReferenceConstantLayerTest { const auto A = std::make_shared( params.inType, params.inputShape, - std::vector{std::to_string(*reinterpret_cast(params.inputData.data()))}); + std::vector{std::to_string(*reinterpret_cast(params.inputData.data()))}); return std::make_shared(A, ParameterVector{}); } }; diff --git a/src/tests/functional/plugin/shared/include/behavior/compiled_model/compiled_model_base.hpp b/src/tests/functional/plugin/shared/include/behavior/compiled_model/compiled_model_base.hpp index ee443d5603f587..832d0e711152bf 100644 --- a/src/tests/functional/plugin/shared/include/behavior/compiled_model/compiled_model_base.hpp +++ b/src/tests/functional/plugin/shared/include/behavior/compiled_model/compiled_model_base.hpp @@ -107,8 +107,8 @@ class OVCompiledModelBaseTest : public testing::WithParamInterface*>(expected.data()), \ - static_cast*>(actual.data()), \ + static_cast*>(expected.data()), \ + static_cast*>(actual.data()), \ expected.get_size(), \ dev_threshold, \ abs_threshold); \ diff --git a/src/tests/functional/shared_test_classes/src/single_op/generate_proposals.cpp b/src/tests/functional/shared_test_classes/src/single_op/generate_proposals.cpp index 3232979cb3b4a1..ccaf0d60c0565d 100644 --- a/src/tests/functional/shared_test_classes/src/single_op/generate_proposals.cpp +++ b/src/tests/functional/shared_test_classes/src/single_op/generate_proposals.cpp @@ -112,8 +112,8 @@ void GenerateProposalsLayerTest::compare(const std::vector& expected const auto actualNumRois = actual[i].get_shape()[0]; ASSERT_LE(expectedNumRois, actualNumRois); - const auto actualBuffer = static_cast(actual[i].data()); - const auto expectedBuffer = static_cast(expected[i].data()); + const auto actualBuffer = static_cast(actual[i].data()); + const auto expectedBuffer = static_cast(expected[i].data()); const auto outputSize = i == 0 ? 4 : 1; rel_threshold = ov::test::utils::tensor_comparation::calculate_default_rel_threshold( diff --git a/src/tests/test_utils/common_test_utils/src/test_case.cpp b/src/tests/test_utils/common_test_utils/src/test_case.cpp index 6c002e32e65c7e..2d8bbb2578b86f 100644 --- a/src/tests/test_utils/common_test_utils/src/test_case.cpp +++ b/src/tests/test_utils/common_test_utils/src/test_case.cpp @@ -153,8 +153,8 @@ std::pair TestCase::compare_results(size_t tol break; case element::Type_t::string: { res = ::testing::AssertionSuccess(); - std::string* exp_strings = exp_result.data(); - std::string* res_strings = result_tensor.data(); + const std::string* exp_strings = exp_result.data(); + const std::string* res_strings = result_tensor.data(); for (size_t i = 0; i < exp_result.get_size(); ++i) { if (exp_strings[i] != res_strings[i]) { res = ::testing::AssertionFailure() << "Wrong string value at index " << i << ", expected \""