Skip to content

Commit 66557de

Browse files
rajatkrishnaallneslikholat
authored
[JAVA_API] Wrapper for ov::PartialShape (#883)
* Get partial shape from output * Update PartialShape getDimension to align with C++ API * Use pointer arithmetic to access Partial Shape dimension Co-authored-by: Nesterov Alexander <nesterov.alexander@outlook.com> * Remove redundant delete method * Revert "Use pointer arithmetic to access Partial Shape dimension" --------- Co-authored-by: Nesterov Alexander <nesterov.alexander@outlook.com> Co-authored-by: Anna Likholat <anna.likholat@intel.com>
1 parent ec67eed commit 66557de

File tree

8 files changed

+132
-10
lines changed

8 files changed

+132
-10
lines changed

modules/java_api/src/main/cpp/dimension.cpp

-6
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,3 @@ JNIEXPORT jint JNICALL Java_org_intel_openvino_Dimension_getLength(JNIEnv *env,
1717
)
1818
return 0;
1919
}
20-
21-
JNIEXPORT void JNICALL Java_org_intel_openvino_Dimension_delete(JNIEnv *, jobject, jlong addr)
22-
{
23-
Dimension *dim = (Dimension *)addr;
24-
delete dim;
25-
}

modules/java_api/src/main/cpp/openvino_java.hpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,19 @@ extern "C"
107107

108108
// ov::Dimension
109109
JNIEXPORT jint JNICALL Java_org_intel_openvino_Dimension_getLength(JNIEnv *, jobject, jlong);
110-
JNIEXPORT void JNICALL Java_org_intel_openvino_Dimension_delete(JNIEnv *, jobject, jlong);
111110

112111
// ov::Output<ov::Node>
113112
JNIEXPORT jstring JNICALL Java_org_intel_openvino_Output_GetAnyName(JNIEnv *, jobject, jlong);
114113
JNIEXPORT jintArray JNICALL Java_org_intel_openvino_Output_GetShape(JNIEnv *, jobject, jlong);
114+
JNIEXPORT jlong JNICALL Java_org_intel_openvino_Output_GetPartialShape(JNIEnv *, jobject, jlong);
115115
JNIEXPORT int JNICALL Java_org_intel_openvino_Output_GetElementType(JNIEnv *, jobject, jlong);
116116
JNIEXPORT void JNICALL Java_org_intel_openvino_Output_delete(JNIEnv *, jobject, jlong);
117117

118+
// ov::PartialShape
119+
JNIEXPORT jlong JNICALL Java_org_intel_openvino_PartialShape_GetDimension(JNIEnv *, jobject, jlong, jint);
120+
JNIEXPORT jintArray JNICALL Java_org_intel_openvino_PartialShape_GetMaxShape(JNIEnv *, jobject, jlong);
121+
JNIEXPORT jintArray JNICALL Java_org_intel_openvino_PartialShape_GetMinShape(JNIEnv *, jobject, jlong);
122+
118123
#ifdef __cplusplus
119124
}
120125
#endif

modules/java_api/src/main/cpp/output.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ JNIEXPORT int JNICALL Java_org_intel_openvino_Output_GetElementType(JNIEnv *env,
4747
return 0;
4848
}
4949

