Skip to content

Commit 0d06b23

Browse files
anntiansAnnTian Shao
and
AnnTian Shao
authored
Update training validation to be handled per algo type (opensearch-project#2462)
* Update training validation to be handled per algo type Signed-off-by: AnnTian Shao <anntians@amazon.com> * fix bwc tests Signed-off-by: AnnTian Shao <anntians@amazon.com> * fixes to encoder and method classes Signed-off-by: AnnTian Shao <anntians@amazon.com> * fixes and moved all validation to encoder Signed-off-by: AnnTian Shao <anntians@amazon.com> * Moved error message to within encoder Signed-off-by: AnnTian Shao <anntians@amazon.com> * fix tests Signed-off-by: AnnTian Shao <anntians@amazon.com> --------- Signed-off-by: AnnTian Shao <anntians@amazon.com> Co-authored-by: AnnTian Shao <anntians@amazon.com>
1 parent f4779a5 commit 0d06b23

15 files changed

+270
-141
lines changed

qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/FaissSQIT.java

+25-6
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ public void testHNSWSQFP16_onUpgradeWhenClipToFp16isTrueAndIndexedWithOutOfFP16R
136136
List<Integer> efConstructionValues = ImmutableList.of(16, 32, 64, 128);
137137
List<Integer> efSearchValues = ImmutableList.of(16, 32, 64, 128);
138138

139-
int dimension = 2;
139+
int dimension = 128;
140140

141141
// Create an index
142142
/**
@@ -199,16 +199,35 @@ public void testHNSWSQFP16_onUpgradeWhenClipToFp16isTrueAndIndexedWithOutOfFP16R
199199

200200
createKnnIndex(testIndex, mapping);
201201
assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(testIndex)));
202-
Float[] vector1 = { -65523.76f, 65504.2f };
203-
Float[] vector2 = { -270.85f, 65514.2f };
204-
Float[] vector3 = { -150.9f, 65504.0f };
205-
Float[] vector4 = { -20.89f, 100000000.0f };
202+
203+
Float[] vector1 = new Float[dimension];
204+
Float[] vector2 = new Float[dimension];
205+
Float[] vector3 = new Float[dimension];
206+
Float[] vector4 = new Float[dimension];
207+
float[] queryVector = new float[dimension];
208+
int halfDimension = dimension / 2;
209+
210+
for (int i = 0; i < dimension; i++) {
211+
if (i < halfDimension) {
212+
vector1[i] = -65523.76f;
213+
vector2[i] = -270.85f;
214+
vector3[i] = -150.9f;
215+
vector4[i] = -20.89f;
216+
queryVector[i] = -10.5f;
217+
} else {
218+
vector1[i] = 65504.2f;
219+
vector2[i] = 65514.2f;
220+
vector3[i] = 65504.0f;
221+
vector4[i] = 100000000.0f;
222+
queryVector[i] = 25.48f;
223+
}
224+
}
225+
206226
addKnnDoc(testIndex, "1", TEST_FIELD, vector1);
207227
addKnnDoc(testIndex, "2", TEST_FIELD, vector2);
208228
addKnnDoc(testIndex, "3", TEST_FIELD, vector3);
209229
addKnnDoc(testIndex, "4", TEST_FIELD, vector4);
210230

211-
float[] queryVector = { -10.5f, 25.48f };
212231
int k = 4;
213232
Response searchResponse = searchKNNIndex(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, k), k);
214233
List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), TEST_FIELD);

src/main/java/org/opensearch/knn/index/engine/AbstractKNNMethod.java

-47
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,6 @@
2424
import java.util.Set;
2525
import java.util.function.Function;
2626

27-
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
28-
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
29-
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
30-
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;
31-
3227
/**
3328
* Abstract class for KNN methods. This class provides the common functionality for all KNN methods.
3429
* It defines the common attributes and methods that all KNN methods should implement.
@@ -116,49 +111,7 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(
116111

117112
protected Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> doGetTrainingConfigValidationSetup() {
118113
return (trainingConfigValidationInput) -> {
119-
120-
KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
121-
KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext();
122-
Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount();
123-
124114
TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();
125-
126-
// validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension
127-
if (knnMethodContext != null && knnMethodConfigContext != null) {
128-
if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M)
129-
&& knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext()
130-
.getParameters()
131-
.get(ENCODER_PARAMETER_PQ_M) != 0) {
132-
builder.valid(false);
133-
return builder.build();
134-
} else {
135-
builder.valid(true);
136-
}
137-
}
138-
139-
// validate number of training points should be greater than minimum clustering criteria defined in faiss
140-
if (knnMethodContext != null && trainingVectors != null) {
141-
long minTrainingVectorCount = 1000;
142-
143-
MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext()
144-
.getParameters()
145-
.get(METHOD_ENCODER_PARAMETER);
146-
147-
if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST)
148-
&& encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) {
149-
150-
int nlist = ((Integer) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_PARAMETER_NLIST));
151-
int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE));
152-
minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size));
153-
}
154-
155-
if (trainingVectors < minTrainingVectorCount) {
156-
builder.valid(false).minTrainingVectorCount(minTrainingVectorCount);
157-
return builder.build();
158-
} else {
159-
builder.valid(true);
160-
}
161-
}
162115
return builder.build();
163116
};
164117
}

src/main/java/org/opensearch/knn/index/engine/Encoder.java

+10
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,14 @@ default String getName() {
3636
* return {@link CompressionLevel#NOT_CONFIGURED}
3737
*/
3838
CompressionLevel calculateCompressionLevel(MethodComponentContext encoderContext, KNNMethodConfigContext knnMethodConfigContext);
39+
40+
/**
41+
* Validates config of encoder
42+
*
43+
* @return Validation output of encoder parameters
44+
*/
45+
default TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput validationInput) {
46+
TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();
47+
return builder.build();
48+
}
3949
}

