Skip to content

Commit 223473a

Browse files
authored
cache and convertor issues (#263)
* cache and convertor issues * fixes * name change and spotless
1 parent a444a64 commit 223473a

File tree

9 files changed

+175
-45
lines changed

9 files changed

+175
-45
lines changed

Java/core/src/main/java/com/amazon/randomcutforest/tree/AbstractCompactRandomCutTree.java

+5
Original file line numberDiff line numberDiff line change
@@ -575,4 +575,9 @@ public T maxSize(int maxSize) {
575575

576576
}
577577

578+
@Override
579+
public void setBoundingBoxCacheFraction(double fraction) {
580+
boxCache.setCacheFraction(fraction);
581+
super.setBoundingBoxCacheFraction(fraction);
582+
}
578583
}

Java/core/src/main/java/com/amazon/randomcutforest/tree/BoxCache.java

+5
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ protected BoxCache(long seed, double cacheFraction, int maxSize) {
3838

3939
abstract void initialize();
4040

41+
public void setCacheFraction(double cacheFraction) {
42+
this.cacheFraction = cacheFraction;
43+
initialize();
44+
}
45+
4146
boolean isDirectMap() {
4247
return cacheFraction >= 0.3;
4348
}

Java/core/src/main/java/com/amazon/randomcutforest/tree/IBoxCache.java

+8
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,12 @@ public interface IBoxCache<Point> {
5959
*/
6060
void addToBox(int index, Point point);
6161

62+
/**
63+
* changes the fraction of boxes cached dynamically
64+
*
65+
* @param fraction new fraction of boxes to be cached
66+
*/
67+
68+
void setCacheFraction(double fraction);
69+
6270
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package com.amazon.randomcutforest.tree;
17+
18+
import static org.junit.jupiter.api.Assertions.assertEquals;
19+
20+
import java.util.Random;
21+
22+
import org.junit.jupiter.api.Test;
23+
24+
import com.amazon.randomcutforest.RandomCutForest;
25+
import com.amazon.randomcutforest.RandomCutForestTest;
26+
import com.amazon.randomcutforest.config.Precision;
27+
28+
public class BoxCacheTest {
29+
30+
@Test
31+
public void testChangingBoundingBoxFloat32() {
32+
int dimensions = 4;
33+
int numberOfTrees = 1;
34+
int sampleSize = 64;
35+
int dataSize = 100000 * sampleSize;
36+
Random random = new Random();
37+
long seed = random.nextLong();
38+
double[][] big = RandomCutForestTest.generateShingledData(dataSize, dimensions, 2);
39+
RandomCutForest forest = RandomCutForest.builder().compact(true).dimensions(dimensions)
40+
.numberOfTrees(numberOfTrees).sampleSize(sampleSize).precision(Precision.FLOAT_32).randomSeed(seed)
41+
.boundingBoxCacheFraction(0).build();
42+
RandomCutForest otherForest = RandomCutForest.builder().compact(true).dimensions(dimensions)
43+
.numberOfTrees(numberOfTrees).sampleSize(sampleSize).precision(Precision.FLOAT_32).randomSeed(seed)
44+
.boundingBoxCacheFraction(1).build();
45+
int num = 0;
46+
for (double[] point : big) {
47+
++num;
48+
if (num % sampleSize == 0) {
49+
forest.setBoundingBoxCacheFraction(random.nextDouble());
50+
}
51+
assertEquals(forest.getAnomalyScore(point), otherForest.getAnomalyScore(point));
52+
forest.update(point);
53+
otherForest.update(point);
54+
}
55+
}
56+
57+
@Test
58+
public void testChangingBoundingBoxFloat64() {
59+
int dimensions = 10;
60+
int numberOfTrees = 1;
61+
int sampleSize = 256;
62+
int dataSize = 4000 * sampleSize;
63+
Random random = new Random();
64+
double[][] big = RandomCutForestTest.generateShingledData(dataSize, dimensions, 2);
65+
RandomCutForest forest = RandomCutForest.builder().compact(true).dimensions(dimensions)
66+
.numberOfTrees(numberOfTrees).sampleSize(sampleSize).precision(Precision.FLOAT_64)
67+
.randomSeed(random.nextLong()).boundingBoxCacheFraction(random.nextDouble()).build();
68+
69+
for (double[] point : big) {
70+
forest.setBoundingBoxCacheFraction(random.nextDouble());
71+
forest.update(point);
72+
}
73+
}
74+
75+
}

Java/examples/src/main/java/com/amazon/randomcutforest/examples/Main.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515

1616
package com.amazon.randomcutforest.examples;
1717

18+
import java.util.Map;
19+
import java.util.TreeMap;
20+
1821
import com.amazon.randomcutforest.examples.dynamicinference.DynamicDensity;
1922
import com.amazon.randomcutforest.examples.dynamicinference.DynamicNearNeighbor;
2023
import com.amazon.randomcutforest.examples.serialization.JsonExample;
2124
import com.amazon.randomcutforest.examples.serialization.ProtostuffExample;
2225

23-
import java.util.Map;
24-
import java.util.TreeMap;
25-
2626
public class Main {
2727

2828
public static final String ARCHIVE_NAME = "randomcutforest-examples-1.0.jar";

Java/examples/src/main/java/com/amazon/randomcutforest/examples/dynamicinference/DynamicDensity.java

+8-8
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@
1515

1616
package com.amazon.randomcutforest.examples.dynamicinference;
1717

18-
import com.amazon.randomcutforest.RandomCutForest;
19-
import com.amazon.randomcutforest.examples.Example;
20-
import com.amazon.randomcutforest.returntypes.DensityOutput;
18+
import static com.amazon.randomcutforest.testutils.ExampleDataSets.generate;
19+
import static com.amazon.randomcutforest.testutils.ExampleDataSets.rotateClockWise;
20+
import static java.lang.Math.PI;
2121

2222
import java.io.BufferedWriter;
2323
import java.io.FileWriter;
2424

25-
import static com.amazon.randomcutforest.testutils.ExampleDataSets.generate;
26-
import static com.amazon.randomcutforest.testutils.ExampleDataSets.rotateClockWise;
27-
import static java.lang.Math.PI;
25+
import com.amazon.randomcutforest.RandomCutForest;
26+
import com.amazon.randomcutforest.examples.Example;
27+
import com.amazon.randomcutforest.returntypes.DensityOutput;
2828

2929
public class DynamicDensity implements Example {
3030

@@ -39,8 +39,8 @@ public String command() {
3939

4040
@Override
4141
public String description() {
42-
return "shows two potential use of dynamic density computations; estimating density as well " +
43-
"as its directional components";
42+
return "shows two potential use of dynamic density computations; estimating density as well "
43+
+ "as its directional components";
4444
}
4545

4646
/**

Java/examples/src/main/java/com/amazon/randomcutforest/examples/dynamicinference/DynamicNearNeighbor.java

+7-9
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515

1616
package com.amazon.randomcutforest.examples.dynamicinference;
1717

18-
import com.amazon.randomcutforest.RandomCutForest;
19-
import com.amazon.randomcutforest.examples.Example;
18+
import static com.amazon.randomcutforest.testutils.ExampleDataSets.generate;
19+
import static com.amazon.randomcutforest.testutils.ExampleDataSets.rotateClockWise;
20+
import static java.lang.Math.PI;
2021

2122
import java.io.BufferedWriter;
2223
import java.io.FileWriter;
2324

24-
import static com.amazon.randomcutforest.testutils.ExampleDataSets.generate;
25-
import static com.amazon.randomcutforest.testutils.ExampleDataSets.rotateClockWise;
26-
import static java.lang.Math.PI;
25+
import com.amazon.randomcutforest.RandomCutForest;
26+
import com.amazon.randomcutforest.examples.Example;
2727

2828
public class DynamicNearNeighbor implements Example {
2929

@@ -38,12 +38,10 @@ public String command() {
3838

3939
@Override
4040
public String description() {
41-
return "shows an example of dynamic near neighbor computation where both the data and query are " +
42-
"evolving in time";
41+
return "shows an example of dynamic near neighbor computation where both the data and query are "
42+
+ "evolving in time";
4343
}
4444

45-
46-
4745
@Override
4846
public void run() throws Exception {
4947
int newDimensions = 2;

Java/serialization/src/main/java/com/amazon/randomcutforest/serialize/json/v1/V1JsonToV2StateConverter.java

+39-20
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,36 @@
2727
import com.amazon.randomcutforest.state.RandomCutForestState;
2828
import com.amazon.randomcutforest.state.sampler.CompactSamplerState;
2929
import com.amazon.randomcutforest.state.store.PointStoreDoubleMapper;
30+
import com.amazon.randomcutforest.state.store.PointStoreFloatMapper;
3031
import com.amazon.randomcutforest.state.store.PointStoreState;
32+
import com.amazon.randomcutforest.store.IPointStore;
3133
import com.amazon.randomcutforest.store.PointStoreDouble;
34+
import com.amazon.randomcutforest.store.PointStoreFloat;
3235
import com.amazon.randomcutforest.tree.CompactRandomCutTreeDouble;
36+
import com.amazon.randomcutforest.tree.CompactRandomCutTreeFloat;
37+
import com.amazon.randomcutforest.tree.ITree;
3338
import com.fasterxml.jackson.databind.ObjectMapper;
3439

3540
public class V1JsonToV2StateConverter {
3641

3742
private final ObjectMapper mapper = new ObjectMapper();
3843

39-
public RandomCutForestState convert(String json) throws IOException {
44+
public RandomCutForestState convert(String json, Precision precision) throws IOException {
4045
V1SerializedRandomCutForest forest = mapper.readValue(json, V1SerializedRandomCutForest.class);
41-
return convert(forest);
46+
return convert(forest, precision);
4247
}
4348

44-
public RandomCutForestState convert(Reader reader) throws IOException {
49+
public RandomCutForestState convert(Reader reader, Precision precision) throws IOException {
4550
V1SerializedRandomCutForest forest = mapper.readValue(reader, V1SerializedRandomCutForest.class);
46-
return convert(forest);
51+
return convert(forest, precision);
4752
}
4853

49-
public RandomCutForestState convert(URL url) throws IOException {
54+
public RandomCutForestState convert(URL url, Precision precision) throws IOException {
5055
V1SerializedRandomCutForest forest = mapper.readValue(url, V1SerializedRandomCutForest.class);
51-
return convert(forest);
56+
return convert(forest, precision);
5257
}
5358

54-
public RandomCutForestState convert(V1SerializedRandomCutForest serializedForest) {
59+
public RandomCutForestState convert(V1SerializedRandomCutForest serializedForest, Precision precision) {
5560
RandomCutForestState state = new RandomCutForestState();
5661
state.setNumberOfTrees(serializedForest.getNumberOfTrees());
5762
state.setDimensions(serializedForest.getDimensions());
@@ -68,7 +73,7 @@ public RandomCutForestState convert(V1SerializedRandomCutForest serializedForest
6873
state.setSaveSamplerStateEnabled(true);
6974
state.setSaveTreeStateEnabled(false);
7075
state.setSaveCoordinatorStateEnabled(true);
71-
state.setPrecision(Precision.FLOAT_64.name());
76+
state.setPrecision(precision.name());
7277
state.setCompressed(false);
7378
state.setPartialTreeState(false);
7479

@@ -78,35 +83,49 @@ public RandomCutForestState convert(V1SerializedRandomCutForest serializedForest
7883
state.setExecutionContext(executionContext);
7984

8085
SamplerConverter samplerConverter = new SamplerConverter(state.getDimensions(),
81-
state.getNumberOfTrees() * state.getSampleSize() + 1);
86+
state.getNumberOfTrees() * state.getSampleSize() + 1, precision);
8287

8388
Arrays.stream(serializedForest.getExecutor().getExecutor().getTreeUpdaters())
8489
.map(V1SerializedRandomCutForest.TreeUpdater::getSampler).forEach(samplerConverter::addSampler);
8590

86-
state.setPointStoreState(samplerConverter.getPointStoreState());
91+
state.setPointStoreState(samplerConverter.getPointStoreState(precision));
8792
state.setCompactSamplerStates(samplerConverter.compactSamplerStates);
8893

8994
return state;
9095
}
9196

9297
static class SamplerConverter {
93-
private final PointStoreDouble pointStore;
98+
private final IPointStore pointStore;
9499
private final List<CompactSamplerState> compactSamplerStates;
95-
96-
public SamplerConverter(int dimensions, int capacity) {
97-
pointStore = new PointStoreDouble(dimensions, capacity);
100+
private final Precision precision;
101+
private final ITree globalTree;
102+
103+
public SamplerConverter(int dimensions, int capacity, Precision precision) {
104+
if (precision == Precision.FLOAT_64) {
105+
pointStore = new PointStoreDouble(dimensions, capacity);
106+
globalTree = new CompactRandomCutTreeDouble.Builder().pointStore(pointStore)
107+
.maxSize(pointStore.getCapacity() + 1).storeSequenceIndexesEnabled(false)
108+
.centerOfMassEnabled(false).boundingBoxCacheFraction(1.0).build();
109+
} else {
110+
pointStore = new PointStoreFloat(dimensions, capacity);
111+
globalTree = new CompactRandomCutTreeFloat.Builder().pointStore(pointStore)
112+
.maxSize(pointStore.getCapacity() + 1).storeSequenceIndexesEnabled(false)
113+
.centerOfMassEnabled(false).boundingBoxCacheFraction(1.0).build();
114+
}
98115
compactSamplerStates = new ArrayList<>();
116+
this.precision = precision;
99117
}
100118

101-
public PointStoreState getPointStoreState() {
102-
return new PointStoreDoubleMapper().toState(pointStore);
119+
public PointStoreState getPointStoreState(Precision precision) {
120+
if (precision == Precision.FLOAT_64) {
121+
return new PointStoreDoubleMapper().toState((PointStoreDouble) pointStore);
122+
} else {
123+
return new PointStoreFloatMapper().toState((PointStoreFloat) pointStore);
124+
}
103125
}
104126

105127
public void addSampler(V1SerializedRandomCutForest.Sampler sampler) {
106128
V1SerializedRandomCutForest.WeightedSamples[] samples = sampler.getWeightedSamples();
107-
CompactRandomCutTreeDouble tree = new CompactRandomCutTreeDouble.Builder().pointStore(pointStore)
108-
.storeSequenceIndexesEnabled(false).centerOfMassEnabled(false).boundingBoxCacheFraction(1.0)
109-
.build();
110129
int[] pointIndex = new int[samples.length];
111130
float[] weight = new float[samples.length];
112131
long[] sequenceIndex = new long[samples.length];
@@ -115,7 +134,7 @@ public void addSampler(V1SerializedRandomCutForest.Sampler sampler) {
115134
V1SerializedRandomCutForest.WeightedSamples sample = samples[i];
116135
double[] point = sample.getPoint();
117136
int index = pointStore.add(point, sample.getSequenceIndex());
118-
pointIndex[i] = tree.addPoint(index, 0L);
137+
pointIndex[i] = (Integer) globalTree.addPoint(index, 0L);
119138
if (pointIndex[i] != index) {
120139
pointStore.incrementRefCount(pointIndex[i]);
121140
pointStore.decrementRefCount(index);

Java/serialization/src/test/java/com/amazon/randomcutforest/serialize/json/v1/V1JsonToV2StateConverterTest.java

+25-5
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,18 @@
2525
import java.io.InputStreamReader;
2626
import java.nio.charset.StandardCharsets;
2727
import java.util.Random;
28+
import java.util.stream.Stream;
2829

2930
import org.junit.jupiter.api.BeforeEach;
3031
import org.junit.jupiter.params.ParameterizedTest;
31-
import org.junit.jupiter.params.provider.EnumSource;
32+
import org.junit.jupiter.params.provider.Arguments;
33+
import org.junit.jupiter.params.provider.MethodSource;
3234

3335
import com.amazon.randomcutforest.RandomCutForest;
36+
import com.amazon.randomcutforest.config.Precision;
3437
import com.amazon.randomcutforest.state.RandomCutForestMapper;
3538
import com.amazon.randomcutforest.state.RandomCutForestState;
39+
import com.fasterxml.jackson.databind.ObjectMapper;
3640

3741
public class V1JsonToV2StateConverterTest {
3842

@@ -44,8 +48,8 @@ public void setUp() {
4448
}
4549

4650
@ParameterizedTest
47-
@EnumSource(V1JsonResource.class)
48-
public void testConvert(V1JsonResource jsonResource) {
51+
@MethodSource("args")
52+
public void testConvert(V1JsonResource jsonResource, Precision precision) {
4953
String resource = jsonResource.getResource();
5054
try (InputStream is = V1JsonToV2StateConverterTest.class.getResourceAsStream(jsonResource.getResource());
5155
BufferedReader rr = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8));) {
@@ -57,7 +61,7 @@ public void testConvert(V1JsonResource jsonResource) {
5761
}
5862

5963
String json = b.toString();
60-
RandomCutForestState state = converter.convert(json);
64+
RandomCutForestState state = converter.convert(json, precision);
6165

6266
assertEquals(jsonResource.getDimensions(), state.getDimensions());
6367
assertEquals(jsonResource.getNumberOfTrees(), state.getNumberOfTrees());
@@ -72,12 +76,15 @@ public void testConvert(V1JsonResource jsonResource) {
7276
// with a few points
7377

7478
Random random = new Random(0);
75-
for (int i = 0; i < 10; i++) {
79+
for (int i = 0; i < 100; i++) {
7680
double[] point = getPoint(jsonResource.getDimensions(), random);
7781
double score = forest.getAnomalyScore(point);
7882
assertTrue(score > 0);
7983
forest.update(point);
8084
}
85+
String newString = new ObjectMapper().writeValueAsString(new RandomCutForestMapper().toState(forest));
86+
System.out.println(" Old size " + json.length() + ", new Size " + newString.length()
87+
+ ", improvement factor " + json.length() / newString.length());
8188
} catch (IOException e) {
8289
fail("Unable to load JSON resource");
8390
}
@@ -90,4 +97,17 @@ private double[] getPoint(int dimensions, Random random) {
9097
}
9198
return point;
9299
}
100+
101+
static Stream<Arguments> args() {
102+
return jsonParams().flatMap(
103+
classParameter -> precision().map(testParameter -> Arguments.of(classParameter, testParameter)));
104+
}
105+
106+
static Stream<Precision> precision() {
107+
return Stream.of(Precision.FLOAT_32, Precision.FLOAT_64);
108+
}
109+
110+
static Stream<V1JsonResource> jsonParams() {
111+
return Stream.of(V1JsonResource.FOREST_1, V1JsonResource.FOREST_2);
112+
}
93113
}

0 commit comments

Comments
 (0)