Skip to content

Commit 63aeaab

Browse files
upgrade djl version to 0.28.0 (opensearch-project#2578) (opensearch-project#2580)
* upgrade djl version to latest 0.28.0 Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com> * force onnxruntime_gpu to 1.16.3 Signed-off-by: Yaliang Wu <ylwu@amazon.com> --------- Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com> Signed-off-by: Yaliang Wu <ylwu@amazon.com> Co-authored-by: Yaliang Wu <ylwu@amazon.com> (cherry picked from commit 01c85cb) Co-authored-by: Bhavana Ramaram <rbhavna@amazon.com>
1 parent 0e27181 commit 63aeaab

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

ml-algorithms/build.gradle

+5-5
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,22 @@ dependencies {
4242
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.7.0'
4343
implementation group: 'com.google.guava', name: 'guava', version: '32.1.2-jre'
4444
implementation group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
45-
implementation platform("ai.djl:bom:0.21.0")
46-
implementation group: 'ai.djl.pytorch', name: 'pytorch-model-zoo', version: '0.21.0'
45+
implementation platform("ai.djl:bom:0.28.0")
46+
implementation group: 'ai.djl.pytorch', name: 'pytorch-model-zoo'
4747
implementation group: 'ai.djl', name: 'api'
4848
implementation group: 'ai.djl.huggingface', name: 'tokenizers'
49-
implementation("ai.djl.onnxruntime:onnxruntime-engine:0.21.0") {
49+
implementation("ai.djl.onnxruntime:onnxruntime-engine") {
5050
exclude group: "com.microsoft.onnxruntime", module: "onnxruntime"
5151
}
5252
def os = DefaultNativePlatform.currentOperatingSystem
5353
//arm/macos doesn't support GPU
5454
if (os.macOsX || System.getProperty("os.arch") == "aarch64") {
5555
dependencies {
56-
implementation "com.microsoft.onnxruntime:onnxruntime:1.14.0"
56+
implementation "com.microsoft.onnxruntime:onnxruntime:1.16.3!!"
5757
}
5858
} else {
5959
dependencies {
60-
implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.14.0"
60+
implementation "com.microsoft.onnxruntime:onnxruntime_gpu:1.16.3!!"
6161
}
6262
}
6363

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModel.java

+1
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ protected void loadModel(
253253
ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
254254
try {
255255
System.setProperty("PYTORCH_PRECXX11", "true");
256+
System.setProperty("PYTORCH_VERSION", "1.13.1");
256257
System.setProperty("DJL_CACHE_DIR", mlEngine.getMlCachePath().toAbsolutePath().toString());
257258
// DJL will read "/usr/java/packages/lib" if don't set "java.library.path". That will throw
258259
// access denied exception

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/DLModelExecute.java

+1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ private void loadModel(File modelZipFile, String modelId, String modelName, Stri
131131
ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
132132
try {
133133
System.setProperty("PYTORCH_PRECXX11", "true");
134+
System.setProperty("PYTORCH_VERSION", "1.13.1");
134135
System.setProperty("DJL_CACHE_DIR", mlEngine.getMlCachePath().toAbsolutePath().toString());
135136
// DJL will read "/usr/java/packages/lib" if don't set "java.library.path". That will throw
136137
// access denied exception

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/question_answering/QuestionAnsweringModelTest.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -163,15 +163,15 @@ public void initModel_predict_ONNX_QuestionAnswering() throws URISyntaxException
163163
.modelFormat(MLModelFormat.ONNX)
164164
.name("test_model_name")
165165
.modelId("test_model_id")
166-
.algorithm(FunctionName.TEXT_SIMILARITY)
166+
.algorithm(FunctionName.QUESTION_ANSWERING)
167167
.version("1.0.0")
168168
.modelState(MLModelState.TRAINED)
169169
.build();
170170
modelZipFile = new File(getClass().getResource("question_answering_onnx.zip").toURI());
171171
params.put(MODEL_ZIP_FILE, modelZipFile);
172172

173173
questionAnsweringModel.initModel(model, params, encryptor);
174-
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_SIMILARITY).inputDataset(inputDataSet).build();
174+
MLInput mlInput = MLInput.builder().algorithm(FunctionName.QUESTION_ANSWERING).inputDataset(inputDataSet).build();
175175
ModelTensorOutput output = (ModelTensorOutput) questionAnsweringModel.predict(mlInput);
176176
List<ModelTensors> mlModelOutputs = output.getMlModelOutputs();
177177
assertEquals(1, mlModelOutputs.size());

0 commit comments

Comments
 (0)