Skip to content

Commit 00f9253

Browse files
authored
fixes for issues 390,391 and test coverage (#392)
1 parent a99f174 commit 00f9253

34 files changed

+1252
-178
lines changed

Java/core/src/main/java/com/amazon/randomcutforest/returntypes/RangeVector.java

+3
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ public RangeVector(int dimensions) {
5555
public RangeVector(float[] values, float[] upper, float[] lower) {
5656
checkArgument(values.length > 0, " dimensions must be > 0");
5757
checkArgument(values.length == upper.length && upper.length == lower.length, "dimensions must be equal");
58+
for (int i = 0; i < values.length; i++) {
59+
checkArgument(upper[i] >= values[i] && values[i] >= lower[i], "incorrect semantics");
60+
}
5861
this.values = Arrays.copyOf(values, values.length);
5962
this.upper = Arrays.copyOf(upper, upper.length);
6063
this.lower = Arrays.copyOf(lower, lower.length);

Java/core/src/main/java/com/amazon/randomcutforest/store/IndexIntervalManager.java

-8
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,4 @@ public void releaseIndex(int index) {
179179
lastInUse += 1;
180180
}
181181

182-
public int[] getFreeIndices() {
183-
int[] answer = new int[2 * lastInUse];
184-
for (int i = 0; i < 2 * lastInUse; i += 2) {
185-
answer[i] = freeIndexesStart[i / 2];
186-
answer[i + 1] = freeIndexesEnd[i / 2];
187-
}
188-
return answer;
189-
}
190182
}

Java/core/src/main/java/com/amazon/randomcutforest/tree/AbstractNodeStore.java

+2
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ public boolean isInternal(int index) {
7777

7878
public abstract int getRightIndex(int index);
7979

80+
public abstract int getParentIndex(int index);
81+
8082
public abstract void setRoot(int index);
8183

8284
protected abstract void decreaseMassOfInternalNode(int node);

Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreLarge.java

+7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
package com.amazon.randomcutforest.tree;
1717

18+
import static com.amazon.randomcutforest.CommonUtils.checkArgument;
19+
1820
import java.util.Arrays;
1921
import java.util.BitSet;
2022
import java.util.Stack;
@@ -202,4 +204,9 @@ public int[] getRightIndex() {
202204
return Arrays.copyOf(rightIndex, rightIndex.length);
203205
}
204206

207+
public int getParentIndex(int index) {
208+
checkArgument(parentIndex != null, "incorrect call");
209+
return parentIndex[index];
210+
}
211+
205212
}

Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreMedium.java

+6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
package com.amazon.randomcutforest.tree;
1717

18+
import static com.amazon.randomcutforest.CommonUtils.checkArgument;
1819
import static com.amazon.randomcutforest.CommonUtils.toCharArray;
1920
import static com.amazon.randomcutforest.CommonUtils.toIntArray;
2021

@@ -136,6 +137,11 @@ public int getRightIndex(int index) {
136137
return rightIndex[index];
137138
}
138139

140+
public int getParentIndex(int index) {
141+
checkArgument(parentIndex != null, "incorrect call");
142+
return parentIndex[index];
143+
}
144+
139145
public void setRoot(int index) {
140146
if (!isLeaf(index) && parentIndex != null) {
141147
parentIndex[index] = (char) capacity;

Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreSmall.java

+5
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ public int getRightIndex(int index) {
141141
return rightIndex[index];
142142
}
143143

144+
public int getParentIndex(int index) {
145+
checkArgument(parentIndex != null, "incorrect call");
146+
return parentIndex[index];
147+
}
148+
144149
public void setRoot(int index) {
145150
if (!isLeaf(index) && parentIndex != null) {
146151
parentIndex[index] = (byte) capacity;

Java/core/src/main/java/com/amazon/randomcutforest/tree/RandomCutTree.java

+16-11
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import static com.amazon.randomcutforest.CommonUtils.checkArgument;
1919
import static com.amazon.randomcutforest.CommonUtils.checkNotNull;
2020
import static com.amazon.randomcutforest.CommonUtils.checkState;
21+
import static com.amazon.randomcutforest.tree.AbstractNodeStore.DEFAULT_STORE_PARENT;
2122
import static com.amazon.randomcutforest.tree.AbstractNodeStore.Null;
2223
import static java.lang.Math.max;
2324

@@ -85,7 +86,8 @@ protected RandomCutTree(Builder<?> builder) {
8586
outputAfter = builder.outputAfter.orElse(max(1, numberOfLeaves / 4));
8687
dimension = (builder.dimension != 0) ? builder.dimension : pointStoreView.getDimensions();
8788
nodeStore = (builder.nodeStore != null) ? builder.nodeStore
88-
: AbstractNodeStore.builder().capacity(numberOfLeaves - 1).dimension(dimension).build();
89+
: AbstractNodeStore.builder().capacity(numberOfLeaves - 1).storeParent(builder.storeParent)
90+
.dimension(dimension).build();
8991
this.boundingBoxCacheFraction = builder.boundingBoxCacheFraction;
9092
this.storeSequenceIndexesEnabled = builder.storeSequenceIndexesEnabled;
9193
this.centerOfMassEnabled = builder.centerOfMassEnabled;
@@ -201,7 +203,7 @@ protected Cut randomCut(double factor, float[] point, BoundingBox box) {
201203
// debugging
202204
// this should be an anomaly - no pun intended.
203205

204-
Random rng = new Random((long) factor);
206+
Random rng = new Random((long) (factor * Long.MAX_VALUE / 2));
205207
if (rng.nextDouble() < 0.5) {
206208
for (int i = 0; i < box.getDimensions(); i++) {
207209
float minValue = (float) box.getMinValue(i);
@@ -357,7 +359,7 @@ protected void manageAncestorsAdd(Stack<int[]> path, float[] point) {
357359
}
358360
if (boundingBoxCacheFraction > 0.0) {
359361
checkContainsAndRebuildBox(index, point, pointStoreView);
360-
checkContainsAndAddPoint(index, point);
362+
addPointInPlace(index, point);
361363
}
362364
}
363365
}
@@ -471,7 +473,7 @@ public boolean isLeaf(int index) {
471473
public boolean isInternal(int index) {
472474
// note that numberOfLeaves - 1 corresponds to an unspefied leaf in partial tree
473475
// 0 .. numberOfLeaves - 2 corresponds to internal nodes
474-
return index < numberOfLeaves - 1;
476+
return index < numberOfLeaves - 1 && index >= 0;
475477
}
476478

477479
public int convertToLeaf(int pointIndex) {
@@ -571,9 +573,9 @@ void copyBoxToData(int idx, BoundingBox box) {
571573
rangeSumData[idx] = box.getRangeSum();
572574
}
573575

574-
boolean checkContainsAndAddPoint(int index, float[] point) {
576+
void addPointInPlace(int index, float[] point) {
575577
int idx = translate(index);
576-
if (idx != Integer.MAX_VALUE && rangeSumData[idx] != 0) {
578+
if (idx != Integer.MAX_VALUE) {
577579
int base = 2 * idx * dimension;
578580
int mid = base + dimension;
579581
double rangeSum = 0;
@@ -586,11 +588,8 @@ boolean checkContainsAndAddPoint(int index, float[] point) {
586588
for (int i = 0; i < dimension; i++) {
587589
rangeSum += boundingBoxData[mid + i] - boundingBoxData[base + i];
588590
}
589-
boolean answer = (rangeSumData[idx] == rangeSum);
590591
rangeSumData[idx] = rangeSum;
591-
return answer;
592592
}
593-
return false;
594593
}
595594

596595
public BoundingBox getBox(int index) {
@@ -639,7 +638,7 @@ boolean checkStrictlyContains(int index, float[] point) {
639638

640639
boolean checkContainsAndRebuildBox(int index, float[] point, IPointStoreView<float[]> pointStoreView) {
641640
int idx = translate(index);
642-
if (idx != Integer.MAX_VALUE && rangeSumData[idx] != 0) {
641+
if (idx != Integer.MAX_VALUE) {
643642
if (!checkStrictlyContains(index, point)) {
644643
BoundingBox mutatedBoundingBox = reconstructBox(index, pointStoreView);
645644
copyBoxToData(idx, mutatedBoundingBox);
@@ -663,7 +662,7 @@ void addBox(int index, float[] point, BoundingBox box) {
663662
int idx = translate(index);
664663
if (idx != Integer.MAX_VALUE) { // always add irrespective of rangesum
665664
copyBoxToData(idx, box);
666-
checkContainsAndAddPoint(index, point);
665+
addPointInPlace(index, point);
667666
}
668667
}
669668
}
@@ -1007,6 +1006,7 @@ public static class Builder<T extends Builder<T>> {
10071006
protected IPointStoreView<float[]> pointStoreView;
10081007
protected AbstractNodeStore nodeStore;
10091008
protected int root = Null;
1009+
protected boolean storeParent = DEFAULT_STORE_PARENT;
10101010

10111011
public T capacity(int capacity) {
10121012
this.capacity = capacity;
@@ -1053,6 +1053,11 @@ public T setRoot(int root) {
10531053
return (T) this;
10541054
}
10551055

1056+
public T storeParent(boolean storeParent) {
1057+
this.storeParent = storeParent;
1058+
return (T) this;
1059+
}
1060+
10561061
public T storeSequenceIndexesEnabled(boolean storeSequenceIndexesEnabled) {
10571062
this.storeSequenceIndexesEnabled = storeSequenceIndexesEnabled;
10581063
return (T) this;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package com.amazon.randomcutforest.returntypes;
17+
18+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
19+
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
20+
import static org.junit.jupiter.api.Assertions.assertThrows;
21+
22+
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Test;
24+
25+
public class RangeVectorTest {
26+
27+
int dimensions;
28+
private RangeVector vector;
29+
30+
@BeforeEach
31+
public void setUp() {
32+
dimensions = 3;
33+
vector = new RangeVector(dimensions);
34+
}
35+
36+
@Test
37+
public void testNew() {
38+
assertThrows(IllegalArgumentException.class, () -> new RangeVector(0));
39+
assertThrows(IllegalArgumentException.class, () -> new RangeVector(new float[0]));
40+
float[] expected = new float[dimensions];
41+
assertArrayEquals(expected, vector.values);
42+
assertArrayEquals(expected, vector.upper);
43+
assertArrayEquals(expected, vector.lower);
44+
45+
float[] another = new float[0];
46+
assertThrows(IllegalArgumentException.class, () -> new RangeVector(another, another, another));
47+
assertThrows(IllegalArgumentException.class,
48+
() -> new RangeVector(expected, expected, new float[dimensions + 1]));
49+
assertThrows(IllegalArgumentException.class,
50+
() -> new RangeVector(expected, new float[dimensions + 1], expected));
51+
assertThrows(IllegalArgumentException.class,
52+
() -> new RangeVector(new float[dimensions + 1], expected, expected));
53+
assertDoesNotThrow(() -> new RangeVector(expected, expected, expected));
54+
55+
assertThrows(IllegalArgumentException.class,
56+
() -> new RangeVector(expected, new float[] { -1f, 0f, 0f }, expected));
57+
assertDoesNotThrow(() -> new RangeVector(expected, expected, new float[] { -1f, 0f, 0f }));
58+
59+
assertThrows(IllegalArgumentException.class,
60+
() -> new RangeVector(expected, new float[] { 1f, 0f, 0f }, new float[] { 1f, 0f, 0f }));
61+
assertDoesNotThrow(() -> new RangeVector(expected, new float[] { 1f, 0f, 0f }, new float[] { -1f, 0f, 0f }));
62+
}
63+
64+
@Test
65+
public void testScale() {
66+
vector.upper[0] = 1.1f;
67+
vector.upper[2] = 3.1f;
68+
vector.upper[1] = 3.1f;
69+
vector.lower[1] = -2.2f;
70+
71+
float z = 9.9f;
72+
assertThrows(IllegalArgumentException.class, () -> vector.scale(0, -1.0f));
73+
assertThrows(IllegalArgumentException.class, () -> vector.scale(-1, 1.0f));
74+
assertThrows(IllegalArgumentException.class, () -> vector.scale(dimensions + 1, 1.0f));
75+
vector.scale(0, z);
76+
77+
float[] expected = new float[] { 1.1f * 9.9f, 3.1f, 3.1f };
78+
assertArrayEquals(expected, vector.upper, 1e-6f);
79+
80+
expected = new float[] { 0.0f, -2.2f, 0.0f };
81+
assertArrayEquals(expected, vector.lower);
82+
83+
vector.scale(1, 2 * z);
84+
assertArrayEquals(new float[] { 1.1f * 9.9f, 3.1f * 2 * z, 3.1f }, vector.upper, 1e-6f);
85+
assertArrayEquals(new float[] { 0f, -2.2f * 2 * z, 0f }, vector.lower, 1e-6f);
86+
}
87+
88+
@Test
89+
public void testShift() {
90+
vector.upper[0] = 1.1f;
91+
vector.upper[2] = 3.1f;
92+
vector.lower[1] = -2.2f;
93+
94+
float z = -9.9f;
95+
assertThrows(IllegalArgumentException.class, () -> vector.shift(-1, z));
96+
assertThrows(IllegalArgumentException.class, () -> vector.shift(dimensions + 1, z));
97+
vector.shift(0, z);
98+
99+
float[] expected = new float[] { 1.1f - 9.9f, 0.0f, 3.1f };
100+
assertArrayEquals(expected, vector.upper, 1e-6f);
101+
102+
expected = new float[] { z, -2.2f, 0.0f };
103+
assertArrayEquals(expected, vector.lower);
104+
105+
assertArrayEquals(new float[] { z, 0, 0 }, vector.values, 1e-6f);
106+
}
107+
108+
}

0 commit comments

Comments
 (0)