@@ -100,6 +100,78 @@ public void testOneBitScalarQuantizationStateRamBytesUsed() throws IOException {
100
100
assertEquals (expectedRamBytesUsed , actualRamBytesUsed );
101
101
}
102
102
103
+ public void testMultiBitScalarQuantizationStateGetDimensions_withDimensionNotMultipleOf8_thenSuccess () {
104
+ ScalarQuantizationParams params = new ScalarQuantizationParams (ScalarQuantizationType .TWO_BIT );
105
+
106
+ // Case 1: 3 thresholds, each with 2 dimensions
107
+ float [][] thresholds1 = { { 0.5f , 1.5f }, { 1.0f , 2.0f }, { 1.5f , 2.5f } };
108
+ MultiBitScalarQuantizationState state1 = new MultiBitScalarQuantizationState (params , thresholds1 );
109
+ int expectedDimensions1 = 24 ; // The next multiple of 8 considering all bits
110
+ assertEquals (expectedDimensions1 , state1 .getDimensions ());
111
+
112
+ // Case 2: 1 threshold, with 5 dimensions (5 bits, should align to 8)
113
+ float [][] thresholds2 = { { 0.5f , 1.5f , 2.5f , 3.5f , 4.5f } };
114
+ MultiBitScalarQuantizationState state2 = new MultiBitScalarQuantizationState (params , thresholds2 );
115
+ int expectedDimensions2 = 8 ; // The next multiple of 8 considering all bits
116
+ assertEquals (expectedDimensions2 , state2 .getDimensions ());
117
+
118
+ // Case 3: 4 thresholds, each with 7 dimensions (28 bits, should align to 32)
119
+ float [][] thresholds3 = {
120
+ { 0.5f , 1.5f , 2.5f , 3.5f , 4.5f , 5.5f , 6.5f },
121
+ { 1.0f , 2.0f , 3.0f , 4.0f , 5.0f , 6.0f , 7.0f },
122
+ { 1.5f , 2.5f , 3.5f , 4.5f , 5.5f , 6.5f , 7.5f },
123
+ { 2.0f , 3.0f , 4.0f , 5.0f , 6.0f , 7.0f , 8.0f } };
124
+ MultiBitScalarQuantizationState state3 = new MultiBitScalarQuantizationState (params , thresholds3 );
125
+ int expectedDimensions3 = 32 ; // The next multiple of 8 considering all bits
126
+ assertEquals (expectedDimensions3 , state3 .getDimensions ());
127
+
128
+ // Case 4: 2 thresholds, each with 8 dimensions (16 bits, already aligned)
129
+ float [][] thresholds4 = { { 0.5f , 1.5f , 2.5f , 3.5f , 4.5f , 5.5f , 6.5f , 7.5f }, { 1.0f , 2.0f , 3.0f , 4.0f , 5.0f , 6.0f , 7.0f , 8.0f } };
130
+ MultiBitScalarQuantizationState state4 = new MultiBitScalarQuantizationState (params , thresholds4 );
131
+ int expectedDimensions4 = 16 ; // Already aligned to 8
132
+ assertEquals (expectedDimensions4 , state4 .getDimensions ());
133
+
134
+ // Case 5: 2 thresholds, each with 6 dimensions (12 bits, should align to 16)
135
+ float [][] thresholds5 = { { 0.5f , 1.5f , 2.5f , 3.5f , 4.5f , 5.5f }, { 1.0f , 2.0f , 3.0f , 4.0f , 5.0f , 6.0f } };
136
+ MultiBitScalarQuantizationState state5 = new MultiBitScalarQuantizationState (params , thresholds5 );
137
+ int expectedDimensions5 = 16 ; // The next multiple of 8 considering all bits
138
+ assertEquals (expectedDimensions5 , state5 .getDimensions ());
139
+ }
140
+
141
+ public void testOneBitScalarQuantizationStateGetDimensions_withDimensionNotMultipleOf8_thenSuccess () {
142
+ ScalarQuantizationParams params = new ScalarQuantizationParams (ScalarQuantizationType .ONE_BIT );
143
+
144
+ // Case 1: 5 dimensions (should align to 8)
145
+ float [] thresholds1 = { 0.5f , 1.5f , 2.5f , 3.5f , 4.5f };
146
+ OneBitScalarQuantizationState state1 = new OneBitScalarQuantizationState (params , thresholds1 );
147
+ int expectedDimensions1 = 8 ; // The next multiple of 8
148
+ assertEquals (expectedDimensions1 , state1 .getDimensions ());
149
+
150
+ // Case 2: 7 dimensions (should align to 8)
151
+ float [] thresholds2 = { 0.5f , 1.5f , 2.5f , 3.5f , 4.5f , 5.5f , 6.5f };
152
+ OneBitScalarQuantizationState state2 = new OneBitScalarQuantizationState (params , thresholds2 );
153
+ int expectedDimensions2 = 8 ; // The next multiple of 8
154
+ assertEquals (expectedDimensions2 , state2 .getDimensions ());
155
+
156
+ // Case 3: 8 dimensions (already aligned to 8)
157
+ float [] thresholds3 = { 0.5f , 1.5f , 2.5f , 3.5f , 4.5f , 5.5f , 6.5f , 7.5f };
158
+ OneBitScalarQuantizationState state3 = new OneBitScalarQuantizationState (params , thresholds3 );
159
+ int expectedDimensions3 = 8 ; // Already aligned to 8
160
+ assertEquals (expectedDimensions3 , state3 .getDimensions ());
161
+
162
+ // Case 4: 10 dimensions (should align to 16)
163
+ float [] thresholds4 = { 0.5f , 1.5f , 2.5f , 3.5f , 4.5f , 5.5f , 6.5f , 7.5f , 8.5f , 9.5f };
164
+ OneBitScalarQuantizationState state4 = new OneBitScalarQuantizationState (params , thresholds4 );
165
+ int expectedDimensions4 = 16 ; // The next multiple of 8
166
+ assertEquals (expectedDimensions4 , state4 .getDimensions ());
167
+
168
+ // Case 5: 16 dimensions (already aligned to 16)
169
+ float [] thresholds5 = { 0.5f , 1.5f , 2.5f , 3.5f , 4.5f , 5.5f , 6.5f , 7.5f , 8.5f , 9.5f , 10.5f , 11.5f , 12.5f , 13.5f , 14.5f , 15.5f };
170
+ OneBitScalarQuantizationState state5 = new OneBitScalarQuantizationState (params , thresholds5 );
171
+ int expectedDimensions5 = 16 ; // Already aligned to 16
172
+ assertEquals (expectedDimensions5 , state5 .getDimensions ());
173
+ }
174
+
103
175
public void testMultiBitScalarQuantizationStateRamBytesUsedManualCalculation () throws IOException {
104
176
ScalarQuantizationParams params = new ScalarQuantizationParams (ScalarQuantizationType .TWO_BIT );
105
177
float [][] thresholds = { { 0.5f , 1.5f , 2.5f }, { 1.0f , 2.0f , 3.0f } };
0 commit comments