Skip to content

Commit 5ab5a97

Browse files
authored
improving precision (#393)
* improving precision * refactor plus more tests * 100 percent coverage for tree and boundingbox * tests and coverage * 100 per cent coverage for summarization * 100 percent coverage for ErrorHandler * refactor RCFCaster
1 parent 00f9253 commit 5ab5a97

File tree

49 files changed

+1747
-658
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1747
-658
lines changed

Java/core/src/main/java/com/amazon/randomcutforest/imputation/ConditionalSampleSummarizer.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ public SampleSummary summarize(List<ConditionalTreeSample> alist, boolean addTyp
179179
}
180180
int maxAllowed = min(queryPoint.length * MAX_NUMBER_OF_TYPICAL_PER_DIMENSION, MAX_NUMBER_OF_TYPICAL_ELEMENTS);
181181
maxAllowed = min(maxAllowed, num);
182-
SampleSummary projectedSummary = Summarizer.summarize(typicalPoints, maxAllowed, num, false);
182+
SampleSummary projectedSummary = Summarizer.l2summarize(typicalPoints, maxAllowed, num, false, 72);
183183

184184
float[][] pointList = new float[projectedSummary.summaryPoints.length][];
185185
float[] likelihood = new float[projectedSummary.summaryPoints.length];

Java/core/src/main/java/com/amazon/randomcutforest/returntypes/RangeVector.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ public void shift(int i, float shift) {
9292

9393
public void scale(int i, float weight) {
9494
checkArgument(i >= 0 && i < values.length, "incorrect index");
95-
checkArgument(weight > 0, " negative weight not permitted");
95+
checkArgument(weight >= 0, " negative weight not permitted");
9696
values[i] = values[i] * weight;
9797
// managing precision
9898
upper[i] = max(upper[i] * weight, values[i]);

Java/core/src/main/java/com/amazon/randomcutforest/summarization/Center.java

+6-5
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,6 @@ public double getWeight() {
8383
return weight;
8484
}
8585

86-
public boolean captureBeforeReset(float[] point, BiFunction<float[], float[], Double> distance) {
87-
return previousWeight * distance.apply(point, representative) < 3 * previousSumOFRadius;
88-
}
89-
9086
// a standard reassignment using the median values and NOT the mean; the mean is
9187
// unlikely to
9288
// provide robust convergence
@@ -114,6 +110,9 @@ public double recompute(Function<Integer, float[]> getPoint, boolean approx,
114110
break;
115111
}
116112
}
113+
if (position == assignedPoints.size()) {
114+
position--;
115+
}
117116
representative[index] = getPoint.apply(assignedPoints.get(position).index)[index];
118117
}
119118
for (int j = 0; j < assignedPoints.size(); j++) {
@@ -163,7 +162,9 @@ public void absorb(ICluster<float[]> other, BiFunction<float[], float[], Double>
163162
}
164163

165164
public double distance(float[] point, BiFunction<float[], float[], Double> distance) {
166-
return distance.apply(point, representative);
165+
double t = distance.apply(point, representative);
166+
checkArgument(t >= 0, "distance cannot be negative");
167+
return t;
167168
}
168169

169170
@Override

Java/core/src/main/java/com/amazon/randomcutforest/summarization/GenericMultiCenter.java

+1-9
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,6 @@ public class GenericMultiCenter<R> implements ICluster<R> {
7272
this.shrinkage = shrinkage;
7373
}
7474

75-
public static <R> GenericMultiCenter<R> initialize(R coordinate, float weight) {
76-
return new GenericMultiCenter<>(coordinate, weight, DEFAULT_SHRINKAGE, DEFAULT_NUMBER_OF_REPRESENTATIVES);
77-
}
78-
7975
public static <R> GenericMultiCenter<R> initialize(R coordinate, float weight, double shrinkage,
8076
int numberOfRepresentatives) {
8177
checkArgument(shrinkage >= 0 && shrinkage <= 1.0, " parameter has to be in [0,1]");
@@ -130,10 +126,6 @@ public double getWeight() {
130126
return weight;
131127
}
132128

133-
public boolean captureBeforeReset(R point, BiFunction<R, R, Double> distanceFunction) {
134-
return previousWeight * distance(point, distanceFunction) < 3 * previousSumOFRadius;
135-
}
136-
137129
// reassignment may not be meaningful for generic types, without additional
138130
// information
139131
public double recompute(Function<Integer, R> getPoint, boolean flag, BiFunction<R, R, Double> distanceFunction) {
@@ -191,7 +183,7 @@ public void absorb(ICluster<R> other, BiFunction<R, R, Double> distance) {
191183
savedRepresentatives.remove(farthestIndex);
192184
}
193185

194-
// absorb the remainder into existing represen tatives
186+
// absorb the remainder into existing representatives
195187
for (Weighted<R> representative : savedRepresentatives) {
196188
double dist = distance.apply(representative.index, this.representatives.get(0).index);
197189
checkArgument(dist >= 0, "distance cannot be negative");

Java/core/src/main/java/com/amazon/randomcutforest/summarization/ICluster.java

+2-5
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,13 @@ public interface ICluster<R> {
4848
// weight computation
4949
double getWeight();
5050

51-
// is a point well expressed by the cluster? To be used in the future.
52-
boolean captureBeforeReset(R point, BiFunction<R, R, Double> distance);
53-
5451
// merge another cluster of same type
5552
void absorb(ICluster<R> other, BiFunction<R, R, Double> distance);
5653

57-
// distance of apoint from a cluster
54+
// distance of apoint from a cluster, has to be non-negative
5855
double distance(R point, BiFunction<R, R, Double> distance);
5956

60-
// distance of another cluster from this cluster
57+
// distance of another cluster from this cluster, has to be non negative
6158
double distance(ICluster<R> other, BiFunction<R, R, Double> distance);
6259

6360
// all potential representativess of a cluster these are typically chosen to be

Java/core/src/main/java/com/amazon/randomcutforest/summarization/MultiCenter.java

+1-5
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ public class MultiCenter extends GenericMultiCenter<float[]> {
3333
this.assignedPoints = new ArrayList<>();
3434
}
3535

36-
public static MultiCenter initialize(float[] coordinate, float weight) {
37-
return new MultiCenter(coordinate, weight, DEFAULT_SHRINKAGE, DEFAULT_NUMBER_OF_REPRESENTATIVES);
38-
}
39-
4036
public static MultiCenter initialize(float[] coordinate, float weight, double shrinkage,
4137
int numberOfRepresentatives) {
4238
checkArgument(shrinkage >= 0 && shrinkage <= 1.0, " parameter has to be in [0,1]");
@@ -70,9 +66,9 @@ public double recompute(Function<Integer, float[]> getPoint, boolean force,
7066
previousSumOFRadius = sumOfRadius;
7167
sumOfRadius = 0;
7268
for (int j = 0; j < assignedPoints.size(); j++) {
69+
// distance will check for -negative internally
7370
double addTerm = distance(getPoint.apply(assignedPoints.get(j).index), distanceFunction)
7471
* assignedPoints.get(j).weight;
75-
checkArgument(addTerm >= 0, "distances or weights cannot be negative");
7672
sumOfRadius += addTerm;
7773
}
7874
return (previousSumOFRadius - sumOfRadius);

Java/core/src/main/java/com/amazon/randomcutforest/summarization/Summarizer.java

+32-22
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ public static <R> void assignAndRecompute(List<Weighted<Integer>> sampledPoints,
105105
double minDist = Double.MAX_VALUE;
106106
int minDistNbr = -1;
107107
for (int i = 0; i < clusters.size(); i++) {
108+
// will check for negative distances
108109
dist[i] = clusters.get(i).distance(getPoint.apply(point.index), distance);
109-
checkArgument(dist[i] >= 0, "distance cannot be negative");
110110
if (minDist > dist[i]) {
111111
minDist = dist[i];
112112
minDistNbr = i;
@@ -201,10 +201,18 @@ public static <R> List<ICluster<R>> iterativeClustering(int maxAllowed, int init
201201
boolean phase2GlobalReassign, double overlapParameter, List<ICluster<R>> previousClustering) {
202202

203203
checkArgument(refs.size() > 0, "empty list, nothing to do");
204-
checkArgument(maxAllowed >= stopAt && stopAt > 0, "incorrect bounds on number of clusters");
204+
checkArgument(stopAt > 0, "has to stop at 1 cluster");
205+
checkArgument(stopAt <= maxAllowed, "cannot stop before achieving the limit");
205206

206207
Random rng = new Random(seed);
207-
double sampledSum = refs.stream().map(e -> (double) e.weight).reduce(Double::sum).get();
208+
double sampledSum = refs.stream().map(e -> {
209+
checkArgument(!Double.isNaN(e.weight), " weights have to be non-NaN");
210+
checkArgument(Double.isFinite(e.weight), " weights have to be finite");
211+
checkArgument(e.weight >= 0.0, () -> "negative weights are not meaningful" + e.weight);
212+
return (double) e.weight;
213+
}).reduce(0.0, Double::sum);
214+
checkArgument(sampledSum > 0, " total weight has to be positive");
215+
208216
ArrayList<ICluster<R>> centers = new ArrayList<>();
209217
if (refs.size() < 10 * (initial + 5)) {
210218
for (Weighted<Integer> point : refs) {
@@ -294,6 +302,8 @@ public static <R> List<ICluster<R>> iterativeClustering(int maxAllowed, int init
294302
}
295303
centers.sort(Comparator.comparingDouble(ICluster::getWeight));
296304
while (centers.get(0).getWeight() == 0.0) {
305+
// this line is reachable via zeroTest() in
306+
// SampleSummaryTest
297307
centers.remove(0);
298308
}
299309
if (inital < 1.2 * maxAllowed + 1) {
@@ -345,14 +355,14 @@ public static <R> List<ICluster<R>> summarize(List<Weighted<R>> points, int maxA
345355
List<ICluster<R>> previousClustering) {
346356
checkArgument(maxAllowed < 100, "are you sure you want more elements in the summary?");
347357
checkArgument(maxAllowed <= initial, "initial parameter should be at least maximum allowed in final result");
348-
checkArgument(stopAt > 0 && stopAt <= maxAllowed, "lower bound set incorrectly");
349358

350359
double totalWeight = points.stream().map(e -> {
351-
checkArgument(e.weight >= 0.0, "negative weights are not meaningful");
360+
checkArgument(!Double.isNaN(e.weight), " weights have to be non-NaN");
361+
checkArgument(Double.isFinite(e.weight), " weights have to be finite");
362+
checkArgument(e.weight >= 0.0, () -> "negative weights are not meaningful" + e.weight);
352363
return (double) e.weight;
353364
}).reduce(0.0, Double::sum);
354-
checkArgument(!Double.isNaN(totalWeight) && Double.isFinite(totalWeight),
355-
" weights have to finite and non-NaN");
365+
checkArgument(totalWeight > 0, " total weight has to be positive");
356366
Random rng = new Random(seed);
357367
// the following list is explicity copied and sorted for potential efficiency
358368
List<Weighted<R>> sampledPoints = createSample(points, rng.nextLong(), 5 * LENGTH_BOUND, 0.005, 1.0);
@@ -363,8 +373,6 @@ public static <R> List<ICluster<R>> summarize(List<Weighted<R>> points, int maxA
363373
}
364374

365375
Function<Integer, R> getPoint = (i) -> sampledPoints.get(i).index;
366-
checkArgument(sampledPoints.size() > 0, "empty list, nothing to do");
367-
double sampledSum = sampledPoints.stream().map(e -> (double) e.weight).reduce(Double::sum).get();
368376

369377
return iterativeClustering(maxAllowed, initial, stopAt, refs, getPoint, distance, clusterInitializer,
370378
rng.nextLong(), parallelEnabled, phase2GlobalReassign, overlapParameter, previousClustering);
@@ -403,11 +411,13 @@ public static SampleSummary summarize(List<Weighted<float[]>> points, int maxAll
403411
checkArgument(maxAllowed <= initial, "initial parameter should be at least maximum allowed in final result");
404412

405413
double totalWeight = points.stream().map(e -> {
406-
checkArgument(e.weight >= 0.0, "negative weights are not meaningful");
414+
checkArgument(!Double.isNaN(e.weight), " weights have to be non-NaN");
415+
checkArgument(Double.isFinite(e.weight), " weights have to be finite");
416+
checkArgument(e.weight >= 0.0, () -> "negative weights are not meaningful" + e.weight);
407417
return (double) e.weight;
408418
}).reduce(0.0, Double::sum);
409-
checkArgument(!Double.isNaN(totalWeight) && Double.isFinite(totalWeight),
410-
" weights have to finite and non-NaN");
419+
checkArgument(totalWeight > 0, " total weight has to be positive");
420+
411421
Random rng = new Random(seed);
412422
// the following list is explicity copied and sorted for potential efficiency
413423
List<Weighted<float[]>> sampledPoints = createSample(points, rng.nextLong(), 5 * LENGTH_BOUND, 0.005, 1.0);
@@ -458,24 +468,24 @@ public static SampleSummary summarize(float[][] points, int maxAllowed, int init
458468
* @param maxAllowed maximum number of groups/clusters
459469
* @param initial a parameter controlling the initialization
460470
* @param reassignPerStep if reassignment is to be performed each step
471+
* @param seed random seed
461472
* @return a summarization
462473
*/
463-
public static SampleSummary summarize(List<Weighted<float[]>> points, int maxAllowed, int initial,
464-
boolean reassignPerStep) {
465-
return summarize(points, maxAllowed, initial, reassignPerStep, Summarizer::L2distance, new Random().nextLong(),
466-
false);
474+
public static SampleSummary l2summarize(List<Weighted<float[]>> points, int maxAllowed, int initial,
475+
boolean reassignPerStep, long seed) {
476+
return summarize(points, maxAllowed, initial, reassignPerStep, Summarizer::L2distance, seed, false);
467477
}
468478

469479
/**
470480
* Same as above, with the most common use cases filled in
471481
*
472482
* @param points points in float[][], each of weight 1.0
473483
* @param maxAllowed maximum number of clusters one is interested in
484+
* @param seed random seed
474485
* @return a summarization
475486
*/
476-
public static SampleSummary summarize(float[][] points, int maxAllowed) {
477-
return summarize(points, maxAllowed, 4 * maxAllowed, false, Summarizer::L2distance, new Random().nextLong(),
478-
false);
487+
public static SampleSummary l2summarize(float[][] points, int maxAllowed, long seed) {
488+
return summarize(points, maxAllowed, 4 * maxAllowed, false, Summarizer::L2distance, seed, false);
479489
}
480490

481491
/**
@@ -529,9 +539,9 @@ public static <R> List<ICluster<R>> multiSummarize(R[] points, int maxAllowed, i
529539
clusterInitializer, seed, parallelEnabled, null);
530540
}
531541

532-
// same as above, with defaults
542+
// same as above, with multicenter instead of generic
533543
public static List<ICluster<float[]>> multiSummarize(float[][] points, int maxAllowed, double shrinkage,
534-
int numberOfRepresentatives) {
544+
int numberOfRepresentatives, long seed) {
535545

536546
ArrayList<Weighted<float[]>> weighted = new ArrayList<>();
537547
for (float[] point : points) {
@@ -540,7 +550,7 @@ public static List<ICluster<float[]>> multiSummarize(float[][] points, int maxAl
540550
BiFunction<float[], Float, ICluster<float[]>> clusterInitializer = (a, b) -> MultiCenter.initialize(a, b,
541551
shrinkage, numberOfRepresentatives);
542552
return summarize(weighted, maxAllowed, 4 * maxAllowed, 1, true, DEFAULT_SEPARATION_RATIO_FOR_MERGE,
543-
Summarizer::L2distance, clusterInitializer, new Random().nextLong(), true, null);
553+
Summarizer::L2distance, clusterInitializer, seed, true, null);
544554
}
545555

546556
}

Java/core/src/main/java/com/amazon/randomcutforest/tree/BoundingBox.java

-12
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,6 @@ public BoundingBox getMergedBox(IBoundingBoxView otherBox) {
9191
return new BoundingBox(minValuesMerged, maxValuesMerged, sum);
9292
}
9393

94-
public void replaceBox(float[] point) {
95-
System.arraycopy(point, 0, minValues, 0, point.length);
96-
System.arraycopy(point, 0, maxValues, 0, point.length);
97-
rangeSum = 0;
98-
}
99-
100-
public void copyFrom(BoundingBox otherBox) {
101-
System.arraycopy(otherBox.minValues, 0, minValues, 0, otherBox.minValues.length);
102-
System.arraycopy(otherBox.maxValues, 0, maxValues, 0, otherBox.maxValues.length);
103-
rangeSum = otherBox.rangeSum;
104-
}
105-
10694
public double probabilityOfCut(float[] point) {
10795
double range = 0;
10896
for (int i = 0; i < point.length; i++) {

0 commit comments

Comments
 (0)