src/main/java/org/opensearch/knn/index/engine/TrainingConfigValidationOutput.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
@Builder
2121
@AllArgsConstructor
2222
public class TrainingConfigValidationOutput {
23-
private boolean valid;
24-
private long minTrainingVectorCount;
23+
private Boolean valid;
24+
private Long minTrainingVectorCount;
25+
private String errorMessage;
2526
}

src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java

+27
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,33 @@ static MethodComponentContext getEncoderMethodComponent(MethodComponentContext m
129129
return (MethodComponentContext) object;
130130
}
131131

132+
protected String getEncoderName(KNNMethodContext knnMethodContext) {
133+
if (isEncoderSpecified(knnMethodContext) == false) {
134+
return null;
135+
}
136+
137+
MethodComponentContext methodComponentContext = getEncoderComponentContext(knnMethodContext);
138+
if (methodComponentContext == null) {
139+
return null;
140+
}
141+
142+
return methodComponentContext.getName();
143+
}
144+
145+
protected MethodComponentContext getEncoderComponentContext(KNNMethodContext knnMethodContext) {
146+
if (isEncoderSpecified(knnMethodContext) == false) {
147+
return null;
148+
}
149+
150+
return (MethodComponentContext) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_ENCODER_PARAMETER);
151+
}
152+
153+
protected boolean isEncoderSpecified(KNNMethodContext knnMethodContext) {
154+
return knnMethodContext != null
155+
&& knnMethodContext.getMethodComponentContext().getParameters() != null
156+
&& knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_ENCODER_PARAMETER);
157+
}
158+
132159
@Override
133160
protected SpaceType convertUserToMethodSpaceType(SpaceType spaceType) {
134161
// While FAISS doesn't directly support cosine similarity, we can leverage the mathematical

src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissPQEncoder.java

+53
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313

1414
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;
1515
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
16+
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
17+
18+
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
19+
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
20+
import org.opensearch.knn.index.engine.KNNMethodContext;
1621

1722
/**
1823
* Abstract class for Faiss PQ encoders. This class provides the common logic for product quantization based encoders
@@ -89,4 +94,52 @@ public CompressionLevel calculateCompressionLevel(
8994
// compression
9095
return CompressionLevel.MAX_COMPRESSION_LEVEL;
9196
}
97+
98+
@Override
99+
public TrainingConfigValidationOutput validateEncoderConfig(TrainingConfigValidationInput trainingConfigValidationInput) {
100+
KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
101+
KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext();
102+
Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount();
103+
104+
TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();
105+
106+
// validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension
107+
if (knnMethodContext != null && knnMethodConfigContext != null) {
108+
if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M)
109+
&& knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext()
110+
.getParameters()
111+
.get(ENCODER_PARAMETER_PQ_M) != 0) {
112+
builder.valid(false);
113+
builder.errorMessage("Training request ENCODER_PARAMETER_PQ_M is not divisible by vector dimensions");
114+
return builder.build();
115+
} else {
116+
builder.valid(true);
117+
}
118+
}
119+
120+
// validate number of training points should be greater than minimum clustering criteria defined in faiss
121+
if (knnMethodContext != null && trainingVectors != null) {
122+
long minTrainingVectorCount = 1000;
123+
124+
MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext()
125+
.getParameters()
126+
.get(METHOD_ENCODER_PARAMETER);
127+
128+
if (encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) {
129+
130+
int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE));
131+
minTrainingVectorCount = (long) Math.pow(2, code_size);
132+
}
133+
134+
if (trainingVectors < minTrainingVectorCount) {
135+
builder.valid(false).minTrainingVectorCount(minTrainingVectorCount);
136+
builder.errorMessage(String.format("Number of training points should be greater than %d", minTrainingVectorCount));
137+
return builder.build();
138+
} else {
139+
builder.valid(true);
140+
}
141+
}
142+
143+
return builder.build();
144+
}
92145
}

src/main/java/org/opensearch/knn/index/engine/faiss/FaissHNSWMethod.java

+23
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@
1616
import org.opensearch.knn.index.engine.MethodComponent;
1717
import org.opensearch.knn.index.engine.MethodComponentContext;
1818
import org.opensearch.knn.index.engine.Parameter;
19+
import org.opensearch.knn.index.engine.KNNMethodContext;
20+
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
21+
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
1922

2023
import java.util.Arrays;
2124
import java.util.Collections;
2225
import java.util.List;
2326
import java.util.Map;
2427
import java.util.Set;
28+
import java.util.function.Function;
2529
import java.util.stream.Collectors;
2630

2731
import static org.opensearch.knn.common.KNNConstants.FAISS_HNSW_DESCRIPTION;
@@ -124,4 +128,23 @@ private static Parameter.MethodComponentContextParameter initEncoderParameter()
124128
SUPPORTED_ENCODERS.values().stream().collect(Collectors.toMap(Encoder::getName, Encoder::getMethodComponent))
125129
);
126130
}
131+
132+
@Override
133+
protected Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> doGetTrainingConfigValidationSetup() {
134+
return (trainingConfigValidationInput) -> {
135+
136+
KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
137+
TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();
138+
139+
if (isEncoderSpecified(knnMethodContext) == false) {
140+
return builder.build();
141+
}
142+
Encoder encoder = SUPPORTED_ENCODERS.get(getEncoderName(knnMethodContext));
143+
if (encoder == null) {
144+
return builder.build();
145+
}
146+
147+
return encoder.validateEncoderConfig(trainingConfigValidationInput);
148+
};
149+
}
127150
}

src/main/java/org/opensearch/knn/index/engine/faiss/FaissIVFMethod.java

+23
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,16 @@
1515
import org.opensearch.knn.index.engine.MethodComponent;
1616
import org.opensearch.knn.index.engine.MethodComponentContext;
1717
import org.opensearch.knn.index.engine.Parameter;
18+
import org.opensearch.knn.index.engine.KNNMethodContext;
19+
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
20+
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
1821

1922
import java.util.Arrays;
2023
import java.util.Collections;
2124
import java.util.List;
2225
import java.util.Map;
2326
import java.util.Set;
27+
import java.util.function.Function;
2428
import java.util.stream.Collectors;
2529

2630
import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
@@ -150,4 +154,23 @@ private static Parameter.MethodComponentContextParameter initEncoderParameter()
150154
SUPPORTED_ENCODERS.values().stream().collect(Collectors.toMap(Encoder::getName, Encoder::getMethodComponent))
151155
);
152156
}
157+
158+
@Override
159+
protected Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> doGetTrainingConfigValidationSetup() {
160+
return (trainingConfigValidationInput) -> {
161+
162+
KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
163+
TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();
164+
165+
if (isEncoderSpecified(knnMethodContext) == false) {
166+
return builder.build();
167+
}
168+
Encoder encoder = SUPPORTED_ENCODERS.get(getEncoderName(knnMethodContext));
169+
if (encoder == null) {
170+
return builder.build();
171+
}
172+
173+
return encoder.validateEncoderConfig(trainingConfigValidationInput);
174+
};
175+
}
153176
}

src/main/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolver.java

+31
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import org.opensearch.knn.index.engine.MethodComponent;
1616
import org.opensearch.knn.index.engine.MethodComponentContext;
1717
import org.opensearch.knn.index.engine.ResolvedMethodContext;
18+
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
19+
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
1820
import org.opensearch.knn.index.mapper.CompressionLevel;
1921
import org.opensearch.knn.index.mapper.Mode;
2022

@@ -73,6 +75,9 @@ public ResolvedMethodContext resolveMethod(
7375
encoderMap
7476
);
7577

78+
// Validate encoder parameters
79+
validateEncoderConfig(resolvedKNNMethodContext, knnMethodConfigContext, encoderMap);
80+
7681
// Validate that resolved compression doesnt have any conflicts
7782
validateCompressionConflicts(knnMethodConfigContext.getCompressionLevel(), resolvedCompressionLevel);
7883
knnMethodConfigContext.setCompressionLevel(resolvedCompressionLevel);
@@ -148,6 +153,32 @@ private void validateConfig(KNNMethodConfigContext knnMethodConfigContext) {
148153
}
149154
}
150155

156+
protected void validateEncoderConfig(
157+
KNNMethodContext resolvedKnnMethodContext,
158+
KNNMethodConfigContext knnMethodConfigContext,
159+
Map<String, Encoder> encoderMap
160+
) {
161+
if (isEncoderSpecified(resolvedKnnMethodContext) == false) {
162+
return;
163+
}
164+
Encoder encoder = encoderMap.get(getEncoderName(resolvedKnnMethodContext));
165+
if (encoder == null) {
166+
return;
167+
}
168+
169+
TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder();
170+
171+
TrainingConfigValidationOutput validationOutput = encoder.validateEncoderConfig(
172+
inputBuilder.knnMethodContext(resolvedKnnMethodContext).knnMethodConfigContext(knnMethodConfigContext).build()
173+
);
174+
175+
if (validationOutput.getValid() != null && !validationOutput.getValid()) {
176+
ValidationException validationException = new ValidationException();
177+
validationException.addValidationError(validationOutput.getErrorMessage());
178+
throw validationException;
179+
}
180+
}
181+
151182
private CompressionLevel getDefaultCompressionLevel(KNNMethodConfigContext knnMethodConfigContext) {
152183
if (CompressionLevel.isConfigured(knnMethodConfigContext.getCompressionLevel())) {
153184
return knnMethodConfigContext.getCompressionLevel();

src/main/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportAction.java

+2-4
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,9 @@ protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelReques
154154
TrainingConfigValidationOutput validation = validateTrainingConfig.apply(
155155
inputBuilder.trainingVectorsCount(trainingVectors).knnMethodContext(knnMethodContext).build()
156156
);
157-
if (!validation.isValid()) {
157+
if (validation.getValid() != null && !validation.getValid()) {
158158
ValidationException exception = new ValidationException();
159-
exception.addValidationError(
160-
String.format("Number of training points should be greater than %d", validation.getMinTrainingVectorCount())
161-
);
159+
exception.addValidationError(validation.getErrorMessage());
162160
listener.onFailure(exception);
163161
return;
164162
}

0 commit comments

Comments
 (0)