28
28
import org .opensearch .core .rest .RestStatus ;
29
29
import org .opensearch .ml .common .transport .batch .MLBatchIngestionInput ;
30
30
import org .opensearch .ml .engine .annotation .Ingester ;
31
-
32
- import com .google .common .annotations .VisibleForTesting ;
31
+ import org .opensearch .ml .engine .utils .S3Utils ;
33
32
34
33
import lombok .extern .log4j .Log4j2 ;
35
- import software .amazon .awssdk .auth .credentials .AwsBasicCredentials ;
36
- import software .amazon .awssdk .auth .credentials .AwsCredentials ;
37
- import software .amazon .awssdk .auth .credentials .AwsSessionCredentials ;
38
- import software .amazon .awssdk .auth .credentials .StaticCredentialsProvider ;
39
34
import software .amazon .awssdk .core .ResponseInputStream ;
40
- import software .amazon .awssdk .regions .Region ;
41
35
import software .amazon .awssdk .services .s3 .S3Client ;
42
36
import software .amazon .awssdk .services .s3 .model .GetObjectRequest ;
43
37
import software .amazon .awssdk .services .s3 .model .GetObjectResponse ;
@@ -54,7 +48,12 @@ public S3DataIngestion(Client client) {
54
48
55
49
@ Override
56
50
public double ingest (MLBatchIngestionInput mlBatchIngestionInput , int bulkSize ) {
57
- S3Client s3 = initS3Client (mlBatchIngestionInput );
51
+ String accessKey = mlBatchIngestionInput .getCredential ().get (ACCESS_KEY_FIELD );
52
+ String secretKey = mlBatchIngestionInput .getCredential ().get (SECRET_KEY_FIELD );
53
+ String sessionToken = mlBatchIngestionInput .getCredential ().get (SESSION_TOKEN_FIELD );
54
+ String region = mlBatchIngestionInput .getCredential ().get (REGION_FIELD );
55
+
56
+ S3Client s3 = S3Utils .initS3Client (accessKey , secretKey , region , sessionToken );
58
57
59
58
List <String > s3Uris = (List <String >) mlBatchIngestionInput .getDataSources ().get (SOURCE );
60
59
if (Objects .isNull (s3Uris ) || s3Uris .isEmpty ()) {
@@ -77,8 +76,8 @@ public double ingestSingleSource(
77
76
boolean isSoleSource ,
78
77
int bulkSize
79
78
) {
80
- String bucketName = getS3BucketName (s3Uri );
81
- String keyName = getS3KeyName (s3Uri );
79
+ String bucketName = S3Utils . getS3BucketName (s3Uri );
80
+ String keyName = S3Utils . getS3KeyName (s3Uri );
82
81
GetObjectRequest getObjectRequest = GetObjectRequest .builder ().bucket (bucketName ).key (keyName ).build ();
83
82
double successRate = 0 ;
84
83
@@ -153,55 +152,4 @@ public double ingestSingleSource(
153
152
154
153
return successRate ;
155
154
}
156
-
157
- private String getS3BucketName (String s3Uri ) {
158
- // Remove the "s3://" prefix
159
- String uriWithoutPrefix = s3Uri .substring (5 );
160
- // Find the first slash after the bucket name
161
- int slashIndex = uriWithoutPrefix .indexOf ('/' );
162
- // If there is no slash, the entire remaining string is the bucket name
163
- if (slashIndex == -1 ) {
164
- return uriWithoutPrefix ;
165
- }
166
- // Otherwise, the bucket name is the substring up to the first slash
167
- return uriWithoutPrefix .substring (0 , slashIndex );
168
- }
169
-
170
- private String getS3KeyName (String s3Uri ) {
171
- String uriWithoutPrefix = s3Uri .substring (5 );
172
- // Find the first slash after the bucket name
173
- int slashIndex = uriWithoutPrefix .indexOf ('/' );
174
- // If there is no slash, it means there is no key, return an empty string or handle as needed
175
- if (slashIndex == -1 ) {
176
- return "" ;
177
- }
178
- // The key name is the substring after the first slash
179
- return uriWithoutPrefix .substring (slashIndex + 1 );
180
- }
181
-
182
- @ VisibleForTesting
183
- public S3Client initS3Client (MLBatchIngestionInput mlBatchIngestionInput ) {
184
- String accessKey = mlBatchIngestionInput .getCredential ().get (ACCESS_KEY_FIELD );
185
- String secretKey = mlBatchIngestionInput .getCredential ().get (SECRET_KEY_FIELD );
186
- String sessionToken = mlBatchIngestionInput .getCredential ().get (SESSION_TOKEN_FIELD );
187
- String region = mlBatchIngestionInput .getCredential ().get (REGION_FIELD );
188
-
189
- AwsCredentials credentials = sessionToken == null
190
- ? AwsBasicCredentials .create (accessKey , secretKey )
191
- : AwsSessionCredentials .create (accessKey , secretKey , sessionToken );
192
-
193
- try {
194
- S3Client s3 = AccessController
195
- .doPrivileged (
196
- (PrivilegedExceptionAction <S3Client >) () -> S3Client
197
- .builder ()
198
- .region (Region .of (region )) // Specify the region here
199
- .credentialsProvider (StaticCredentialsProvider .create (credentials ))
200
- .build ()
201
- );
202
- return s3 ;
203
- } catch (PrivilegedActionException e ) {
204
- throw new RuntimeException ("Can't load credentials" , e );
205
- }
206
- }
207
155
}
0 commit comments