Skip to content

Commit 5a64ddb

Browse files
authored
first steps of forecast (#330)
* forecast error ranges * changes, fixing the dependence on centrality + tests * random seed * changes
1 parent 0b6a678 commit 5a64ddb

File tree

7 files changed

+350
-59
lines changed

7 files changed

+350
-59
lines changed

Java/core/src/main/java/com/amazon/randomcutforest/RandomCutForest.java

+43-15
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
import com.amazon.randomcutforest.returntypes.Neighbor;
6161
import com.amazon.randomcutforest.returntypes.OneSidedConvergingDiVectorAccumulator;
6262
import com.amazon.randomcutforest.returntypes.OneSidedConvergingDoubleAccumulator;
63+
import com.amazon.randomcutforest.returntypes.RangeVector;
6364
import com.amazon.randomcutforest.returntypes.SampleSummary;
6465
import com.amazon.randomcutforest.sampler.CompactSampler;
6566
import com.amazon.randomcutforest.sampler.IStreamSampler;
@@ -568,10 +569,11 @@ public float[] lastShingledPoint() {
568569

569570
/**
570571
*
571-
* @return the sequence index of the last known shingled point
572+
* @return the sequence index of the last known shingled point. If internal
573+
* shingling is not enabled, then this would correspond to the number of
574+
* updates
572575
*/
573576
public long nextSequenceIndex() {
574-
checkArgument(internalShinglingEnabled, "incorrect use");
575577
return stateCoordinator.getStore().getNextSequenceIndex();
576578
}
577579

@@ -1001,7 +1003,8 @@ public List<ConditionalTreeSample> getConditionalField(float[] point, int number
10011003

10021004
int[] liftedIndices = transformIndices(missingIndexes, point.length);
10031005
IMultiVisitorFactory<ConditionalTreeSample> visitorFactory = (tree, y) -> new ImputeVisitor(y,
1004-
tree.projectToTree(y), liftedIndices, tree.projectMissingIndices(liftedIndices), centrality);
1006+
tree.projectToTree(y), liftedIndices, tree.projectMissingIndices(liftedIndices), centrality,
1007+
tree.getRandomSeed());
10051008
return traverseForestMulti(transformToShingledPoint(point), visitorFactory, ConditionalTreeSample.collector);
10061009
}
10071010

@@ -1072,25 +1075,41 @@ public double[] extrapolateBasic(double[] point, int horizon, int blockSize, boo
10721075
}
10731076

10741077
public float[] extrapolateBasic(float[] point, int horizon, int blockSize, boolean cyclic, int shingleIndex) {
1078+
return extrapolateWithRanges(point, horizon, blockSize, cyclic, shingleIndex, 1.0).values;
1079+
}
1080+
1081+
// the following is provided for maximum flexibilty from the calling entity;
1082+
// but likely use is extrapolateFromShingle(), which abstracts away rotation
1083+
// etc.
1084+
public RangeVector extrapolateWithRanges(float[] point, int horizon, int blockSize, boolean cyclic,
1085+
int shingleIndex, double centrality) {
10751086
checkArgument(0 < blockSize && blockSize < dimensions,
10761087
"blockSize must be between 0 and dimensions (exclusive)");
10771088
checkArgument(dimensions % blockSize == 0, "dimensions must be evenly divisible by blockSize");
10781089
checkArgument(0 <= shingleIndex && shingleIndex < dimensions / blockSize,
10791090
"shingleIndex must be between 0 (inclusive) and dimensions / blockSize");
10801091

1081-
float[] result = new float[blockSize * horizon];
1092+
RangeVector result = new RangeVector(blockSize * horizon);
10821093
int[] missingIndexes = new int[blockSize];
10831094
float[] queryPoint = Arrays.copyOf(point, dimensions);
10841095

10851096
if (cyclic) {
1086-
extrapolateBasicCyclic(result, horizon, blockSize, shingleIndex, queryPoint, missingIndexes);
1097+
extrapolateBasicCyclic(result, horizon, blockSize, shingleIndex, queryPoint, missingIndexes, centrality);
10871098
} else {
1088-
extrapolateBasicSliding(result, horizon, blockSize, queryPoint, missingIndexes);
1099+
extrapolateBasicSliding(result, horizon, blockSize, queryPoint, missingIndexes, centrality);
10891100
}
10901101

10911102
return result;
10921103
}
10931104

1105+
// external management of shingle; can function for both internal and external
1106+
// shingling
1107+
// however blocksize has to be externally managed
1108+
public RangeVector extrapolateFromShingle(float[] shingle, int horizon, int blockSize, double centrality) {
1109+
return extrapolateWithRanges(shingle, horizon, blockSize, isRotationEnabled(),
1110+
((int) nextSequenceIndex()) % shingleSize, centrality);
1111+
}
1112+
10941113
/**
10951114
* Given an initial shingled point, extrapolate the stream into the future to
10961115
* produce a forecast. This method is intended to be called when the input data
@@ -1130,7 +1149,8 @@ public double[] extrapolateBasic(ShingleBuilder builder, int horizon) {
11301149
builder.getShingleIndex());
11311150
}
11321151

1133-
void extrapolateBasicSliding(float[] result, int horizon, int blockSize, float[] queryPoint, int[] missingIndexes) {
1152+
void extrapolateBasicSliding(RangeVector result, int horizon, int blockSize, float[] queryPoint,
1153+
int[] missingIndexes, double centrality) {
11341154
int resultIndex = 0;
11351155

11361156
Arrays.fill(missingIndexes, 0);
@@ -1142,16 +1162,20 @@ void extrapolateBasicSliding(float[] result, int horizon, int blockSize, float[]
11421162
// shift all entries in the query point left by 1 block
11431163
System.arraycopy(queryPoint, blockSize, queryPoint, 0, dimensions - blockSize);
11441164

1145-
float[] imputedPoint = imputeMissingValues(queryPoint, blockSize, missingIndexes);
1165+
SampleSummary imputedSummary = getConditionalFieldSummary(queryPoint, blockSize, missingIndexes,
1166+
centrality);
11461167
for (int y = 0; y < blockSize; y++) {
1147-
result[resultIndex++] = queryPoint[dimensions - blockSize + y] = imputedPoint[dimensions - blockSize
1148-
+ y];
1168+
result.values[resultIndex] = queryPoint[dimensions - blockSize + y] = imputedSummary.median[dimensions
1169+
- blockSize + y];
1170+
result.lower[resultIndex] = imputedSummary.lower[dimensions - blockSize + y];
1171+
result.upper[resultIndex] = imputedSummary.upper[dimensions - blockSize + y];
1172+
resultIndex++;
11491173
}
11501174
}
11511175
}
11521176

1153-
void extrapolateBasicCyclic(float[] result, int horizon, int blockSize, int shingleIndex, float[] queryPoint,
1154-
int[] missingIndexes) {
1177+
void extrapolateBasicCyclic(RangeVector result, int horizon, int blockSize, int shingleIndex, float[] queryPoint,
1178+
int[] missingIndexes, double centrality) {
11551179

11561180
int resultIndex = 0;
11571181
int currentPosition = shingleIndex;
@@ -1162,11 +1186,15 @@ void extrapolateBasicCyclic(float[] result, int horizon, int blockSize, int shin
11621186
missingIndexes[y] = (currentPosition + y) % dimensions;
11631187
}
11641188

1165-
float[] imputedPoint = imputeMissingValues(queryPoint, blockSize, missingIndexes);
1189+
SampleSummary imputedSummary = getConditionalFieldSummary(queryPoint, blockSize, missingIndexes,
1190+
centrality);
11661191

11671192
for (int y = 0; y < blockSize; y++) {
1168-
result[resultIndex++] = queryPoint[(currentPosition + y)
1169-
% dimensions] = imputedPoint[(currentPosition + y) % dimensions];
1193+
result.values[resultIndex] = queryPoint[(currentPosition + y)
1194+
% dimensions] = imputedSummary.median[(currentPosition + y) % dimensions];
1195+
result.lower[resultIndex] = imputedSummary.lower[(currentPosition + y) % dimensions];
1196+
result.upper[resultIndex] = imputedSummary.upper[(currentPosition + y) % dimensions];
1197+
resultIndex++;
11701198
}
11711199

11721200
currentPosition = (currentPosition + blockSize) % dimensions;

Java/core/src/main/java/com/amazon/randomcutforest/imputation/ImputeVisitor.java

+21-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import static com.amazon.randomcutforest.CommonUtils.checkArgument;
1919

2020
import java.util.Arrays;
21+
import java.util.Random;
2122

2223
import com.amazon.randomcutforest.CommonUtils;
2324
import com.amazon.randomcutforest.MultiVisitor;
@@ -66,6 +67,10 @@ public class ImputeVisitor implements MultiVisitor<ConditionalTreeSample> {
6667
*/
6768
protected double centrality;
6869

70+
protected long randomSeed;
71+
72+
protected double randomRank;
73+
6974
protected boolean converged;
7075

7176
protected int pointIndex;
@@ -84,10 +89,11 @@ public class ImputeVisitor implements MultiVisitor<ConditionalTreeSample> {
8489
* space
8590
*/
8691
public ImputeVisitor(float[] liftedPoint, float[] queryPoint, int[] liftedMissingIndexes, int[] missingIndexes,
87-
double centrality) {
92+
double centrality, long randomSeed) {
8893
this.queryPoint = Arrays.copyOf(queryPoint, queryPoint.length);
8994
this.missing = new boolean[queryPoint.length];
9095
this.centrality = centrality;
96+
this.randomSeed = randomSeed;
9197
this.dimensionsUsed = new int[queryPoint.length];
9298

9399
if (missingIndexes == null) {
@@ -108,7 +114,7 @@ public ImputeVisitor(float[] liftedPoint, float[] queryPoint, int[] liftedMissin
108114
public ImputeVisitor(float[] queryPoint, int numberOfMissingIndices, int[] missingIndexes) {
109115
this(queryPoint, Arrays.copyOf(queryPoint, queryPoint.length),
110116
Arrays.copyOf(missingIndexes, Math.min(numberOfMissingIndices, missingIndexes.length)),
111-
Arrays.copyOf(missingIndexes, Math.min(numberOfMissingIndices, missingIndexes.length)), 1.0);
117+
Arrays.copyOf(missingIndexes, Math.min(numberOfMissingIndices, missingIndexes.length)), 1.0, 0L);
112118
}
113119

114120
/**
@@ -121,6 +127,8 @@ public ImputeVisitor(float[] queryPoint, int numberOfMissingIndices, int[] missi
121127
this.queryPoint = Arrays.copyOf(original.queryPoint, length);
122128
this.missing = Arrays.copyOf(original.missing, length);
123129
this.dimensionsUsed = new int[original.dimensionsUsed.length];
130+
this.randomSeed = new Random(original.randomSeed).nextLong();
131+
this.centrality = original.centrality;
124132
anomalyRank = DEFAULT_INIT_VALUE;
125133
distance = DEFAULT_INIT_VALUE;
126134
}
@@ -174,6 +182,12 @@ public void acceptLeaf(final INodeView leafNode, final int depthOfNode) {
174182
}
175183
}
176184

185+
if (centrality < 1.0) {
186+
Random rng = new Random(randomSeed);
187+
randomSeed = rng.nextLong();
188+
randomRank = rng.nextDouble();
189+
}
190+
177191
this.distance = distance;
178192
if (distance <= 0) {
179193
converged = true;
@@ -226,8 +240,12 @@ public MultiVisitor<ConditionalTreeSample> newCopy() {
226240
return new ImputeVisitor(this);
227241
}
228242

243+
double adjustedRank() {
244+
return (1 - centrality) * randomRank + centrality * anomalyRank;
245+
}
246+
229247
protected boolean updateCombine(ImputeVisitor other) {
230-
return other.anomalyRank < anomalyRank;
248+
return other.adjustedRank() < adjustedRank();
231249
}
232250

233251
/**
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package com.amazon.randomcutforest.returntypes;
17+
18+
import static com.amazon.randomcutforest.CommonUtils.checkArgument;
19+
20+
import java.util.Arrays;
21+
22+
/**
23+
* A RangeVector is used when we want to track a quantity and its upper and
24+
* lower bounds
25+
*/
26+
public class RangeVector {
27+
28+
public final float[] values;
29+
30+
/**
31+
* An array of values corresponding to the upper ranges in each dimension.
32+
*/
33+
public final float[] upper;
34+
/**
35+
* An array of values corresponding to the lower ranges in each dimension
36+
*/
37+
public final float[] lower;
38+
39+
public RangeVector(int dimensions) {
40+
checkArgument(dimensions > 0, "dimensions must be greater than 0");
41+
values = new float[dimensions];
42+
upper = new float[dimensions];
43+
lower = new float[dimensions];
44+
}
45+
46+
/**
47+
* Construct a new RangeVector with the given number of spatial dimensions.
48+
*
49+
* @param values the values being estimated in a range
50+
* @param upper the higher values of the ranges
51+
* @param lower the lower values in the ranges
52+
*/
53+
public RangeVector(float[] values, float[] upper, float[] lower) {
54+
checkArgument(values.length > 0, " dimensions must be > 0");
55+
checkArgument(values.length == upper.length && upper.length == lower.length, "dimensions must be equal");
56+
this.values = Arrays.copyOf(values, values.length);
57+
this.upper = Arrays.copyOf(upper, upper.length);
58+
this.lower = Arrays.copyOf(lower, lower.length);
59+
}
60+
61+
public RangeVector(float[] values) {
62+
checkArgument(values.length > 0, "dimensions must be > 0 ");
63+
this.values = Arrays.copyOf(values, values.length);
64+
this.upper = Arrays.copyOf(values, values.length);
65+
this.lower = Arrays.copyOf(values, values.length);
66+
}
67+
68+
/**
69+
* Create a deep copy of the base RangeVector.
70+
*
71+
* @param base The RangeVector to copy.
72+
*/
73+
public RangeVector(RangeVector base) {
74+
int dimensions = base.values.length;
75+
this.values = Arrays.copyOf(base.values, dimensions);
76+
this.upper = Arrays.copyOf(base.upper, dimensions);
77+
this.lower = Arrays.copyOf(base.lower, dimensions);
78+
}
79+
}

0 commit comments

Comments
 (0)