5
5
6
6
package org .opensearch .ml .engine .ingest ;
7
7
8
- import static org .opensearch .ml .common .connector .AbstractConnector .ACCESS_KEY_FIELD ;
9
- import static org .opensearch .ml .common .connector .AbstractConnector .SECRET_KEY_FIELD ;
10
- import static org .opensearch .ml .common .connector .AbstractConnector .SESSION_TOKEN_FIELD ;
11
- import static org .opensearch .ml .common .connector .HttpConnector .REGION_FIELD ;
12
-
13
8
import java .io .BufferedReader ;
14
9
import java .io .InputStreamReader ;
15
10
import java .nio .charset .StandardCharsets ;
28
23
import org .opensearch .core .rest .RestStatus ;
29
24
import org .opensearch .ml .common .transport .batch .MLBatchIngestionInput ;
30
25
import org .opensearch .ml .engine .annotation .Ingester ;
31
-
32
- import com .google .common .annotations .VisibleForTesting ;
26
+ import org .opensearch .ml .common .utils .S3Utils ;
33
27
34
28
import lombok .extern .log4j .Log4j2 ;
35
- import software .amazon .awssdk .auth .credentials .AwsBasicCredentials ;
36
- import software .amazon .awssdk .auth .credentials .AwsCredentials ;
37
- import software .amazon .awssdk .auth .credentials .AwsSessionCredentials ;
38
- import software .amazon .awssdk .auth .credentials .StaticCredentialsProvider ;
39
29
import software .amazon .awssdk .core .ResponseInputStream ;
40
- import software .amazon .awssdk .regions .Region ;
41
30
import software .amazon .awssdk .services .s3 .S3Client ;
42
31
import software .amazon .awssdk .services .s3 .model .GetObjectRequest ;
43
32
import software .amazon .awssdk .services .s3 .model .GetObjectResponse ;
44
33
import software .amazon .awssdk .services .s3 .model .S3Exception ;
45
34
35
+ import static org .opensearch .ml .common .connector .AbstractConnector .*;
36
+ import static org .opensearch .ml .common .connector .HttpConnector .REGION_FIELD ;
37
+
46
38
@ Log4j2
47
39
@ Ingester ("s3" )
48
40
public class S3DataIngestion extends AbstractIngestion {
@@ -54,7 +46,12 @@ public S3DataIngestion(Client client) {
54
46
55
47
@ Override
56
48
public double ingest (MLBatchIngestionInput mlBatchIngestionInput , int bulkSize ) {
57
- S3Client s3 = initS3Client (mlBatchIngestionInput );
49
+ String accessKey = mlBatchIngestionInput .getCredential ().get (ACCESS_KEY_FIELD );
50
+ String secretKey = mlBatchIngestionInput .getCredential ().get (SECRET_KEY_FIELD );
51
+ String sessionToken = mlBatchIngestionInput .getCredential ().get (SESSION_TOKEN_FIELD );
52
+ String region = mlBatchIngestionInput .getCredential ().get (REGION_FIELD );
53
+
54
+ S3Client s3 = S3Utils .initS3Client (accessKey , secretKey , region , sessionToken );
58
55
59
56
List <String > s3Uris = (List <String >) mlBatchIngestionInput .getDataSources ().get (SOURCE );
60
57
if (Objects .isNull (s3Uris ) || s3Uris .isEmpty ()) {
@@ -77,8 +74,8 @@ public double ingestSingleSource(
77
74
boolean isSoleSource ,
78
75
int bulkSize
79
76
) {
80
- String bucketName = getS3BucketName (s3Uri );
81
- String keyName = getS3KeyName (s3Uri );
77
+ String bucketName = S3Utils . getS3BucketName (s3Uri );
78
+ String keyName = S3Utils . getS3KeyName (s3Uri );
82
79
GetObjectRequest getObjectRequest = GetObjectRequest .builder ().bucket (bucketName ).key (keyName ).build ();
83
80
double successRate = 0 ;
84
81
@@ -153,55 +150,4 @@ public double ingestSingleSource(
153
150
154
151
return successRate ;
155
152
}
156
-
157
- private String getS3BucketName (String s3Uri ) {
158
- // Remove the "s3://" prefix
159
- String uriWithoutPrefix = s3Uri .substring (5 );
160
- // Find the first slash after the bucket name
161
- int slashIndex = uriWithoutPrefix .indexOf ('/' );
162
- // If there is no slash, the entire remaining string is the bucket name
163
- if (slashIndex == -1 ) {
164
- return uriWithoutPrefix ;
165
- }
166
- // Otherwise, the bucket name is the substring up to the first slash
167
- return uriWithoutPrefix .substring (0 , slashIndex );
168
- }
169
-
170
- private String getS3KeyName (String s3Uri ) {
171
- String uriWithoutPrefix = s3Uri .substring (5 );
172
- // Find the first slash after the bucket name
173
- int slashIndex = uriWithoutPrefix .indexOf ('/' );
174
- // If there is no slash, it means there is no key, return an empty string or handle as needed
175
- if (slashIndex == -1 ) {
176
- return "" ;
177
- }
178
- // The key name is the substring after the first slash
179
- return uriWithoutPrefix .substring (slashIndex + 1 );
180
- }
181
-
182
- @ VisibleForTesting
183
- public S3Client initS3Client (MLBatchIngestionInput mlBatchIngestionInput ) {
184
- String accessKey = mlBatchIngestionInput .getCredential ().get (ACCESS_KEY_FIELD );
185
- String secretKey = mlBatchIngestionInput .getCredential ().get (SECRET_KEY_FIELD );
186
- String sessionToken = mlBatchIngestionInput .getCredential ().get (SESSION_TOKEN_FIELD );
187
- String region = mlBatchIngestionInput .getCredential ().get (REGION_FIELD );
188
-
189
- AwsCredentials credentials = sessionToken == null
190
- ? AwsBasicCredentials .create (accessKey , secretKey )
191
- : AwsSessionCredentials .create (accessKey , secretKey , sessionToken );
192
-
193
- try {
194
- S3Client s3 = AccessController
195
- .doPrivileged (
196
- (PrivilegedExceptionAction <S3Client >) () -> S3Client
197
- .builder ()
198
- .region (Region .of (region )) // Specify the region here
199
- .credentialsProvider (StaticCredentialsProvider .create (credentials ))
200
- .build ()
201
- );
202
- return s3 ;
203
- } catch (PrivilegedActionException e ) {
204
- throw new RuntimeException ("Can't load credentials" , e );
205
- }
206
- }
207
153
}
0 commit comments