Skip to content

Commit f089b5b

Browse files
Vikasht34navneet1v
andauthored
Align dimensions to the nearest multiple of 8 in QuantizationState (#2010)
* Align dimensions to the nearest multiple of 8 in QuantizationState Signed-off-by: VIKASH TIWARI <viktari@amazon.com> * Update src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java Co-authored-by: Navneet Verma <vermanavneet003@gmail.com> Signed-off-by: Vikasht34 <viktari@amazon.com> * Update QuantizationStateTests.java Signed-off-by: Vikasht34 <viktari@amazon.com> --------- Signed-off-by: VIKASH TIWARI <viktari@amazon.com> Signed-off-by: Vikasht34 <viktari@amazon.com> Co-authored-by: Navneet Verma <vermanavneet003@gmail.com>
1 parent bf38c2e commit f089b5b

File tree

12 files changed

+92
-12
lines changed

12 files changed

+92
-12
lines changed

src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import org.apache.lucene.index.Sorter;
2525
import org.apache.lucene.util.IOUtils;
2626
import org.apache.lucene.util.RamUsageEstimator;
27-
import org.opensearch.knn.index.quantizationService.QuantizationService;
27+
import org.opensearch.knn.index.quantizationservice.QuantizationService;
2828
import org.opensearch.knn.index.VectorDataType;
2929
import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter;
3030
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;

src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import org.opensearch.knn.index.VectorDataType;
2525
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
2626
import org.opensearch.knn.index.engine.KNNEngine;
27-
import org.opensearch.knn.index.quantizationService.QuantizationService;
27+
import org.opensearch.knn.index.quantizationservice.QuantizationService;
2828
import org.opensearch.knn.index.util.IndexUtil;
2929
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
3030
import org.opensearch.knn.indices.Model;

src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import lombok.experimental.UtilityClass;
99
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
10-
import org.opensearch.knn.index.quantizationService.QuantizationService;
10+
import org.opensearch.knn.index.quantizationservice.QuantizationService;
1111
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
1212
import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput;
1313
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java src/main/java/org/opensearch/knn/index/quantizationservice/KNNVectorQuantizationTrainingRequest.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.knn.index.quantizationService;
6+
package org.opensearch.knn.index.quantizationservice;
77

88
import lombok.extern.log4j.Log4j2;
99
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
@@ -49,7 +49,7 @@ public T getVectorAtThePosition(int position) throws IOException {
4949
}
5050
knnVectorValues.nextDoc();
5151
}
52-
// Return the vector and the updated index
52+
// Return the vector
5353
return knnVectorValues.getVector();
5454
}
5555
}

src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.knn.index.quantizationService;
6+
package org.opensearch.knn.index.quantizationservice;
77

88
import lombok.AccessLevel;
99
import lombok.NoArgsConstructor;

src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,14 @@ public int getDimensions() {
150150
if (thresholds == null || thresholds.length == 0 || thresholds[0] == null) {
151151
throw new IllegalStateException("Error in getting Dimension: The thresholds array is not initialized.");
152152
}
153-
return thresholds.length * thresholds[0].length;
153+
int originalDimensions = thresholds[0].length;
154+
155+
// Align the original dimensions to the next multiple of 8 for each bit level
156+
int alignedDimensions = (originalDimensions + 7) & ~7;
157+
158+
// The final dimension count should consider the bit levels
159+
return thresholds.length * alignedDimensions;
160+
154161
}
155162

