27
27
import com .amazon .randomcutforest .state .RandomCutForestState ;
28
28
import com .amazon .randomcutforest .state .sampler .CompactSamplerState ;
29
29
import com .amazon .randomcutforest .state .store .PointStoreDoubleMapper ;
30
+ import com .amazon .randomcutforest .state .store .PointStoreFloatMapper ;
30
31
import com .amazon .randomcutforest .state .store .PointStoreState ;
32
+ import com .amazon .randomcutforest .store .IPointStore ;
31
33
import com .amazon .randomcutforest .store .PointStoreDouble ;
34
+ import com .amazon .randomcutforest .store .PointStoreFloat ;
32
35
import com .amazon .randomcutforest .tree .CompactRandomCutTreeDouble ;
36
+ import com .amazon .randomcutforest .tree .CompactRandomCutTreeFloat ;
37
+ import com .amazon .randomcutforest .tree .ITree ;
33
38
import com .fasterxml .jackson .databind .ObjectMapper ;
34
39
35
40
public class V1JsonToV2StateConverter {
36
41
37
42
private final ObjectMapper mapper = new ObjectMapper ();
38
43
39
- public RandomCutForestState convert (String json ) throws IOException {
44
+ public RandomCutForestState convert (String json , Precision precision ) throws IOException {
40
45
V1SerializedRandomCutForest forest = mapper .readValue (json , V1SerializedRandomCutForest .class );
41
- return convert (forest );
46
+ return convert (forest , precision );
42
47
}
43
48
44
- public RandomCutForestState convert (Reader reader ) throws IOException {
49
+ public RandomCutForestState convert (Reader reader , Precision precision ) throws IOException {
45
50
V1SerializedRandomCutForest forest = mapper .readValue (reader , V1SerializedRandomCutForest .class );
46
- return convert (forest );
51
+ return convert (forest , precision );
47
52
}
48
53
49
- public RandomCutForestState convert (URL url ) throws IOException {
54
+ public RandomCutForestState convert (URL url , Precision precision ) throws IOException {
50
55
V1SerializedRandomCutForest forest = mapper .readValue (url , V1SerializedRandomCutForest .class );
51
- return convert (forest );
56
+ return convert (forest , precision );
52
57
}
53
58
54
- public RandomCutForestState convert (V1SerializedRandomCutForest serializedForest ) {
59
+ public RandomCutForestState convert (V1SerializedRandomCutForest serializedForest , Precision precision ) {
55
60
RandomCutForestState state = new RandomCutForestState ();
56
61
state .setNumberOfTrees (serializedForest .getNumberOfTrees ());
57
62
state .setDimensions (serializedForest .getDimensions ());
@@ -68,7 +73,7 @@ public RandomCutForestState convert(V1SerializedRandomCutForest serializedForest
68
73
state .setSaveSamplerStateEnabled (true );
69
74
state .setSaveTreeStateEnabled (false );
70
75
state .setSaveCoordinatorStateEnabled (true );
71
- state .setPrecision (Precision . FLOAT_64 .name ());
76
+ state .setPrecision (precision .name ());
72
77
state .setCompressed (false );
73
78
state .setPartialTreeState (false );
74
79
@@ -78,35 +83,49 @@ public RandomCutForestState convert(V1SerializedRandomCutForest serializedForest
78
83
state .setExecutionContext (executionContext );
79
84
80
85
SamplerConverter samplerConverter = new SamplerConverter (state .getDimensions (),
81
- state .getNumberOfTrees () * state .getSampleSize () + 1 );
86
+ state .getNumberOfTrees () * state .getSampleSize () + 1 , precision );
82
87
83
88
Arrays .stream (serializedForest .getExecutor ().getExecutor ().getTreeUpdaters ())
84
89
.map (V1SerializedRandomCutForest .TreeUpdater ::getSampler ).forEach (samplerConverter ::addSampler );
85
90
86
- state .setPointStoreState (samplerConverter .getPointStoreState ());
91
+ state .setPointStoreState (samplerConverter .getPointStoreState (precision ));
87
92
state .setCompactSamplerStates (samplerConverter .compactSamplerStates );
88
93
89
94
return state ;
90
95
}
91
96
92
97
static class SamplerConverter {
93
- private final PointStoreDouble pointStore ;
98
+ private final IPointStore pointStore ;
94
99
private final List <CompactSamplerState > compactSamplerStates ;
95
-
96
- public SamplerConverter (int dimensions , int capacity ) {
97
- pointStore = new PointStoreDouble (dimensions , capacity );
100
+ private final Precision precision ;
101
+ private final ITree globalTree ;
102
+
103
+ public SamplerConverter (int dimensions , int capacity , Precision precision ) {
104
+ if (precision == Precision .FLOAT_64 ) {
105
+ pointStore = new PointStoreDouble (dimensions , capacity );
106
+ globalTree = new CompactRandomCutTreeDouble .Builder ().pointStore (pointStore )
107
+ .maxSize (pointStore .getCapacity () + 1 ).storeSequenceIndexesEnabled (false )
108
+ .centerOfMassEnabled (false ).boundingBoxCacheFraction (1.0 ).build ();
109
+ } else {
110
+ pointStore = new PointStoreFloat (dimensions , capacity );
111
+ globalTree = new CompactRandomCutTreeFloat .Builder ().pointStore (pointStore )
112
+ .maxSize (pointStore .getCapacity () + 1 ).storeSequenceIndexesEnabled (false )
113
+ .centerOfMassEnabled (false ).boundingBoxCacheFraction (1.0 ).build ();
114
+ }
98
115
compactSamplerStates = new ArrayList <>();
116
+ this .precision = precision ;
99
117
}
100
118
101
- public PointStoreState getPointStoreState () {
102
- return new PointStoreDoubleMapper ().toState (pointStore );
119
+ public PointStoreState getPointStoreState (Precision precision ) {
120
+ if (precision == Precision .FLOAT_64 ) {
121
+ return new PointStoreDoubleMapper ().toState ((PointStoreDouble ) pointStore );
122
+ } else {
123
+ return new PointStoreFloatMapper ().toState ((PointStoreFloat ) pointStore );
124
+ }
103
125
}
104
126
105
127
public void addSampler (V1SerializedRandomCutForest .Sampler sampler ) {
106
128
V1SerializedRandomCutForest .WeightedSamples [] samples = sampler .getWeightedSamples ();
107
- CompactRandomCutTreeDouble tree = new CompactRandomCutTreeDouble .Builder ().pointStore (pointStore )
108
- .storeSequenceIndexesEnabled (false ).centerOfMassEnabled (false ).boundingBoxCacheFraction (1.0 )
109
- .build ();
110
129
int [] pointIndex = new int [samples .length ];
111
130
float [] weight = new float [samples .length ];
112
131
long [] sequenceIndex = new long [samples .length ];
@@ -115,7 +134,7 @@ public void addSampler(V1SerializedRandomCutForest.Sampler sampler) {
115
134
V1SerializedRandomCutForest .WeightedSamples sample = samples [i ];
116
135
double [] point = sample .getPoint ();
117
136
int index = pointStore .add (point , sample .getSequenceIndex ());
118
- pointIndex [i ] = tree .addPoint (index , 0L );
137
+ pointIndex [i ] = ( Integer ) globalTree .addPoint (index , 0L );
119
138
if (pointIndex [i ] != index ) {
120
139
pointStore .incrementRefCount (pointIndex [i ]);
121
140
pointStore .decrementRefCount (index );
0 commit comments