Skip to content

Commit

Permalink
Opt: memory: linear for [group] const values
Browse files Browse the repository at this point in the history
For memories with constant values that have sorts which are linear
groups, there is a way to optimize linear-scan memory-checking.

This patch implements that optimization.
  • Loading branch information
alex-ozdemir committed Aug 19, 2024
1 parent 2b54efa commit 968f92b
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 7 deletions.
23 changes: 23 additions & 0 deletions examples/ZoKrates/pf/const_linear_lookup.zok
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
struct T {
field v
field w
field x
field y
field z
}

const T[9] TABLE = [
T { v: 1, w: 12, x: 13, y: 14, z: 15 },
T { v: 2, w: 22, x: 23, y: 24, z: 25 },
T { v: 3, w: 32, x: 33, y: 34, z: 35 },
T { v: 4, w: 42, x: 43, y: 44, z: 45 },
T { v: 5, w: 52, x: 53, y: 54, z: 55 },
T { v: 6, w: 62, x: 63, y: 64, z: 65 },
T { v: 7, w: 72, x: 73, y: 74, z: 75 },
T { v: 8, w: 82, x: 83, y: 84, z: 85 },
T { v: 9, w: 92, x: 93, y: 94, z: 95 }
]

def main(field i) -> field:
T t = TABLE[i]
return t.v + t.w + t.x + t.y + t.z
5 changes: 5 additions & 0 deletions examples/circ.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,11 @@ fn main() {
"Final R1cs rounds: {}",
prover_data.precompute.stage_sizes().count() - 1
);
println!(
"Final Witext steps: {}, arguments: {}",
prover_data.precompute.num_steps(),
prover_data.precompute.num_step_args()
);
match action {
ProofAction::Count => (),
#[cfg(feature = "bellman")]
Expand Down
1 change: 1 addition & 0 deletions scripts/zokrates_test.zsh
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ function pf_test_isolate {
}

r1cs_test_count ./examples/ZoKrates/pf/mm4_cond.zok 120
r1cs_test_count ./examples/ZoKrates/pf/const_linear_lookup.zok 20
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsAdd.zok
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOnCurve.zok
r1cs_test ./third_party/ZoKrates/zokrates_stdlib/stdlib/ecc/edwardsOrderCheck.zok
Expand Down
46 changes: 41 additions & 5 deletions src/ir/opt/mem/lin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,47 @@ impl RewritePass for Linearizer {
.unwrap_or_else(|| a.val.default_term()),
)
} else {
let mut fields = (0..a.size).map(|idx| term![Op::Field(idx); tup.clone()]);
let first = fields.next().unwrap();
Some(a.key.elems_iter().take(a.size).skip(1).zip(fields).fold(first, |acc, (idx_c, field)| {
term![Op::Ite; term![Op::Eq; idx.clone(), idx_c], field, acc]
}))
let value_sort = check(tup).as_tuple()[0].clone();
if value_sort.is_group() {
// if values are a group
// then emit v0 + ite(idx == i1, v1 - v0, 0) + ... it(idx = iN, vN - v0, 0)
// where +, -, 0 are defined by the group.
//
// we do this because if the values are constant, then the above sum is
// linear, which is very nice for most backends.
let mut fields =
(0..a.size).map(|idx| term![Op::Field(idx); tup.clone()]);
let first = fields.next().unwrap();
let zero = value_sort.group_identity();
Some(
value_sort.group_add_nary(
std::iter::once(first.clone())
.chain(
a.key
.elems_iter()
.take(a.size)
.skip(1)
.zip(fields)
.map(|(idx_c, field)| {
term![Op::Ite;
term![Op::Eq; idx.clone(), idx_c],
value_sort.group_sub(field, first.clone()),
zero.clone()
]
}),
)
.collect(),
),
)
} else {
// otherwise, ite(idx == iN, vN, ... ite(idx == i1, v1, v0) ... )
let mut fields =
(0..a.size).map(|idx| term![Op::Field(idx); tup.clone()]);
let first = fields.next().unwrap();
Some(a.key.elems_iter().take(a.size).skip(1).zip(fields).fold(first, |acc, (idx_c, field)| {
term![Op::Ite; term![Op::Eq; idx.clone(), idx_c], field, acc]
}))
}
}
} else {
unreachable!()
Expand Down
11 changes: 9 additions & 2 deletions src/ir/opt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ pub enum Opt {
pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computations, optimizations: I) -> Computations {
for c in cs.comps.values() {
trace!("Before all opts: {}", text::serialize_computation(c));
info!("Before all opts: {} terms", c.stats().main.n_terms);
info!(
"Before all opts: {} terms",
c.stats().main.n_terms + c.stats().prec.n_terms
);
debug!("Before all opts: {:#?}", c.stats());
debug!("Before all opts: {:#?}", c.detailed_stats());
}
Expand Down Expand Up @@ -167,7 +170,11 @@ pub fn opt<I: IntoIterator<Item = Opt>>(mut cs: Computations, optimizations: I)
fits_in_bits_ip::fits_in_bits_ip(c);
}
}
info!("After {:?}: {} terms", i, c.stats().main.n_terms);
info!(
"After {:?}: {} terms",
i,
c.stats().main.n_terms + c.stats().prec.n_terms
);
debug!("After {:?}: {:#?}", i, c.stats());
trace!("After {:?}: {}", i, text::serialize_computation(c));
#[cfg(debug_assertions)]
Expand Down
87 changes: 87 additions & 0 deletions src/ir/term/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,93 @@ impl Sort {
pub fn is_scalar(&self) -> bool {
!matches!(self, Sort::Tuple(..) | Sort::Array(..) | Sort::Map(..))
}

/// Is this sort a group?
pub fn is_group(&self) -> bool {
match self {
Sort::BitVector(_) | Sort::Int | Sort::Field(_) | Sort::Bool => true,
Sort::F32 | Sort::F64 | Sort::Array(_) | Sort::Map(_) => false,
Sort::Tuple(fields) => fields.iter().all(|f| f.is_group()),
}
}

/// The (n-ary) group operation for these terms.
pub fn group_add_nary(&self, ts: Vec<Term>) -> Term {
debug_assert!(ts.iter().all(|t| &check(t) == self));
match self {
Sort::BitVector(_) => term(BV_ADD, ts),
Sort::Bool => term(XOR, ts),
Sort::Field(_) => term(PF_ADD, ts),
Sort::Int => term(INT_ADD, ts),
Sort::Tuple(sorts) => term(
Op::Tuple,
sorts
.iter()
.enumerate()
.map(|(i, sort)| {
sort.group_add_nary(
ts.iter()
.map(|t| term(Op::Field(i), vec![t.clone()]))
.collect(),
)
})
.collect(),
),
_ => panic!("Not a group: {}", self),
}
}

/// Group inverse
pub fn group_neg(&self, t: Term) -> Term {
debug_assert_eq!(&check(&t), self);
match self {
Sort::BitVector(_) => term(BV_NEG, vec![t]),
Sort::Bool => term(NOT, vec![t]),
Sort::Field(_) => term(PF_NEG, vec![t]),
Sort::Int => term(
INT_MUL,
vec![leaf_term(Op::new_const(Value::Int(Integer::from(-1i8)))), t],
),
Sort::Tuple(sorts) => term(
Op::Tuple,
sorts
.iter()
.enumerate()
.map(|(i, sort)| sort.group_neg(term(Op::Field(i), vec![t.clone()])))
.collect(),
),
_ => panic!("Not a group: {}", self),
}
}

/// Group identity
pub fn group_identity(&self) -> Term {
match self {
Sort::BitVector(n_bits) => bv_lit(0, *n_bits),
Sort::Bool => bool_lit(false),
Sort::Field(f) => pf_lit(f.new_v(0)),
Sort::Int => leaf_term(Op::new_const(Value::Int(Integer::from(0i8)))),
Sort::Tuple(sorts) => term(
Op::Tuple,
sorts.iter().map(|sort| sort.group_identity()).collect(),
),
_ => panic!("Not a group: {}", self),
}
}

/// Group operation
pub fn group_add(&self, s: Term, t: Term) -> Term {
debug_assert_eq!(&check(&s), self);
debug_assert_eq!(&check(&t), self);
self.group_add_nary(vec![s, t])
}

/// Group elimination
pub fn group_sub(&self, s: Term, t: Term) -> Term {
debug_assert_eq!(&check(&s), self);
debug_assert_eq!(&check(&t), self);
self.group_add(s, self.group_neg(t))
}
}

mod hc {
Expand Down
10 changes: 10 additions & 0 deletions src/target/r1cs/wit_comp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,16 @@ impl StagedWitComp {
pub fn num_stage_inputs(&self, n: usize) -> usize {
self.stages[n].inputs.len()
}

/// Number of steps
pub fn num_steps(&self) -> usize {
self.steps.len()
}

/// Number of step arguments
pub fn num_step_args(&self) -> usize {
self.step_args.len()
}
}

/// Evaluator interface
Expand Down

0 comments on commit 968f92b

Please sign in to comment.