156163
/**

src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ public int getBytesPerVector() {
123123
@Override
124124
public int getDimensions() {
125125
// For one-bit quantization, the dimension for indexing is just the length of the thresholds array.
126-
return meanThresholds.length;
126+
// Align the original dimensions to the next multiple of 8
127+
return (meanThresholds.length + 7) & ~7;
127128
}
128129

129130
/**

src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
1818
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory;
1919
import org.opensearch.knn.index.engine.KNNEngine;
20-
import org.opensearch.knn.index.quantizationService.QuantizationService;
20+
import org.opensearch.knn.index.quantizationservice.QuantizationService;
2121
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
2222
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
2323
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;

src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer;
1717
import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory;
1818
import org.opensearch.knn.index.engine.KNNEngine;
19-
import org.opensearch.knn.index.quantizationService.QuantizationService;
19+
import org.opensearch.knn.index.quantizationservice.QuantizationService;
2020
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
2121
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
2222
import org.opensearch.knn.index.vectorvalues.TestVectorValues;

src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import org.opensearch.knn.KNNTestCase;
1010
import org.opensearch.knn.index.VectorDataType;
1111
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
12-
import org.opensearch.knn.index.quantizationService.QuantizationService;
12+
import org.opensearch.knn.index.quantizationservice.QuantizationService;
1313
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
1414
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
1515
import org.opensearch.knn.index.vectorvalues.TestVectorValues;

src/test/java/org/opensearch/knn/index/quantizationService/QuantizationServiceTests.java src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.knn.index.quantizationService;
6+
package org.opensearch.knn.index.quantizationservice;
77

88
import org.opensearch.knn.KNNTestCase;
99
import org.junit.Before;

src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java

+72
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,78 @@ public void testOneBitScalarQuantizationStateRamBytesUsed() throws IOException {
100100
assertEquals(expectedRamBytesUsed, actualRamBytesUsed);
101101
}
102102

103+
public void testMultiBitScalarQuantizationStateGetDimensions_withDimensionNotMultipleOf8_thenSuccess() {
104+
ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
105+
106+
// Case 1: 3 thresholds, each with 2 dimensions
107+
float[][] thresholds1 = { { 0.5f, 1.5f }, { 1.0f, 2.0f }, { 1.5f, 2.5f } };
108+
MultiBitScalarQuantizationState state1 = new MultiBitScalarQuantizationState(params, thresholds1);
109+
int expectedDimensions1 = 24; // The next multiple of 8 considering all bits
110+
assertEquals(expectedDimensions1, state1.getDimensions());
111+
112+
// Case 2: 1 threshold, with 5 dimensions (5 bits, should align to 8)
113+
float[][] thresholds2 = { { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f } };
114+
MultiBitScalarQuantizationState state2 = new MultiBitScalarQuantizationState(params, thresholds2);
115+
int expectedDimensions2 = 8; // The next multiple of 8 considering all bits
116+
assertEquals(expectedDimensions2, state2.getDimensions());
117+
118+
// Case 3: 4 thresholds, each with 7 dimensions (28 bits, should align to 32)
119+
float[][] thresholds3 = {
120+
{ 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f },
121+
{ 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f },
122+
{ 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f },
123+
{ 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } };
124+
MultiBitScalarQuantizationState state3 = new MultiBitScalarQuantizationState(params, thresholds3);
125+
int expectedDimensions3 = 32; // The next multiple of 8 considering all bits
126+
assertEquals(expectedDimensions3, state3.getDimensions());
127+
128+
// Case 4: 2 thresholds, each with 8 dimensions (16 bits, already aligned)
129+
float[][] thresholds4 = { { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f }, { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } };
130+
MultiBitScalarQuantizationState state4 = new MultiBitScalarQuantizationState(params, thresholds4);
131+
int expectedDimensions4 = 16; // Already aligned to 8
132+
assertEquals(expectedDimensions4, state4.getDimensions());
133+
134+
// Case 5: 2 thresholds, each with 6 dimensions (12 bits, should align to 16)
135+
float[][] thresholds5 = { { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f }, { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f } };
136+
MultiBitScalarQuantizationState state5 = new MultiBitScalarQuantizationState(params, thresholds5);
137+
int expectedDimensions5 = 16; // The next multiple of 8 considering all bits
138+
assertEquals(expectedDimensions5, state5.getDimensions());
139+
}
140+
141+
public void testOneBitScalarQuantizationStateGetDimensions_withDimensionNotMultipleOf8_thenSuccess() {
142+
ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
143+
144+
// Case 1: 5 dimensions (should align to 8)
145+
float[] thresholds1 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f };
146+
OneBitScalarQuantizationState state1 = new OneBitScalarQuantizationState(params, thresholds1);
147+
int expectedDimensions1 = 8; // The next multiple of 8
148+
assertEquals(expectedDimensions1, state1.getDimensions());
149+
150+
// Case 2: 7 dimensions (should align to 8)
151+
float[] thresholds2 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f };
152+
OneBitScalarQuantizationState state2 = new OneBitScalarQuantizationState(params, thresholds2);
153+
int expectedDimensions2 = 8; // The next multiple of 8
154+
assertEquals(expectedDimensions2, state2.getDimensions());
155+
156+
// Case 3: 8 dimensions (already aligned to 8)
157+
float[] thresholds3 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f };
158+
OneBitScalarQuantizationState state3 = new OneBitScalarQuantizationState(params, thresholds3);
159+
int expectedDimensions3 = 8; // Already aligned to 8
160+
assertEquals(expectedDimensions3, state3.getDimensions());
161+
162+
// Case 4: 10 dimensions (should align to 16)
163+
float[] thresholds4 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f, 9.5f };
164+
OneBitScalarQuantizationState state4 = new OneBitScalarQuantizationState(params, thresholds4);
165+
int expectedDimensions4 = 16; // The next multiple of 8
166+
assertEquals(expectedDimensions4, state4.getDimensions());
167+
168+
// Case 5: 16 dimensions (already aligned to 16)
169+
float[] thresholds5 = { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f, 9.5f, 10.5f, 11.5f, 12.5f, 13.5f, 14.5f, 15.5f };
170+
OneBitScalarQuantizationState state5 = new OneBitScalarQuantizationState(params, thresholds5);
171+
int expectedDimensions5 = 16; // Already aligned to 16
172+
assertEquals(expectedDimensions5, state5.getDimensions());
173+
}
174+
103175
public void testMultiBitScalarQuantizationStateRamBytesUsedManualCalculation() throws IOException {
104176
ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
105177
float[][] thresholds = { { 0.5f, 1.5f, 2.5f }, { 1.0f, 2.0f, 3.0f } };

0 commit comments

Comments
 (0)