Skip to content

Commit f2984b5

Browse files
authored
Fix Confidence Adjustment for Larger Shingle Sizes (#407)
* Fix Confidence Adjustment for Larger Shingle Sizes This PR addresses further adjustments to the confidence calculation issue discussed in PR 405. While PR 405 successfully resolved the issue for a shingle size of 4, it did not achieve the same results for larger shingle sizes like 8. Key Changes 1. Refinement of seenValues Calculation: * Previously, the formula increased confidence even as numImputed (number of imputations seen) increased because seenValues (all values seen) also increased. * This PR fixes the issue by counting only non-imputed values as seenValues. 2. Upper Bound for numImputed: * The numImputed is now upper bounded to the shingle size. * The impute fraction calculation, which uses numberOfImputed * 1.0 / shingleSize, now ensures the fraction does not exceed 1. 3. Decrementing numberOfImputed: * The numberOfImputed is decremented when there is no imputation. * Previously, numberOfImputed remained unchanged when there is an imputation as there was both an increment and a decrement, keeping the imputation fraction constant. This PR ensures the imputation fraction accurately reflects the current state. This adjustment ensures that the forest update decision, which relies on the imputation fraction, functions correctly. The forest is updated only when the imputation fraction is below the threshold of 0.5. Testing * Added test scenarios with various shingle sizes to verify the changes. Signed-off-by: Kaituo Li <kaituo@amazon.com> * added comment Signed-off-by: Kaituo Li <kaituo@amazon.com> --------- Signed-off-by: Kaituo Li <kaituo@amazon.com>
1 parent 07aab4a commit f2984b5

File tree

3 files changed

+120
-52
lines changed

3 files changed

+120
-52
lines changed

Java/core/src/main/java/com/amazon/randomcutforest/preprocessor/ImputePreprocessor.java

+66-41
Original file line numberDiff line numberDiff line change
@@ -80,44 +80,6 @@ public float[] getScaledShingledInput(double[] inputPoint, long timestamp, int[]
8080
return point;
8181
}
8282

83-
/**
84-
* the timestamps are now used to calculate the number of imputed tuples in the
85-
* shingle
86-
*
87-
* @param timestamp the timestamp of the current input
88-
*/
89-
@Override
90-
protected void updateTimestamps(long timestamp) {
91-
/*
92-
* For imputations done on timestamps other than the current one (specified by
93-
* the timestamp parameter), the timestamp of the imputed tuple matches that of
94-
* the input tuple, and we increment numberOfImputed. For imputations done at
95-
* the current timestamp (if all input values are missing), the timestamp of the
96-
* imputed tuple is the current timestamp, and we increment numberOfImputed.
97-
*
98-
* To check if imputed values are still present in the shingle, we use the first
99-
* condition (previousTimeStamps[0] == previousTimeStamps[1]). This works
100-
* because previousTimeStamps has a size equal to the shingle size and is filled
101-
* with the current timestamp. However, there are scenarios where we might miss
102-
* decrementing numberOfImputed:
103-
*
104-
* 1. Not all values in the shingle are imputed. 2. We accumulated
105-
* numberOfImputed when the current timestamp had missing values.
106-
*
107-
* As a result, this could cause the data quality measure to decrease
108-
* continuously since we are always counting missing values that should
109-
* eventually be reset to zero. The second condition <pre> timestamp >
110-
* previousTimeStamps[previousTimeStamps.length-1] && numberOfImputed > 0 </pre>
111-
* will decrement numberOfImputed when we move to a new timestamp, provided
112-
* numberOfImputed is greater than zero.
113-
*/
114-
if (previousTimeStamps[0] == previousTimeStamps[1]
115-
|| (timestamp > previousTimeStamps[previousTimeStamps.length - 1] && numberOfImputed > 0)) {
116-
numberOfImputed = numberOfImputed - 1;
117-
}
118-
super.updateTimestamps(timestamp);
119-
}
120-
12183
/**
12284
* decides if the forest should be updated, this is needed for imputation on the
12385
* fly. The main goal of this function is to avoid runaway sequences where a
@@ -128,7 +90,10 @@ protected void updateTimestamps(long timestamp) {
12890
*/
12991
protected boolean updateAllowed() {
13092
double fraction = numberOfImputed * 1.0 / (shingleSize);
131-
if (numberOfImputed == shingleSize - 1 && previousTimeStamps[0] != previousTimeStamps[1]
93+
if (fraction > 1) {
94+
fraction = 1;
95+
}
96+
if (numberOfImputed >= shingleSize - 1 && previousTimeStamps[0] != previousTimeStamps[1]
13297
&& (transformMethod == DIFFERENCE || transformMethod == NORMALIZE_DIFFERENCE)) {
13398
// this shingle is disconnected from the previously seen values
13499
// these transformations will have little meaning
@@ -144,10 +109,57 @@ protected boolean updateAllowed() {
144109
// two different points).
145110
return false;
146111
}
112+
147113
dataQuality[0].update(1 - fraction);
148114
return (fraction < useImputedFraction && internalTimeStamp >= shingleSize);
149115
}
150116

117+
@Override
118+
protected void updateTimestamps(long timestamp) {
119+
/*
120+
* For imputations done on timestamps other than the current one (specified by
121+
* the timestamp parameter), the timestamp of the imputed tuple matches that of
122+
* the input tuple, and we increment numberOfImputed. For imputations done at
123+
* the current timestamp (if all input values are missing), the timestamp of the
124+
* imputed tuple is the current timestamp, and we increment numberOfImputed.
125+
*
126+
* To check if imputed values are still present in the shingle, we use the
127+
* condition (previousTimeStamps[0] == previousTimeStamps[1]). This works
128+
* because previousTimeStamps has a size equal to the shingle size and is filled
129+
* with the current timestamp.
130+
*
131+
* For example, if the last 10 values were imputed and the shingle size is 8,
132+
* the condition will most likely return false until all 10 imputed values are
133+
* removed from the shingle.
134+
*
135+
* However, there are scenarios where we might miss decrementing
136+
* numberOfImputed:
137+
*
138+
* 1. Not all values in the shingle are imputed. 2. We accumulated
139+
* numberOfImputed when the current timestamp had missing values.
140+
*
141+
* As a result, this could cause the data quality measure to decrease
142+
* continuously since we are always counting missing values that should
143+
* eventually be reset to zero. To address the issue, we add code in method
144+
* updateForest to decrement numberOfImputed when we move to a new timestamp,
145+
* provided there is no imputation. This ensures th e imputation fraction does
146+
* not increase as long as the imputation is continuing. This also ensures that
147+
* the forest update decision, which relies on the imputation fraction,
148+
* functions correctly. The forest is updated only when the imputation fraction
149+
* is below the threshold of 0.5.
150+
*
151+
* Also, why can't we combine the decrement code between updateTimestamps and
152+
* updateForest together? This would cause Consistency.ImputeTest to fail when
153+
* testing with and without imputation, as the RCF scores would not change. The
154+
* method updateTimestamps is used in other places (e.g., updateState and
155+
* dischargeInitial), not only in updateForest.
156+
*/
157+
if (previousTimeStamps[0] == previousTimeStamps[1]) {
158+
numberOfImputed = numberOfImputed - 1;
159+
}
160+
super.updateTimestamps(timestamp);
161+
}
162+
151163
/**
152164
* the following function mutates the forest, the lastShingledPoint,
153165
* lastShingledInput as well as previousTimeStamps, and adds the shingled input
@@ -168,7 +180,13 @@ void updateForest(boolean changeForest, double[] input, long timestamp, RandomCu
168180
updateShingle(input, scaledInput);
169181
updateTimestamps(timestamp);
170182
if (isFullyImputed) {
171-
numberOfImputed = numberOfImputed + 1;
183+
// The numImputed is now capped at the shingle size to ensure that the impute
184+
// fraction,
185+
// calculated as numberOfImputed * 1.0 / shingleSize, does not exceed 1.
186+
numberOfImputed = Math.min(numberOfImputed + 1, shingleSize);
187+
} else if (numberOfImputed > 0) {
188+
// Decrement numberOfImputed when the new value is not imputed
189+
numberOfImputed = numberOfImputed - 1;
172190
}
173191
if (changeForest) {
174192
if (forest.isInternalShinglingEnabled()) {
@@ -190,7 +208,14 @@ public void update(double[] point, float[] rcfPoint, long timestamp, int[] missi
190208
return;
191209
}
192210
generateShingle(point, timestamp, missing, getTimeFactor(timeStampDeviations[1]), true, forest);
193-
++valuesSeen;
211+
// The confidence formula depends on numImputed (the number of recent
212+
// imputations seen)
213+
// and seenValues (all values seen). To ensure confidence decreases when
214+
// numImputed increases,
215+
// we need to count only non-imputed values as seenValues.
216+
if (missing == null || missing.length != point.length) {
217+
++valuesSeen;
218+
}
194219
}
195220

196221
protected double getTimeFactor(Deviation deviation) {

Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/PredictorCorrector.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,8 @@ public void setLastScore(double[] score) {
961961
}
962962

963963
void validateIgnore(double[] shift, int length) {
964-
checkArgument(shift.length == length, () -> String.format(Locale.ROOT, "has to be of length %d but is %d", length, shift.length));
964+
checkArgument(shift.length == length,
965+
() -> String.format(Locale.ROOT, "has to be of length %d but is %d", length, shift.length));
965966
for (double element : shift) {
966967
checkArgument(element >= 0, "has to be non-negative");
967968
}

Java/parkservices/src/test/java/com/amazon/randomcutforest/parkservices/MissingValueTest.java

+52-10
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,38 @@
1919
import static org.junit.jupiter.api.Assertions.assertTrue;
2020

2121
import java.util.ArrayList;
22+
import java.util.Arrays;
2223
import java.util.List;
24+
import java.util.Locale;
2325
import java.util.Random;
26+
import java.util.stream.Stream;
2427

28+
import org.junit.jupiter.api.extension.ExtensionContext;
2529
import org.junit.jupiter.params.ParameterizedTest;
26-
import org.junit.jupiter.params.provider.EnumSource;
30+
import org.junit.jupiter.params.provider.Arguments;
31+
import org.junit.jupiter.params.provider.ArgumentsProvider;
32+
import org.junit.jupiter.params.provider.ArgumentsSource;
2733

2834
import com.amazon.randomcutforest.config.ForestMode;
2935
import com.amazon.randomcutforest.config.ImputationMethod;
3036
import com.amazon.randomcutforest.config.Precision;
3137
import com.amazon.randomcutforest.config.TransformMethod;
3238

3339
public class MissingValueTest {
40+
private static class EnumAndValueProvider implements ArgumentsProvider {
41+
@Override
42+
public Stream<? extends Arguments> provideArguments(ExtensionContext context) {
43+
return Stream.of(ImputationMethod.PREVIOUS, ImputationMethod.ZERO, ImputationMethod.FIXED_VALUES)
44+
.flatMap(method -> Stream.of(4, 8, 16) // Example shingle sizes
45+
.map(shingleSize -> Arguments.of(method, shingleSize)));
46+
}
47+
}
48+
3449
@ParameterizedTest
35-
@EnumSource(ImputationMethod.class)
36-
public void testConfidence(ImputationMethod method) {
50+
@ArgumentsSource(EnumAndValueProvider.class)
51+
public void testConfidence(ImputationMethod method, int shingleSize) {
3752
// Create and populate a random cut forest
3853

39-
int shingleSize = 4;
4054
int numberOfTrees = 50;
4155
int sampleSize = 256;
4256
Precision precision = Precision.FLOAT_32;
@@ -45,11 +59,19 @@ public void testConfidence(ImputationMethod method) {
4559
long count = 0;
4660

4761
int dimensions = baseDimensions * shingleSize;
48-
ThresholdedRandomCutForest forest = new ThresholdedRandomCutForest.Builder<>().compact(true)
62+
ThresholdedRandomCutForest.Builder forestBuilder = new ThresholdedRandomCutForest.Builder<>().compact(true)
4963
.dimensions(dimensions).randomSeed(0).numberOfTrees(numberOfTrees).shingleSize(shingleSize)
5064
.sampleSize(sampleSize).precision(precision).anomalyRate(0.01).imputationMethod(method)
51-
.fillValues(new double[] { 3 }).forestMode(ForestMode.STREAMING_IMPUTE)
52-
.transformMethod(TransformMethod.NORMALIZE).autoAdjust(true).build();
65+
.forestMode(ForestMode.STREAMING_IMPUTE).transformMethod(TransformMethod.NORMALIZE).autoAdjust(true);
66+
67+
if (method == ImputationMethod.FIXED_VALUES) {
68+
// we cannot pass fillValues when the method is not fixed values. Otherwise, we
69+
// will impute
70+
// filled in values irregardless of imputation method
71+
forestBuilder.fillValues(new double[] { 3 });
72+
}
73+
74+
ThresholdedRandomCutForest forest = forestBuilder.build();
5375

5476
// Define the size and range
5577
int size = 400;
@@ -75,18 +97,38 @@ public void testConfidence(ImputationMethod method) {
7597
float[] rcfPoint = result.getRCFPoint();
7698
double scale = result.getScale()[0];
7799
double shift = result.getShift()[0];
78-
double[] actual = new double[] { (rcfPoint[3] * scale) + shift };
100+
double[] actual = new double[] { (rcfPoint[shingleSize - 1] * scale) + shift };
79101
if (method == ImputationMethod.ZERO) {
80102
assertEquals(0, actual[0], 0.001d);
103+
if (count == 300) {
104+
assertTrue(result.getAnomalyGrade() > 0);
105+
}
81106
} else if (method == ImputationMethod.FIXED_VALUES) {
82107
assertEquals(3.0d, actual[0], 0.001d);
108+
if (count == 300) {
109+
assertTrue(result.getAnomalyGrade() > 0);
110+
}
111+
} else if (method == ImputationMethod.PREVIOUS) {
112+
assertEquals(0, result.getAnomalyGrade(), 0.001d,
113+
"count: " + count + " actual: " + Arrays.toString(actual));
83114
}
84115
} else {
85116
AnomalyDescriptor result = forest.process(point, newStamp);
86-
if ((count > 100 && count < 300) || count >= 326) {
117+
// after 325, we have a period of confidence decreasing. After that, confidence
118+
// starts increasing again.
119+
// We are not sure where the confidence will start increasing after decreasing.
120+
// So we start check the behavior after 325 + shingleSize.
121+
int backupPoint = 325 + shingleSize;
122+
if ((count > 100 && count < 300) || count >= backupPoint) {
87123
// The first 65+ observations gives 0 confidence.
88124
// Confidence start increasing after 1 observed point
89-
assertTrue(result.getDataConfidence() > lastConfidence);
125+
assertTrue(result.getDataConfidence() > lastConfidence,
126+
String.format(Locale.ROOT, "count: %d, confidence: %f, last confidence: %f", count,
127+
result.getDataConfidence(), lastConfidence));
128+
} else if (count < 325 && count > 300) {
129+
assertTrue(result.getDataConfidence() < lastConfidence,
130+
String.format(Locale.ROOT, "count: %d, confidence: %f, last confidence: %f", count,
131+
result.getDataConfidence(), lastConfidence));
90132
}
91133
lastConfidence = result.getDataConfidence();
92134
}

0 commit comments

Comments
 (0)