Skip to content

Commit 32de396

Browse files
authored
rust impute v1 (#302)
1 parent 789a1ae commit 32de396

19 files changed

+1463
-230
lines changed

Rust/Cargo.toml

+12
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,23 @@ name = "rcf"
33
version = "3.0.0"
44
edition = "2021"
55

6+
[profile.test]
7+
opt-level = 3
8+
9+
[lib]
10+
name = "rcflib"
11+
path = "src/lib.rs"
12+
13+
[[bin]]
14+
name = "example"
15+
path = "src/example.rs"
616
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
717

18+
819
[dependencies]
920
num = "0.4"
1021
rayon = "1.5"
1122
rand = "*"
1223
rand_chacha = "0.3.0"
1324
rand_core = "0.6"
25+
parameterized_test = "0.1.0"
+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
2+
use num::abs;
3+
use crate::pointstore::PointStore;
4+
use crate::samplesummary::{SampleSummary, summarize};
5+
6+
fn project_missing(point: &Vec<f32>,position : &[usize]) -> Vec<f32> {
7+
position.iter().map(|i| point[*i]).collect()
8+
}
9+
10+
/// the following function is a conduit that summarizes the conditional samples derived from the trees
11+
/// The samples are denoted by (PointIndex, f32) where the PointIndex(usize) corresponds to the point identifier
12+
/// in the point store and the f32 associated with a scalar value (corresponding to weight)
13+
/// the field missing corresponds to the list of missing fields in the space of the full (potentially shingled) points
14+
/// centrality corresponds to the parameter which was used to derive the samples, and thus provides a mechanism for
15+
/// refined interpretation in summarization
16+
/// project corresponds to a boolean flag, determining whether we wish to focus on the missing fields only (project = true)
17+
/// or we focus on the entire space of (potentially shingled) points (in case of project = false) which have different
18+
/// and complementary uses.
19+
/// max_number corresponds to a parameter that controls the summarization -- in the current version this corresponds to
20+
/// an upper bound on the number of summary points in the SampleSummary construct
21+
///
22+
/// Note that the global, mean and median do not perform any weighting/pruning; whereas the summarize() performs on
23+
/// somewhat denoised data to provide a list of summary. Note further that summarize() is redundant (and skipped)
24+
/// when max_number = 0
25+
/// The combination appears to provide the best of all worlds with little performance overhead and can be
26+
/// used and reconfigured easily. In the fullness of time, it is possible to leverage a dynamic Kernel, since
27+
/// the entire PointStore is present and the PointStore is dynamic.
28+
#[repr(C)]
29+
pub struct FieldSummarizer {
30+
centrality: f64,
31+
project : bool,
32+
max_number : usize,
33+
distance : fn(&[f32],&[f32]) -> f64
34+
}
35+
36+
impl FieldSummarizer {
37+
pub fn new(centrality: f64, project: bool, max_number: usize, distance: fn(&[f32], &[f32]) -> f64) -> Self {
38+
FieldSummarizer {
39+
centrality,
40+
project,
41+
max_number,
42+
distance
43+
}
44+
}
45+
46+
pub fn summarize_list(&self, pointstore: &dyn PointStore, point_list_with_distance: &[(usize, f32)], missing: &[usize]) -> SampleSummary {
47+
let mut distance_list: Vec<f32> = point_list_with_distance.iter().map(|a| a.1)
48+
.collect();
49+
distance_list.sort_by(|a, b| a.partial_cmp(&b).unwrap());
50+
let mut threshold = 0.0;
51+
if self.centrality > 0.0 {
52+
let mut always_include = 0;
53+
while always_include < point_list_with_distance.len() && distance_list[always_include] == 0.0 {
54+
always_include += 1;
55+
}
56+
threshold = self.centrality * (distance_list[always_include + (distance_list.len() - always_include) / 3] +
57+
distance_list[always_include + (distance_list.len() - always_include) / 2]) as f64;
58+
}
59+
threshold += (1.0 - self.centrality) * distance_list[point_list_with_distance.len() - 1] as f64;
60+
61+
let total_weight = point_list_with_distance.len() as f64;
62+
let dimensions = if !self.project || missing.len() == 0 {
63+
pointstore.get_copy(point_list_with_distance[0].0).len()
64+
} else {
65+
missing.len()
66+
};
67+
let mut mean = vec![0.0f32; dimensions];
68+
let mut deviation = vec![0.0f32; dimensions];
69+
let mut sum_values_sq = vec![0.0f64; dimensions];
70+
let mut sum_values = vec![0.0f64; dimensions];
71+
let mut vec = Vec::new();
72+
for i in 0..point_list_with_distance.len() {
73+
let point = if !self.project || missing.len() == 0 {
74+
pointstore.get_copy(point_list_with_distance[i].0)
75+
} else {
76+
project_missing(&pointstore.get_copy(point_list_with_distance[i].0), &missing)
77+
};
78+
for j in 0..dimensions {
79+
sum_values[j] += point[j] as f64;
80+
sum_values_sq[j] += point[j] as f64 * point[j] as f64;
81+
}
82+
/// the else can be filtered further
83+
let weight: f32 = if point_list_with_distance[i].1 <= threshold as f32 {
84+
1.0
85+
} else {
86+
threshold as f32 / point_list_with_distance[i].1
87+
};
88+
89+
vec.push((point, weight));
90+
};
91+
92+
for j in 0..dimensions {
93+
mean[j] = (sum_values[j] / total_weight as f64) as f32;
94+
let t: f64 = sum_values_sq[j] / total_weight as f64 - sum_values[j] * sum_values[j] / (total_weight as f64 * total_weight as f64);
95+
deviation[j] = f64::sqrt(if t > 0.0 { t } else { 0.0 }) as f32;
96+
};
97+
let mut median = vec![0.0f32; dimensions];
98+
for j in 0..dimensions {
99+
let mut v: Vec<f32> = vec.iter().map(|x| x.0[j]).collect();
100+
v.sort_by(|a, b| a.partial_cmp(b).unwrap());
101+
median[j] = v[vec.len() / 2];
102+
};
103+
104+
let mut summary = summarize(&vec, self.distance, self.max_number);
105+
SampleSummary {
106+
summary_points: summary.summary_points.clone(),
107+
relative_weight: summary.relative_weight.clone(),
108+
total_weight: summary.total_weight,
109+
mean,
110+
median,
111+
deviation
112+
}
113+
}
114+
}

Rust/src/example.rs

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
2+
extern crate rand;
3+
extern crate rand_chacha;
4+
extern crate rcflib;
5+
use rand_chacha::ChaCha20Rng;
6+
use rand::{Rng, SeedableRng};
7+
8+
9+
use rcflib::multidimdatawithkey;
10+
use rcflib::multidimdatawithkey::MultiDimDataWithKey;
11+
use rcflib::rcf::{create_rcf, RCF};
12+
13+
fn main() {
14+
let shingle_size = 8;
15+
let base_dimension = 5;
16+
let data_size = 100000;
17+
let number_of_trees = 30;
18+
let capacity = 256;
19+
let initial_accept_fraction = 0.1;
20+
let dimensions = shingle_size * base_dimension;
21+
let _point_store_capacity = capacity * number_of_trees + 1;
22+
let time_decay = 0.1 / capacity as f64;
23+
let bounding_box_cache_fraction = 1.0;
24+
let random_seed = 17;
25+
let parallel_enabled: bool = false;
26+
let store_attributes: bool = false;
27+
let internal_shingling: bool = true;
28+
let internal_rotation = false;
29+
let noise = 5.0;
30+
31+
let mut forest: Box<dyn RCF> = create_rcf(
32+
dimensions,
33+
shingle_size,
34+
capacity,
35+
number_of_trees,
36+
random_seed,
37+
store_attributes,
38+
parallel_enabled,
39+
internal_shingling,
40+
internal_rotation,
41+
time_decay,
42+
initial_accept_fraction,
43+
bounding_box_cache_fraction,
44+
);
45+
46+
let mut rng = ChaCha20Rng::seed_from_u64(42);
47+
let mut amplitude = Vec::new();
48+
for _i in 0..base_dimension {
49+
amplitude.push( (1.0 + 0.2 * rng.gen::<f32>())*60.0);
50+
}
51+
52+
let data_with_key = multidimdatawithkey::MultiDimDataWithKey::multi_cosine(
53+
data_size,
54+
&vec![60;base_dimension],
55+
&amplitude,
56+
noise,
57+
0,
58+
base_dimension.into(),
59+
);
60+
61+
let mut score: f64 = 0.0;
62+
let _next_index = 0;
63+
let mut error = 0.0;
64+
let mut count = 0;
65+
66+
for i in 0..data_with_key.data.len() {
67+
68+
if (i > 200) {
69+
let next_values = forest.extrapolate(1);
70+
assert!(next_values.len() == base_dimension);
71+
error += next_values.iter().zip(&data_with_key.data[i]).map(|(x,y)| ((x-y) as f64 *(x-y) as f64)).sum::<f64>();
72+
count += base_dimension;
73+
}
74+
75+
let new_score = forest.score(&data_with_key.data[i]);
76+
//println!("{} {} score {}",y,i,new_score);
77+
/*
78+
if next_index < data_with_key.change_indices.len() && data_with_key.change_indices[next_index] == i {
79+
println!(" score at change {} position {} ", new_score, i);
80+
next_index += 1;
81+
}
82+
*/
83+
84+
score += new_score;
85+
forest.update(&data_with_key.data[i], 0);
86+
}
87+
88+
println!(
89+
"Average score {} ",
90+
(score / data_with_key.data.len() as f64)
91+
);
92+
println!("Success! {}", forest.get_entries_seen());
93+
println!("PointStore Size {} ", forest.get_point_store_size());
94+
println!("Total size {} bytes (approx)", forest.get_size());
95+
println!(" RMSE {}, noise {} ", f64::sqrt(error/count as f64), noise);
96+
}

0 commit comments

Comments
 (0)