1
1
package org .opensearch .ml .engine .ingest ;
2
2
3
+ import java .io .BufferedReader ;
4
+ import java .io .InputStreamReader ;
5
+ import java .net .HttpURLConnection ;
6
+ import java .net .URL ;
7
+ import java .security .AccessController ;
8
+ import java .security .PrivilegedActionException ;
9
+ import java .security .PrivilegedExceptionAction ;
10
+ import java .util .ArrayList ;
11
+ import java .util .HashMap ;
12
+ import java .util .List ;
13
+ import java .util .Map ;
14
+ import java .util .concurrent .CompletableFuture ;
15
+ import java .util .concurrent .atomic .AtomicInteger ;
16
+
17
+ import org .json .JSONArray ;
18
+ import org .json .JSONObject ;
19
+ import org .opensearch .OpenSearchStatusException ;
20
+ import org .opensearch .action .bulk .BulkRequest ;
21
+ import org .opensearch .action .bulk .BulkResponse ;
22
+ import org .opensearch .action .index .IndexRequest ;
3
23
import org .opensearch .client .Client ;
24
+ import org .opensearch .core .action .ActionListener ;
25
+ import org .opensearch .core .rest .RestStatus ;
4
26
import org .opensearch .ml .common .transport .batch .MLBatchIngestionInput ;
5
27
import org .opensearch .ml .engine .annotation .Ingester ;
6
28
9
31
@ Log4j2
10
32
@ Ingester ("openai" )
11
33
public class openAIDataIngestion implements Ingestable {
34
+ private static final String API_KEY = "openAI_key" ;
35
+ private static final String API_URL = "https://api.openai.com/v1/files/" ;
36
+
12
37
public static final String SOURCE = "source" ;
13
38
private final Client client ;
14
39
@@ -19,7 +44,121 @@ public openAIDataIngestion(Client client) {
19
44
@ Override
20
45
public double ingest (MLBatchIngestionInput mlBatchIngestionInput ) {
21
46
double successRate = 0 ;
47
+ try {
48
+ String apiKey = mlBatchIngestionInput .getCredential ().get (API_KEY );
49
+ String fileId = mlBatchIngestionInput .getDataSources ().get (SOURCE );
50
+ URL url = new URL (API_URL + fileId + "/content" );
51
+
52
+ HttpURLConnection connection = (HttpURLConnection ) url .openConnection ();
53
+ connection .setRequestMethod ("GET" );
54
+ connection .setRequestProperty ("Authorization" , "Bearer " + apiKey );
55
+
56
+ InputStreamReader inputStreamReader = AccessController
57
+ .doPrivileged ((PrivilegedExceptionAction <InputStreamReader >) () -> new InputStreamReader (connection .getInputStream ()));
58
+ BufferedReader reader = new BufferedReader (inputStreamReader );
59
+
60
+ List <String > linesBuffer = new ArrayList <>();
61
+ String line ;
62
+ int lineCount = 0 ;
63
+ // Atomic counters for tracking success and failure
64
+ AtomicInteger successfulBatches = new AtomicInteger (0 );
65
+ AtomicInteger failedBatches = new AtomicInteger (0 );
66
+ // List of CompletableFutures to track batch ingestion operations
67
+ List <CompletableFuture <Void >> futures = new ArrayList <>();
68
+
69
+ while ((line = reader .readLine ()) != null ) {
70
+ linesBuffer .add (line );
71
+ lineCount ++;
72
+
73
+ // Process every 100 lines
74
+ if (lineCount == 100 ) {
75
+ // Create a CompletableFuture that will be completed by the bulkResponseListener
76
+ CompletableFuture <Void > future = new CompletableFuture <>();
77
+ batchIngest (linesBuffer , mlBatchIngestionInput , getBulkResponseListener (successfulBatches , failedBatches , future ));
78
+
79
+ futures .add (future );
80
+ linesBuffer .clear ();
81
+ lineCount = 0 ;
82
+ }
83
+ }
84
+ // Process any remaining lines in the buffer
85
+ if (!linesBuffer .isEmpty ()) {
86
+ CompletableFuture <Void > future = new CompletableFuture <>();
87
+ batchIngest (linesBuffer , mlBatchIngestionInput , getBulkResponseListener (successfulBatches , failedBatches , future ));
88
+ futures .add (future );
89
+ }
90
+
91
+ reader .close ();
92
+ // Combine all futures and wait for completion
93
+ CompletableFuture <Void > allFutures = CompletableFuture .allOf (futures .toArray (new CompletableFuture [0 ]));
94
+ // Wait for all tasks to complete
95
+ allFutures .join ();
96
+ int totalBatches = successfulBatches .get () + failedBatches .get ();
97
+ successRate = (double ) successfulBatches .get () / totalBatches * 100 ;
98
+ } catch (PrivilegedActionException e ) {
99
+ throw new RuntimeException ("Failed to read from OpenAI file API: " , e );
100
+ } catch (Exception e ) {
101
+ log .error (e .getMessage ());
102
+ throw new OpenSearchStatusException ("Failed to batch ingest: " + e .getMessage (), RestStatus .INTERNAL_SERVER_ERROR );
103
+ }
22
104
23
105
return successRate ;
24
106
}
107
+
108
+ private ActionListener <BulkResponse > getBulkResponseListener (
109
+ AtomicInteger successfulBatches ,
110
+ AtomicInteger failedBatches ,
111
+ CompletableFuture <Void > future
112
+ ) {
113
+ return ActionListener .wrap (bulkResponse -> {
114
+ if (bulkResponse .hasFailures ()) {
115
+ failedBatches .incrementAndGet ();
116
+ future .completeExceptionally (new RuntimeException (bulkResponse .buildFailureMessage ())); // Mark the future as completed
117
+ // with an exception
118
+ }
119
+ log .debug ("Batch Ingestion successfully" );
120
+ successfulBatches .incrementAndGet ();
121
+ future .complete (null ); // Mark the future as completed successfully
122
+ }, e -> {
123
+ log .error ("Failed to bulk update model state" , e );
124
+ failedBatches .incrementAndGet ();
125
+ future .completeExceptionally (e ); // Mark the future as completed with an exception
126
+ });
127
+ }
128
+
129
+ private void batchIngest (
130
+ List <String > sourceLines ,
131
+ MLBatchIngestionInput mlBatchIngestionInput ,
132
+ ActionListener <BulkResponse > bulkResponseListener
133
+ ) {
134
+ BulkRequest bulkRequest = new BulkRequest ();
135
+ sourceLines .stream ().forEach (jsonStr -> {
136
+ JSONObject jsonObject = new JSONObject (jsonStr );
137
+ String customId = jsonObject .getString ("custom_id" );
138
+ JSONObject responseBody = jsonObject .getJSONObject ("response" ).getJSONObject ("body" );
139
+ JSONArray dataArray = responseBody .getJSONArray ("data" );
140
+ Map <String , Object > jsonMap = processFieldMapping (customId , dataArray , mlBatchIngestionInput .getFieldMapping ());
141
+ IndexRequest indexRequest = new IndexRequest (mlBatchIngestionInput .getIndexName ()).source (jsonMap );
142
+
143
+ bulkRequest .add (indexRequest );
144
+ });
145
+ client .bulk (bulkRequest , bulkResponseListener );
146
+ }
147
+
148
+ private Map <String , Object > processFieldMapping (String customId , JSONArray dataArray , Map <String , String > fieldMapping ) {
149
+ Map <String , Object > jsonMap = new HashMap <>();
150
+ if (dataArray .length () == fieldMapping .size ()) {
151
+ int index = 0 ;
152
+ for (Map .Entry <String , String > mapping : fieldMapping .entrySet ()) {
153
+ // key is the field name for input String, value is the field name for embedded output
154
+ JSONObject dataItem = dataArray .getJSONObject (index );
155
+ jsonMap .put (mapping .getValue (), dataItem .getJSONArray ("embedding" ));
156
+ index ++;
157
+ }
158
+ jsonMap .put ("id" , customId );
159
+ } else {
160
+ throw new IllegalArgumentException ("the fieldMapping and source data do not match" );
161
+ }
162
+ return jsonMap ;
163
+ }
25
164
}
0 commit comments