Skip to content

Commit 7958322

Browse files
committed
[JAVA_API] Add get_element_type() to Tensor
1 parent 4272f47 commit 7958322

File tree

4 files changed

+24
-0
lines changed

4 files changed

+24
-0
lines changed

modules/java_api/src/main/cpp/openvino_java.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ extern "C"
6464
JNIEXPORT jlong JNICALL Java_org_intel_openvino_Tensor_TensorLong(JNIEnv *, jobject, jintArray, jlongArray);
6565
JNIEXPORT jint JNICALL Java_org_intel_openvino_Tensor_GetSize(JNIEnv *, jobject, jlong);
6666
JNIEXPORT jintArray JNICALL Java_org_intel_openvino_Tensor_GetShape(JNIEnv *, jobject, jlong);
67+
JNIEXPORT jint JNICALL Java_org_intel_openvino_Tensor_GetElementType(JNIEnv *, jobject, jlong);
6768
JNIEXPORT jfloatArray JNICALL Java_org_intel_openvino_Tensor_asFloat(JNIEnv *, jobject, jlong);
6869
JNIEXPORT jintArray JNICALL Java_org_intel_openvino_Tensor_asInt(JNIEnv *, jobject, jlong);
6970
JNIEXPORT void JNICALL Java_org_intel_openvino_Tensor_delete(JNIEnv *, jobject, jlong);

modules/java_api/src/main/cpp/tensor.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,19 @@ JNIEXPORT jintArray JNICALL Java_org_intel_openvino_Tensor_GetShape(JNIEnv *env,
114114
return 0;
115115
}
116116

117+
JNIEXPORT jint JNICALL Java_org_intel_openvino_Tensor_GetElementType(JNIEnv *env, jobject, jlong addr)
118+
{
119+
JNI_METHOD(
120+
"GetElementType",
121+
Tensor *ov_tensor = (Tensor *)addr;
122+
123+
element::Type_t t_type = ov_tensor->get_element_type();
124+
jint type = static_cast<jint>(t_type);
125+
return type;
126+
)
127+
return 0;
128+
}
129+
117130
JNIEXPORT jfloatArray JNICALL Java_org_intel_openvino_Tensor_asFloat(JNIEnv *env, jobject, jlong addr)
118131
{
119132
JNI_METHOD(

modules/java_api/src/main/java/org/intel/openvino/Tensor.java

+7
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ public int[] get_shape() {
5656
return GetShape(nativeObj);
5757
}
5858

59+
/** Returns the tensor element type. */
60+
public ElementType get_element_type() {
61+
return ElementType.valueOf(GetElementType(nativeObj));
62+
}
63+
5964
/** Returns a tensor data as floating point array. */
6065
public float[] data() {
6166
return asFloat(nativeObj);
@@ -77,6 +82,8 @@ public int[] as_int() {
7782

7883
private static native int[] GetShape(long addr);
7984

85+
private static native int GetElementType(long addr);
86+
8087
private static native float[] asFloat(long addr);
8188

8289
private static native int[] asInt(long addr);

modules/java_api/src/test/java/org/intel/openvino/TensorTests.java

+3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ public void testGetTensorFromFloat() {
1616

1717
assertArrayEquals(tensor.get_shape(), dimsArr);
1818
assertArrayEquals(tensor.data(), data, 0.0f);
19+
assertEquals(ElementType.f32, tensor.get_element_type());
1920
}
2021

2122
@Test
@@ -29,6 +30,7 @@ public void testGetTensorFromInt() {
2930
assertArrayEquals(dimsArr, tensor.get_shape());
3031
assertArrayEquals(inputData, tensor.as_int());
3132
assertEquals(size, tensor.get_size());
33+
assertEquals(ElementType.i32, tensor.get_element_type());
3234
}
3335

3436
@Test
@@ -41,5 +43,6 @@ public void testGetTensorFromLong() {
4143

4244
assertArrayEquals(dimsArr, tensor.get_shape());
4345
assertEquals(size, tensor.get_size());
46+
assertEquals(ElementType.i64, tensor.get_element_type());
4447
}
4548
}

0 commit comments

Comments
 (0)