@@ -64,18 +64,21 @@ public class ConditionalSampleSummarizer {
64
64
*/
65
65
protected boolean project = false ;
66
66
67
- public ConditionalSampleSummarizer (int [] missingDimensions , float [] queryPoint , double centrality ) {
68
- this .missingDimensions = Arrays .copyOf (missingDimensions , missingDimensions .length );
69
- this .queryPoint = Arrays .copyOf (queryPoint , queryPoint .length );
70
- this .centrality = centrality ;
71
- }
67
+ protected int numberOfReps = 1 ;
72
68
73
- public ConditionalSampleSummarizer (int [] missingDimensions , float [] queryPoint , double centrality ,
74
- boolean project ) {
69
+ protected double shrinkage = 0 ;
70
+
71
+ protected int shingleSize = 1 ;
72
+
73
+ public ConditionalSampleSummarizer (int [] missingDimensions , float [] queryPoint , double centrality , boolean project ,
74
+ int numberOfReps , double shrinkage , int shingleSize ) {
75
75
this .missingDimensions = Arrays .copyOf (missingDimensions , missingDimensions .length );
76
76
this .queryPoint = Arrays .copyOf (queryPoint , queryPoint .length );
77
77
this .centrality = centrality ;
78
78
this .project = project ;
79
+ this .numberOfReps = numberOfReps ;
80
+ this .shrinkage = shrinkage ;
81
+ this .shingleSize = shingleSize ;
79
82
}
80
83
81
84
public SampleSummary summarize (List <ConditionalTreeSample > alist ) {
@@ -102,21 +105,28 @@ public SampleSummary summarize(List<ConditionalTreeSample> alist, boolean addTyp
102
105
List <ConditionalTreeSample > newList = ConditionalTreeSample .dedup (alist );
103
106
104
107
newList .sort ((o1 , o2 ) -> Double .compare (o1 .distance , o2 .distance ));
108
+ int dimensions = queryPoint .length ;
105
109
106
- ArrayList <Weighted <float []>> points = new ArrayList <>();
107
- newList .stream ().forEach (e -> {
108
- if (!project ) {
109
- points .add (new Weighted <>(e .leafPoint , (float ) e .weight ));
110
- } else {
111
- float [] values = new float [missingDimensions .length ];
112
- for (int i = 0 ; i < missingDimensions .length ; i ++) {
113
- values [i ] = e .leafPoint [missingDimensions [i ]];
110
+ if (!addTypical ) {
111
+ ArrayList <Weighted <float []>> points = new ArrayList <>();
112
+ newList .stream ().forEach (e -> {
113
+ if (!project ) {
114
+ if (shingleSize == 1 ) {
115
+ points .add (new Weighted <>(e .leafPoint , (float ) e .weight ));
116
+ } else {
117
+ float [] values = Arrays .copyOfRange (e .leafPoint , dimensions - dimensions / shingleSize ,
118
+ dimensions );
119
+ points .add (new Weighted <>(values , (float ) e .weight ));
120
+ }
121
+ } else {
122
+ float [] values = new float [missingDimensions .length ];
123
+ for (int i = 0 ; i < missingDimensions .length ; i ++) {
124
+ values [i ] = e .leafPoint [missingDimensions [i ]];
125
+ }
126
+ points .add (new Weighted <>(values , (float ) e .weight ));
114
127
}
115
- points .add (new Weighted <>(values , (float ) e .weight ));
116
- }
117
- });
128
+ });
118
129
119
- if (!addTypical ) {
120
130
return new SampleSummary (points );
121
131
}
122
132
@@ -131,34 +141,37 @@ public SampleSummary summarize(List<ConditionalTreeSample> alist, boolean addTyp
131
141
* exact matches would go against the dynamic sampling based use of RCF.
132
142
**/
133
143
134
- int dimensions = queryPoint .length ;
135
-
136
- double threshold = centrality * newList .get (0 ).distance ;
137
- double currentWeight = 0 ;
138
- int alwaysInclude = 0 ;
139
- double remainderWeight = totalWeight ;
140
- while (newList .get (alwaysInclude ).distance == 0 ) {
141
- remainderWeight -= newList .get (alwaysInclude ).weight ;
142
- ++alwaysInclude ;
143
- if (alwaysInclude == newList .size ()) {
144
- break ;
144
+ int num = 0 ;
145
+ if (centrality > 0 ) {
146
+ double threshold = centrality * newList .get (0 ).distance + 1e-6 ;
147
+ double currentWeight = 0 ;
148
+ int alwaysInclude = 0 ;
149
+ double remainderWeight = totalWeight ;
150
+ while (newList .get (alwaysInclude ).distance == 0 ) {
151
+ remainderWeight -= newList .get (alwaysInclude ).weight ;
152
+ ++alwaysInclude ;
153
+ if (alwaysInclude == newList .size ()) {
154
+ break ;
155
+ }
145
156
}
146
- }
147
- for (int j = 1 ; j < newList .size (); j ++) {
148
- if ((currentWeight < remainderWeight / 3 && currentWeight + newList .get (j ).weight >= remainderWeight / 3 )
149
- || (currentWeight < remainderWeight / 2
150
- && currentWeight + newList .get (j ).weight >= remainderWeight / 2 )) {
151
- threshold = centrality * newList .get (j ).distance ;
157
+ for (int j = 1 ; j < newList .size (); j ++) {
158
+ if ((currentWeight < remainderWeight / 3
159
+ && currentWeight + newList .get (j ).weight >= remainderWeight / 3 )
160
+ || (currentWeight < remainderWeight / 2
161
+ && currentWeight + newList .get (j ).weight >= remainderWeight / 2 )) {
162
+ threshold = centrality * newList .get (j ).distance ;
163
+ }
164
+ currentWeight += newList .get (j ).weight ;
152
165
}
153
- currentWeight += newList . get ( j ). weight ;
154
- }
155
- // note that the threshold is currently centrality * (some distance in the list)
156
- // thus the sequel uses a convex combination; and setting centrality = 0 removes
157
- // the entire filtering based on distances
158
- threshold += ( 1 - centrality ) * newList . get ( newList . size () - 1 ). distance ;
159
- int num = 0 ;
160
- while ( num < newList . size () && newList . get ( num ). distance <= threshold ) {
161
- ++ num ;
166
+ // note that the threshold is currently centrality * (some distance in the list)
167
+ // thus the sequel uses a convex combination; and setting centrality = 0 removes
168
+ // the entire filtering based on distances
169
+ threshold += ( 1 - centrality ) * newList . get ( newList . size () - 1 ). distance ;
170
+ while ( num < newList . size () && newList . get ( num ). distance <= threshold ) {
171
+ ++ num ;
172
+ }
173
+ } else {
174
+ num = newList . size () ;
162
175
}
163
176
164
177
ArrayList <Weighted <float []>> typicalPoints = new ArrayList <>();
@@ -171,26 +184,21 @@ public SampleSummary summarize(List<ConditionalTreeSample> alist, boolean addTyp
171
184
values [i ] = e .leafPoint [missingDimensions [i ]];
172
185
}
173
186
} else {
174
- values = Arrays .copyOf (e .leafPoint , dimensions );
187
+ if (shingleSize == 1 ) {
188
+ values = e .leafPoint ;
189
+ } else {
190
+ values = Arrays .copyOfRange (e .leafPoint , dimensions - dimensions / shingleSize , dimensions );
191
+ }
175
192
}
176
193
typicalPoints .add (new Weighted <>(values , (float ) e .weight ));
177
194
}
178
195
int maxAllowed = min (queryPoint .length * MAX_NUMBER_OF_TYPICAL_PER_DIMENSION , MAX_NUMBER_OF_TYPICAL_ELEMENTS );
179
196
maxAllowed = min (maxAllowed , num );
180
- SampleSummary projectedSummary = Summarizer .l2summarize (typicalPoints , maxAllowed , num , false , 72 );
181
197
182
- float [][] pointList = new float [projectedSummary .summaryPoints .length ][];
183
- float [] likelihood = new float [projectedSummary .summaryPoints .length ];
184
-
185
- for (int i = 0 ; i < projectedSummary .summaryPoints .length ; i ++) {
186
- pointList [i ] = Arrays .copyOf (queryPoint , dimensions );
187
- for (int j = 0 ; j < missingDimensions .length ; j ++) {
188
- pointList [i ][missingDimensions [j ]] = projectedSummary .summaryPoints [i ][j ];
189
- }
190
- likelihood [i ] = projectedSummary .relativeWeight [i ];
191
- }
198
+ SampleSummary projectedSummary = Summarizer .summarize (typicalPoints , maxAllowed , num , false ,
199
+ Summarizer ::L2distance , 72 , false , numberOfReps , shrinkage );
192
200
193
- return new SampleSummary (points , pointList , likelihood );
201
+ return new SampleSummary (typicalPoints , projectedSummary );
194
202
}
195
203
196
204
}
0 commit comments