50+
JNIEXPORT jlong JNICALL Java_org_intel_openvino_Output_GetPartialShape(JNIEnv *env, jobject obj, jlong addr) {
51+
JNI_METHOD("GetPartialShape",
52+
Output<Node> *output = (Output<Node> *)addr;
53+
const PartialShape& partialShape = output->get_partial_shape();
54+
55+
return (jlong) &partialShape;
56+
)
57+
return 0;
58+
}
59+
5060
JNIEXPORT void JNICALL Java_org_intel_openvino_Output_delete(JNIEnv *, jobject, jlong addr)
5161
{
5262
Output<Node> *obj = (Output<Node> *)addr;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Copyright (C) 2020-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#include <jni.h> // JNI header provided by JDK
5+
#include "openvino/openvino.hpp"
6+
7+
#include "openvino_java.hpp"
8+
#include "jni_common.hpp"
9+
10+
using namespace ov;
11+
12+
JNIEXPORT jlong JNICALL Java_org_intel_openvino_PartialShape_GetDimension(JNIEnv *env, jobject obj, jlong addr, jint index) {
13+
JNI_METHOD("GetDimension",
14+
PartialShape* partial_shape = (PartialShape *)addr;
15+
return (jlong) &(*partial_shape)[index];
16+
)
17+
return 0;
18+
}
19+
20+
JNIEXPORT jintArray JNICALL Java_org_intel_openvino_PartialShape_GetMaxShape(JNIEnv *env, jobject obj, jlong addr) {
21+
JNI_METHOD("GetMaxShape",
22+
PartialShape* partial_shape = (PartialShape *)addr;
23+
Shape max_shape = partial_shape->get_max_shape();
24+
25+
jintArray result = env->NewIntArray(max_shape.size());
26+
if (!result) {
27+
throw std::runtime_error("Out of memory!");
28+
} jint *arr = env->GetIntArrayElements(result, nullptr);
29+
30+
for (int i = 0; i < max_shape.size(); ++i)
31+
arr[i] = max_shape[i];
32+
33+
env->ReleaseIntArrayElements(result, arr, 0);
34+
return result;
35+
)
36+
return 0;
37+
}
38+
39+
JNIEXPORT jintArray JNICALL Java_org_intel_openvino_PartialShape_GetMinShape(JNIEnv *env, jobject obj, jlong addr) {
40+
JNI_METHOD("GetMinShape",
41+
PartialShape* partial_shape = (PartialShape *)addr;
42+
Shape min_shape = partial_shape->get_min_shape();
43+
44+
jintArray result = env->NewIntArray(min_shape.size());
45+
if (!result) {
46+
throw std::runtime_error("Out of memory!");
47+
} jint *arr = env->GetIntArrayElements(result, nullptr);
48+
49+
for (int i = 0; i < min_shape.size(); ++i)
50+
arr[i] = min_shape[i];
51+
52+
env->ReleaseIntArrayElements(result, arr, 0);
53+
return result;
54+
)
55+
return 0;
56+
}

modules/java_api/src/main/java/org/intel/openvino/Dimension.java

-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,4 @@ public int get_length() {
2323

2424
/*----------------------------------- native methods -----------------------------------*/
2525
private static native int getLength(long addr);
26-
27-
@Override
28-
protected native void delete(long nativeObj);
2926
}

modules/java_api/src/main/java/org/intel/openvino/Output.java

+7
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ public int[] get_shape() {
2020
return GetShape(nativeObj);
2121
}
2222

23+
/** Returns the partial shape of the output referred to by this output handle. */
24+
public PartialShape get_partial_shape() {
25+
return new PartialShape(GetPartialShape(nativeObj));
26+
}
27+
2328
/** Returns the element type of the output referred to by this output handle. */
2429
public ElementType get_element_type() {
2530
return ElementType.valueOf(GetElementType(nativeObj));
@@ -30,6 +35,8 @@ public ElementType get_element_type() {
3035

3136
private static native int[] GetShape(long addr);
3237

38+
private static native long GetPartialShape(long addr);
39+
3340
private static native int GetElementType(long addr);
3441

3542
@Override
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (C) 2020-2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package org.intel.openvino;
5+
6+
/** This class represents the definitions and operations about partial shape. */
7+
public class PartialShape extends Wrapper {
8+
9+
public PartialShape(long addr) {
10+
super(addr);
11+
}
12+
13+
/**
14+
* Get the dimension at specified index of a partial shape.
15+
*
16+
* @param index The index of dimension.
17+
* @return The particular dimension of partial shape.
18+
*/
19+
public Dimension get_dimension(int index) {
20+
return new Dimension(GetDimension(nativeObj, index));
21+
}
22+
23+
/** Returns the max bounding shape. */
24+
public int[] get_max_shape() {
25+
return GetMaxShape(nativeObj);
26+
}
27+
28+
/** Returns the min bounding shape. */
29+
public int[] get_min_shape() {
30+
return GetMinShape(nativeObj);
31+
}
32+
33+
/*----------------------------------- native methods -----------------------------------*/
34+
private static native long GetDimension(long addr, int index);
35+
36+
private static native int[] GetMaxShape(long addr);
37+
38+
private static native int[] GetMinShape(long addr);
39+
}

modules/java_api/src/test/java/org/intel/openvino/ModelTests.java

+14
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,20 @@ public void testGetShape() {
4444
assertArrayEquals("Shape", ref, outputs.get(0).get_shape());
4545
}
4646

47+
@Test
48+
public void testGetPartialShape() {
49+
ArrayList<Output> outputs = net.outputs();
50+
int[] ref = new int[] {1, 10};
51+
52+
PartialShape partialShape = outputs.get(0).get_partial_shape();
53+
for (int i = 0; i < ref.length; i++) {
54+
Dimension dim = partialShape.get_dimension(i);
55+
assertEquals(ref[i], dim.get_length());
56+
}
57+
assertArrayEquals("MaxShape", ref, partialShape.get_max_shape());
58+
assertArrayEquals("MinShape", ref, partialShape.get_min_shape());
59+
}
60+
4761
@Test
4862
public void testReshape() {
4963
int[] inpDims = net.input().get_shape();

0 commit comments

Comments
 (0)