Skip to content

Commit aa47ba7

Browse files
authored
RCF4.0 and PredictiveRCF (#401)
* RCF4.0 and PredictiveRCF * fixes and spotless * fixes and improvements * fixes for normalization equalling forest outputReady * fixes and internal shingling * small corrections * changes and fixes * test changes * optimization and stability * comments etc. * fix * shortening noisy test * changes * comments * changes
1 parent 748f683 commit aa47ba7

File tree

137 files changed

+180498
-2783
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

137 files changed

+180498
-2783
lines changed

Java/benchmark/pom.xml

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<parent>
77
<groupId>software.amazon.randomcutforest</groupId>
88
<artifactId>randomcutforest-parent</artifactId>
9-
<version>3.8.0</version>
9+
<version>4.0.0-SNAPSHOT</version>
1010
</parent>
1111

1212
<artifactId>randomcutforest-benchmark</artifactId>
@@ -50,12 +50,12 @@
5050
<dependency>
5151
<groupId>com.fasterxml.jackson.core</groupId>
5252
<artifactId>jackson-core</artifactId>
53-
<version>2.14.2</version>
53+
<version>2.16.0</version>
5454
</dependency>
5555
<dependency>
5656
<groupId>com.fasterxml.jackson.core</groupId>
5757
<artifactId>jackson-databind</artifactId>
58-
<version>2.14.2</version>
58+
<version>2.16.0</version>
5959
</dependency>
6060
<dependency>
6161
<groupId>io.protostuff</groupId>

Java/core/pom.xml

+14-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<parent>
77
<groupId>software.amazon.randomcutforest</groupId>
88
<artifactId>randomcutforest-parent</artifactId>
9-
<version>3.8.0</version>
9+
<version>4.0.0-SNAPSHOT</version>
1010
</parent>
1111

1212
<artifactId>randomcutforest-core</artifactId>
@@ -22,7 +22,7 @@
2222
<dependency>
2323
<groupId>org.projectlombok</groupId>
2424
<artifactId>lombok</artifactId>
25-
<version>1.18.24</version>
25+
<version>1.18.30</version>
2626
<scope>provided</scope>
2727
</dependency>
2828
<dependency>
@@ -55,5 +55,17 @@
5555
<artifactId>powermock-api-easymock</artifactId>
5656
<scope>test</scope>
5757
</dependency>
58+
<dependency>
59+
<groupId>com.fasterxml.jackson.core</groupId>
60+
<artifactId>jackson-core</artifactId>
61+
<version>2.16.0</version>
62+
<scope>test</scope>
63+
</dependency>
64+
<dependency>
65+
<groupId>com.fasterxml.jackson.core</groupId>
66+
<artifactId>jackson-databind</artifactId>
67+
<version>2.16.0</version>
68+
<scope>test</scope>
69+
</dependency>
5870
</dependencies>
5971
</project>

Java/core/src/main/java/com/amazon/randomcutforest/CommonUtils.java

+8
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,10 @@ public static double[] toDoubleArray(float[] array) {
217217
return result;
218218
}
219219

220+
public static double[] toDoubleArrayNullable(float[] array) {
221+
return (array == null) ? null : toDoubleArray(array);
222+
}
223+
220224
public static float[] toFloatArray(double[] array) {
221225
checkNotNull(array, "array must not be null");
222226
float[] result = new float[array.length];
@@ -227,6 +231,10 @@ public static float[] toFloatArray(double[] array) {
227231
return result;
228232
}
229233

234+
public static float[] toFloatArrayNullable(double[] array) {
235+
return (array == null) ? null : toFloatArray(array);
236+
}
237+
230238
public static int[] toIntArray(byte[] values) {
231239
checkNotNull(values, "array must not be null");
232240
int[] result = new int[values.length];

Java/core/src/main/java/com/amazon/randomcutforest/MultiVisitor.java

+4-3
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,13 @@ public interface MultiVisitor<R> extends Visitor<R> {
3838
boolean trigger(final INodeView node);
3939

4040
/**
41-
* Return a copy of this visitor. The original visitor plus the copy will each
42-
* traverse one branch of the tree.
41+
* Return a partial copy of this visitor. The original visitor plus the copy
42+
* will each traverse one branch of the tree. The fields not copied will be
43+
* filled in by the branches of the tree
4344
*
4445
* @return a copy of this visitor
4546
*/
46-
MultiVisitor<R> newCopy();
47+
MultiVisitor<R> newPartialCopy();
4748

4849
/**
4950
* Combine two visitors. The state of the argument visitor should be combined

Java/core/src/main/java/com/amazon/randomcutforest/PredictiveRandomCutForest.java

+565
Large diffs are not rendered by default.

Java/core/src/main/java/com/amazon/randomcutforest/RandomCutForest.java

+137-97
Large diffs are not rendered by default.

Java/core/src/main/java/com/amazon/randomcutforest/executor/AbstractForestUpdateExecutor.java

+17-3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ public abstract class AbstractForestUpdateExecutor<PointReference, Point> {
3636

3737
protected final IStateCoordinator<PointReference, Point> updateCoordinator;
3838
protected final ComponentList<PointReference, Point> components;
39+
protected boolean currentlySampling = true;
3940

4041
/**
4142
* Create a new AbstractForestUpdateExecutor.
@@ -59,17 +60,26 @@ protected AbstractForestUpdateExecutor(IStateCoordinator<PointReference, Point>
5960
* @param point The point used to update the forest.
6061
*/
6162
public void update(Point point) {
63+
update(point, false);
64+
}
65+
66+
public void update(Point point, boolean updateShingleOnly) {
6267
long internalSequenceNumber = updateCoordinator.getTotalUpdates();
6368
IPointStore<?, ?> store = updateCoordinator.getStore();
6469
if (store != null && store.isInternalShinglingEnabled()) {
6570
internalSequenceNumber -= store.getShingleSize() - 1;
6671
}
67-
update(point, internalSequenceNumber);
72+
update(point, internalSequenceNumber, updateShingleOnly);
6873
}
6974

7075
public void update(Point point, long sequenceNumber) {
71-
PointReference updateInput = updateCoordinator.initUpdate(point, sequenceNumber);
72-
List<UpdateResult<PointReference>> results = (updateInput == null) ? Collections.emptyList()
76+
update(point, sequenceNumber, false);
77+
}
78+
79+
public void update(Point point, long sequenceNumber, boolean updateShingleOnly) {
80+
PointReference updateInput = updateCoordinator.initUpdate(point, sequenceNumber, updateShingleOnly);
81+
boolean propagate = (updateInput != null) && currentlySampling;
82+
List<UpdateResult<PointReference>> results = (!propagate) ? Collections.emptyList()
7383
: updateInternal(updateInput, sequenceNumber);
7484
updateCoordinator.completeUpdate(results, updateInput);
7585
}
@@ -86,4 +96,8 @@ public void update(Point point, long sequenceNumber) {
8696
*/
8797
protected abstract List<UpdateResult<PointReference>> updateInternal(PointReference updateInput, long currentIndex);
8898

99+
public void setCurrentlySampling(boolean value) {
100+
currentlySampling = value;
101+
}
102+
89103
}

Java/core/src/main/java/com/amazon/randomcutforest/executor/IStateCoordinator.java

+8-6
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,16 @@ public interface IStateCoordinator<PointReference, Point> {
3333
* Transform the input point into a value that can be submitted to IUpdatable
3434
* instances.
3535
*
36-
* @param point The input point.
37-
* @param sequenceNumber the sequence number associated with the point
36+
* @param point The input point.
37+
* @param sequenceNumber the sequence number associated with the point
38+
* @param updateShingleOnly Only update the shingles (Provide a null reference)
39+
* or, also update the point store (provide a usable
40+
* reference)
41+
*
3842
* @return The point transformed into the representation expected by an
3943
* IUpdatable instance.
4044
*/
41-
PointReference initUpdate(Point point, long sequenceNumber);
45+
PointReference initUpdate(Point point, long sequenceNumber, boolean updateShingleOnly);
4246

4347
/**
4448
* Complete the update. This method is called by IStateCoordinator after all
@@ -56,7 +60,5 @@ public interface IStateCoordinator<PointReference, Point> {
5660

5761
void setTotalUpdates(long totalUpdates);
5862

59-
default IPointStore<PointReference, Point> getStore() {
60-
return null;
61-
}
63+
IPointStore<PointReference, Point> getStore();
6264
}

Java/core/src/main/java/com/amazon/randomcutforest/executor/ParallelForestTraversalExecutor.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
*/
3535
public class ParallelForestTraversalExecutor extends AbstractForestTraversalExecutor {
3636

37-
private ForkJoinPool forkJoinPool;
37+
ForkJoinPool forkJoinPool;
3838
private final int threadPoolSize;
3939

4040
public ParallelForestTraversalExecutor(ComponentList<?, ?> treeExecutors, int threadPoolSize) {
@@ -96,7 +96,7 @@ public <R, S> S traverseForestMulti(float[] point, IMultiVisitorFactory<R> visit
9696
() -> components.parallelStream().map(c -> c.traverseMulti(point, visitorFactory)).collect(collector));
9797
}
9898

99-
private <T> T submitAndJoin(Callable<T> callable) {
99+
<T> T submitAndJoin(Callable<T> callable) {
100100
if (forkJoinPool == null) {
101101
forkJoinPool = new ForkJoinPool(threadPoolSize);
102102
}

Java/core/src/main/java/com/amazon/randomcutforest/executor/ParallelForestUpdateExecutor.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
public class ParallelForestUpdateExecutor<PointReference, Point>
3333
extends AbstractForestUpdateExecutor<PointReference, Point> {
3434

35-
private ForkJoinPool forkJoinPool;
35+
ForkJoinPool forkJoinPool;
3636
private final int threadPoolSize;
3737

3838
public ParallelForestUpdateExecutor(IStateCoordinator<PointReference, Point> updateCoordinator,
@@ -48,7 +48,7 @@ protected List<UpdateResult<PointReference>> updateInternal(PointReference point
4848
.filter(UpdateResult::isStateChange).collect(Collectors.toList()));
4949
}
5050

51-
private <T> T submitAndJoin(Callable<T> callable) {
51+
<T> T submitAndJoin(Callable<T> callable) {
5252
if (forkJoinPool == null) {
5353
forkJoinPool = new ForkJoinPool(threadPoolSize);
5454
}

Java/core/src/main/java/com/amazon/randomcutforest/executor/PointStoreCoordinator.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ public PointStoreCoordinator(IPointStore<Integer, Point> store) {
3838
}
3939

4040
@Override
41-
public Integer initUpdate(Point point, long sequenceNumber) {
42-
int index = store.add(point, sequenceNumber);
41+
public Integer initUpdate(Point point, long sequenceNumber, boolean updateShingleOnly) {
42+
int index = store.add(point, sequenceNumber, updateShingleOnly);
4343
return (index == PointStore.INFEASIBLE_POINTSTORE_INDEX) ? null : index;
4444
}
4545

Java/core/src/main/java/com/amazon/randomcutforest/imputation/ConditionalSampleSummarizer.java

+66-58
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,21 @@ public class ConditionalSampleSummarizer {
6464
*/
6565
protected boolean project = false;
6666

67-
public ConditionalSampleSummarizer(int[] missingDimensions, float[] queryPoint, double centrality) {
68-
this.missingDimensions = Arrays.copyOf(missingDimensions, missingDimensions.length);
69-
this.queryPoint = Arrays.copyOf(queryPoint, queryPoint.length);
70-
this.centrality = centrality;
71-
}
67+
protected int numberOfReps = 1;
7268

73-
public ConditionalSampleSummarizer(int[] missingDimensions, float[] queryPoint, double centrality,
74-
boolean project) {
69+
protected double shrinkage = 0;
70+
71+
protected int shingleSize = 1;
72+
73+
public ConditionalSampleSummarizer(int[] missingDimensions, float[] queryPoint, double centrality, boolean project,
74+
int numberOfReps, double shrinkage, int shingleSize) {
7575
this.missingDimensions = Arrays.copyOf(missingDimensions, missingDimensions.length);
7676
this.queryPoint = Arrays.copyOf(queryPoint, queryPoint.length);
7777
this.centrality = centrality;
7878
this.project = project;
79+
this.numberOfReps = numberOfReps;
80+
this.shrinkage = shrinkage;
81+
this.shingleSize = shingleSize;
7982
}
8083

8184
public SampleSummary summarize(List<ConditionalTreeSample> alist) {
@@ -102,21 +105,28 @@ public SampleSummary summarize(List<ConditionalTreeSample> alist, boolean addTyp
102105
List<ConditionalTreeSample> newList = ConditionalTreeSample.dedup(alist);
103106

104107
newList.sort((o1, o2) -> Double.compare(o1.distance, o2.distance));
108+
int dimensions = queryPoint.length;
105109

106-
ArrayList<Weighted<float[]>> points = new ArrayList<>();
107-
newList.stream().forEach(e -> {
108-
if (!project) {
109-
points.add(new Weighted<>(e.leafPoint, (float) e.weight));
110-
} else {
111-
float[] values = new float[missingDimensions.length];
112-
for (int i = 0; i < missingDimensions.length; i++) {
113-
values[i] = e.leafPoint[missingDimensions[i]];
110+
if (!addTypical) {
111+
ArrayList<Weighted<float[]>> points = new ArrayList<>();
112+
newList.stream().forEach(e -> {
113+
if (!project) {
114+
if (shingleSize == 1) {
115+
points.add(new Weighted<>(e.leafPoint, (float) e.weight));
116+
} else {
117+
float[] values = Arrays.copyOfRange(e.leafPoint, dimensions - dimensions / shingleSize,
118+
dimensions);
119+
points.add(new Weighted<>(values, (float) e.weight));
120+
}
121+
} else {
122+
float[] values = new float[missingDimensions.length];
123+
for (int i = 0; i < missingDimensions.length; i++) {
124+
values[i] = e.leafPoint[missingDimensions[i]];
125+
}
126+
points.add(new Weighted<>(values, (float) e.weight));
114127
}
115-
points.add(new Weighted<>(values, (float) e.weight));
116-
}
117-
});
128+
});
118129

119-
if (!addTypical) {
120130
return new SampleSummary(points);
121131
}
122132

@@ -131,34 +141,37 @@ public SampleSummary summarize(List<ConditionalTreeSample> alist, boolean addTyp
131141
* exact matches would go against the dynamic sampling based use of RCF.
132142
**/
133143

134-
int dimensions = queryPoint.length;
135-
136-
double threshold = centrality * newList.get(0).distance;
137-
double currentWeight = 0;
138-
int alwaysInclude = 0;
139-
double remainderWeight = totalWeight;
140-
while (newList.get(alwaysInclude).distance == 0) {
141-
remainderWeight -= newList.get(alwaysInclude).weight;
142-
++alwaysInclude;
143-
if (alwaysInclude == newList.size()) {
144-
break;
144+
int num = 0;
145+
if (centrality > 0) {
146+
double threshold = centrality * newList.get(0).distance + 1e-6;
147+
double currentWeight = 0;
148+
int alwaysInclude = 0;
149+
double remainderWeight = totalWeight;
150+
while (newList.get(alwaysInclude).distance == 0) {
151+
remainderWeight -= newList.get(alwaysInclude).weight;
152+
++alwaysInclude;
153+
if (alwaysInclude == newList.size()) {
154+
break;
155+
}
145156
}
146-
}
147-
for (int j = 1; j < newList.size(); j++) {
148-
if ((currentWeight < remainderWeight / 3 && currentWeight + newList.get(j).weight >= remainderWeight / 3)
149-
|| (currentWeight < remainderWeight / 2
150-
&& currentWeight + newList.get(j).weight >= remainderWeight / 2)) {
151-
threshold = centrality * newList.get(j).distance;
157+
for (int j = 1; j < newList.size(); j++) {
158+
if ((currentWeight < remainderWeight / 3
159+
&& currentWeight + newList.get(j).weight >= remainderWeight / 3)
160+
|| (currentWeight < remainderWeight / 2
161+
&& currentWeight + newList.get(j).weight >= remainderWeight / 2)) {
162+
threshold = centrality * newList.get(j).distance;
163+
}
164+
currentWeight += newList.get(j).weight;
152165
}
153-
currentWeight += newList.get(j).weight;
154-
}
155-
// note that the threshold is currently centrality * (some distance in the list)
156-
// thus the sequel uses a convex combination; and setting centrality = 0 removes
157-
// the entire filtering based on distances
158-
threshold += (1 - centrality) * newList.get(newList.size() - 1).distance;
159-
int num = 0;
160-
while (num < newList.size() && newList.get(num).distance <= threshold) {
161-
++num;
166+
// note that the threshold is currently centrality * (some distance in the list)
167+
// thus the sequel uses a convex combination; and setting centrality = 0 removes
168+
// the entire filtering based on distances
169+
threshold += (1 - centrality) * newList.get(newList.size() - 1).distance;
170+
while (num < newList.size() && newList.get(num).distance <= threshold) {
171+
++num;
172+
}
173+
} else {
174+
num = newList.size();
162175
}
163176

164177
ArrayList<Weighted<float[]>> typicalPoints = new ArrayList<>();
@@ -171,26 +184,21 @@ public SampleSummary summarize(List<ConditionalTreeSample> alist, boolean addTyp
171184
values[i] = e.leafPoint[missingDimensions[i]];
172185
}
173186
} else {
174-
values = Arrays.copyOf(e.leafPoint, dimensions);
187+
if (shingleSize == 1) {
188+
values = e.leafPoint;
189+
} else {
190+
values = Arrays.copyOfRange(e.leafPoint, dimensions - dimensions / shingleSize, dimensions);
191+
}
175192
}
176193
typicalPoints.add(new Weighted<>(values, (float) e.weight));
177194
}
178195
int maxAllowed = min(queryPoint.length * MAX_NUMBER_OF_TYPICAL_PER_DIMENSION, MAX_NUMBER_OF_TYPICAL_ELEMENTS);
179196
maxAllowed = min(maxAllowed, num);
180-
SampleSummary projectedSummary = Summarizer.l2summarize(typicalPoints, maxAllowed, num, false, 72);
181197

182-
float[][] pointList = new float[projectedSummary.summaryPoints.length][];
183-
float[] likelihood = new float[projectedSummary.summaryPoints.length];
184-
185-
for (int i = 0; i < projectedSummary.summaryPoints.length; i++) {
186-
pointList[i] = Arrays.copyOf(queryPoint, dimensions);
187-
for (int j = 0; j < missingDimensions.length; j++) {
188-
pointList[i][missingDimensions[j]] = projectedSummary.summaryPoints[i][j];
189-
}
190-
likelihood[i] = projectedSummary.relativeWeight[i];
191-
}
198+
SampleSummary projectedSummary = Summarizer.summarize(typicalPoints, maxAllowed, num, false,
199+
Summarizer::L2distance, 72, false, numberOfReps, shrinkage);
192200

193-
return new SampleSummary(points, pointList, likelihood);
201+
return new SampleSummary(typicalPoints, projectedSummary);
194202
}
195203

196204
}

0 commit comments

Comments
 (0)