Skip to content

Commit c8d7b8a

Browse files
committed
add openAI ingester
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent fd9e77f commit c8d7b8a

File tree

2 files changed

+140
-1
lines changed

2 files changed

+140
-1
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/S3DataIngestion.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ private void batchIngest(
166166
private void processFieldMapping(Map<String, Object> jsonMap, Map<String, String> fieldMapping) {
167167
List<List> smOutput = (List<List>) jsonMap.get("SageMakerOutput");
168168
List<String> smInput = (List<String>) jsonMap.get("content");
169-
if (smInput.size() == smInput.size() && smInput.size() == fieldMapping.size()) {
169+
if (smInput.size() == smOutput.size() && smInput.size() == fieldMapping.size()) {
170170
int index = 0;
171171
for (Map.Entry<String, String> mapping : fieldMapping.entrySet()) {
172172
// key is the field name for input String, value is the field name for embedded output

ml-algorithms/src/main/java/org/opensearch/ml/engine/ingest/openAIDataIngestion.java

+139
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,28 @@
11
package org.opensearch.ml.engine.ingest;
22

3+
import java.io.BufferedReader;
4+
import java.io.InputStreamReader;
5+
import java.net.HttpURLConnection;
6+
import java.net.URL;
7+
import java.security.AccessController;
8+
import java.security.PrivilegedActionException;
9+
import java.security.PrivilegedExceptionAction;
10+
import java.util.ArrayList;
11+
import java.util.HashMap;
12+
import java.util.List;
13+
import java.util.Map;
14+
import java.util.concurrent.CompletableFuture;
15+
import java.util.concurrent.atomic.AtomicInteger;
16+
17+
import org.json.JSONArray;
18+
import org.json.JSONObject;
19+
import org.opensearch.OpenSearchStatusException;
20+
import org.opensearch.action.bulk.BulkRequest;
21+
import org.opensearch.action.bulk.BulkResponse;
22+
import org.opensearch.action.index.IndexRequest;
323
import org.opensearch.client.Client;
24+
import org.opensearch.core.action.ActionListener;
25+
import org.opensearch.core.rest.RestStatus;
426
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
527
import org.opensearch.ml.engine.annotation.Ingester;
628

@@ -9,6 +31,9 @@
931
@Log4j2
1032
@Ingester("openai")
1133
public class openAIDataIngestion implements Ingestable {
34+
private static final String API_KEY = "openAI_key";
35+
private static final String API_URL = "https://api.openai.com/v1/files/";
36+
1237
public static final String SOURCE = "source";
1338
private final Client client;
1439

@@ -19,7 +44,121 @@ public openAIDataIngestion(Client client) {
1944
@Override
2045
public double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
2146
double successRate = 0;
47+
try {
48+
String apiKey = mlBatchIngestionInput.getCredential().get(API_KEY);
49+
String fileId = mlBatchIngestionInput.getDataSources().get(SOURCE);
50+
URL url = new URL(API_URL + fileId + "/content");
51+
52+
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
53+
connection.setRequestMethod("GET");
54+
connection.setRequestProperty("Authorization", "Bearer " + apiKey);
55+
56+
InputStreamReader inputStreamReader = AccessController
57+
.doPrivileged((PrivilegedExceptionAction<InputStreamReader>) () -> new InputStreamReader(connection.getInputStream()));
58+
BufferedReader reader = new BufferedReader(inputStreamReader);
59+
60+
List<String> linesBuffer = new ArrayList<>();
61+
String line;
62+
int lineCount = 0;
63+
// Atomic counters for tracking success and failure
64+
AtomicInteger successfulBatches = new AtomicInteger(0);
65+
AtomicInteger failedBatches = new AtomicInteger(0);
66+
// List of CompletableFutures to track batch ingestion operations
67+
List<CompletableFuture<Void>> futures = new ArrayList<>();
68+
69+
while ((line = reader.readLine()) != null) {
70+
linesBuffer.add(line);
71+
lineCount++;
72+
73+
// Process every 100 lines
74+
if (lineCount == 100) {
75+
// Create a CompletableFuture that will be completed by the bulkResponseListener
76+
CompletableFuture<Void> future = new CompletableFuture<>();
77+
batchIngest(linesBuffer, mlBatchIngestionInput, getBulkResponseListener(successfulBatches, failedBatches, future));
78+
79+
futures.add(future);
80+
linesBuffer.clear();
81+
lineCount = 0;
82+
}
83+
}
84+
// Process any remaining lines in the buffer
85+
if (!linesBuffer.isEmpty()) {
86+
CompletableFuture<Void> future = new CompletableFuture<>();
87+
batchIngest(linesBuffer, mlBatchIngestionInput, getBulkResponseListener(successfulBatches, failedBatches, future));
88+
futures.add(future);
89+
}
90+
91+
reader.close();
92+
// Combine all futures and wait for completion
93+
CompletableFuture<Void> allFutures = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
94+
// Wait for all tasks to complete
95+
allFutures.join();
96+
int totalBatches = successfulBatches.get() + failedBatches.get();
97+
successRate = (double) successfulBatches.get() / totalBatches * 100;
98+
} catch (PrivilegedActionException e) {
99+
throw new RuntimeException("Failed to read from OpenAI file API: ", e);
100+
} catch (Exception e) {
101+
log.error(e.getMessage());
102+
throw new OpenSearchStatusException("Failed to batch ingest: " + e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR);
103+
}
22104

23105
return successRate;
24106
}
107+
108+
private ActionListener<BulkResponse> getBulkResponseListener(
109+
AtomicInteger successfulBatches,
110+
AtomicInteger failedBatches,
111+
CompletableFuture<Void> future
112+
) {
113+
return ActionListener.wrap(bulkResponse -> {
114+
if (bulkResponse.hasFailures()) {
115+
failedBatches.incrementAndGet();
116+
future.completeExceptionally(new RuntimeException(bulkResponse.buildFailureMessage())); // Mark the future as completed
117+
// with an exception
118+
}
119+
log.debug("Batch Ingestion successfully");
120+
successfulBatches.incrementAndGet();
121+
future.complete(null); // Mark the future as completed successfully
122+
}, e -> {
123+
log.error("Failed to bulk update model state", e);
124+
failedBatches.incrementAndGet();
125+
future.completeExceptionally(e); // Mark the future as completed with an exception
126+
});
127+
}
128+
129+
private void batchIngest(
130+
List<String> sourceLines,
131+
MLBatchIngestionInput mlBatchIngestionInput,
132+
ActionListener<BulkResponse> bulkResponseListener
133+
) {
134+
BulkRequest bulkRequest = new BulkRequest();
135+
sourceLines.stream().forEach(jsonStr -> {
136+
JSONObject jsonObject = new JSONObject(jsonStr);
137+
String customId = jsonObject.getString("custom_id");
138+
JSONObject responseBody = jsonObject.getJSONObject("response").getJSONObject("body");
139+
JSONArray dataArray = responseBody.getJSONArray("data");
140+
Map<String, Object> jsonMap = processFieldMapping(customId, dataArray, mlBatchIngestionInput.getFieldMapping());
141+
IndexRequest indexRequest = new IndexRequest(mlBatchIngestionInput.getIndexName()).source(jsonMap);
142+
143+
bulkRequest.add(indexRequest);
144+
});
145+
client.bulk(bulkRequest, bulkResponseListener);
146+
}
147+
148+
private Map<String, Object> processFieldMapping(String customId, JSONArray dataArray, Map<String, String> fieldMapping) {
149+
Map<String, Object> jsonMap = new HashMap<>();
150+
if (dataArray.length() == fieldMapping.size()) {
151+
int index = 0;
152+
for (Map.Entry<String, String> mapping : fieldMapping.entrySet()) {
153+
// key is the field name for input String, value is the field name for embedded output
154+
JSONObject dataItem = dataArray.getJSONObject(index);
155+
jsonMap.put(mapping.getValue(), dataItem.getJSONArray("embedding"));
156+
index++;
157+
}
158+
jsonMap.put("id", customId);
159+
} else {
160+
throw new IllegalArgumentException("the fieldMapping and source data do not match");
161+
}
162+
return jsonMap;
163+
}
25164
}

0 commit comments

Comments
 (0)