diff --git a/Cargo.toml b/Cargo.toml index e09494942..ef196a92d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,7 +64,7 @@ approx = "0.5.0" default = [] # frontends c = ["lang-c"] -zok = ["zokrates_parser", "zokrates_pest_ast", "typed-arena", "petgraph"] +zok = ["smt", "zokrates_parser", "zokrates_pest_ast", "typed-arena", "petgraph"] datalog = ["pest", "pest-ast", "pest_derive", "from-pest", "lazy_static"] # backends smt = ["rsmt2", "ieee754"] diff --git a/circ_opt/README.md b/circ_opt/README.md index bbc8fdd5c..488c75a8b 100644 --- a/circ_opt/README.md +++ b/circ_opt/README.md @@ -129,6 +129,16 @@ Options: - waksman: Use the AS-Waksman network - msh: Use the (keyed) multi-set hash + --ram-rom + ROM approach + + [env: RAM_ROM=] + [default: haboeck] + + Possible values: + - haboeck: Use Haboeck's argument + - permute: Use permute-and-check + --fmt-use-default-field Which field to use @@ -210,6 +220,8 @@ Options: How to argue that indices are only repeated in blocks [env: RAM_INDEX=] [default: uniqueness] [possible values: sort, uniqueness] --ram-permutation How to argue that indices are only repeated in blocks [env: RAM_PERMUTATION=] [default: msh] [possible values: waksman, msh] + --ram-rom + ROM approach [env: RAM_ROM=] [default: haboeck] [possible values: haboeck, permute] --fmt-use-default-field Which field to use [env: FMT_USE_DEFAULT_FIELD=] [default: true] [possible values: true, false] --fmt-hide-field @@ -253,6 +265,7 @@ BinaryOpt { range: Sort, index: Uniqueness, permutation: Msh, + rom: Haboeck, }, fmt: FmtOpt { use_default_field: true, @@ -298,6 +311,7 @@ BinaryOpt { range: Sort, index: Uniqueness, permutation: Msh, + rom: Haboeck, }, fmt: FmtOpt { use_default_field: true, @@ -341,6 +355,7 @@ BinaryOpt { range: Sort, index: Uniqueness, permutation: Msh, + rom: Haboeck, }, fmt: FmtOpt { use_default_field: true, @@ -384,6 +399,7 @@ BinaryOpt { range: Sort, index: Uniqueness, permutation: Msh, + rom: Haboeck, }, fmt: FmtOpt { use_default_field: true, @@ -427,6 +443,7 @@ BinaryOpt { range: Sort, index: Uniqueness, permutation: Msh, + rom: Haboeck, }, fmt: FmtOpt { use_default_field: true, @@ -470,6 +487,7 @@ BinaryOpt { range: Sort, index: Uniqueness, permutation: Msh, + rom: Haboeck, }, fmt: FmtOpt { use_default_field: true, @@ -513,6 +531,7 @@ BinaryOpt { range: Sort, index: Uniqueness, permutation: Msh, + rom: Haboeck, }, fmt: FmtOpt { use_default_field: true, @@ -556,6 +575,7 @@ BinaryOpt { range: Sort, index: Uniqueness, permutation: Msh, + rom: Haboeck, }, fmt: FmtOpt { use_default_field: true, @@ -602,6 +622,7 @@ BinaryOpt { range: Sort, index: Uniqueness, permutation: Msh, + rom: Haboeck, }, fmt: FmtOpt { use_default_field: true, @@ -646,6 +667,7 @@ BinaryOpt { range: Sort, index: Uniqueness, permutation: Msh, + rom: Haboeck, }, fmt: FmtOpt { use_default_field: true, @@ -692,6 +714,7 @@ BinaryOpt { range: Sort, index: Uniqueness, permutation: Msh, + rom: Haboeck, }, fmt: FmtOpt { use_default_field: true, @@ -736,6 +759,7 @@ BinaryOpt { range: Sort, index: Uniqueness, permutation: Msh, + rom: Haboeck, }, fmt: FmtOpt { use_default_field: true, @@ -782,6 +806,7 @@ BinaryOpt { range: Sort, index: Uniqueness, permutation: Msh, + rom: Haboeck, }, fmt: FmtOpt { use_default_field: true, @@ -826,6 +851,7 @@ BinaryOpt { range: Sort, index: Uniqueness, permutation: Msh, + rom: Haboeck, }, fmt: FmtOpt { use_default_field: true, diff --git a/circ_opt/src/lib.rs b/circ_opt/src/lib.rs index 63ccced13..34a171286 100644 --- a/circ_opt/src/lib.rs +++ b/circ_opt/src/lib.rs @@ -239,6 +239,14 @@ pub struct RamOpt { default_value = "msh" )] pub permutation: PermutationStrategy, + /// ROM approach + #[arg( + long = "ram-rom", + env = "RAM_ROM", + value_enum, + default_value = "haboeck" + )] + pub rom: RomStrategy, } #[derive(ValueEnum, Debug, PartialEq, Eq, Clone, Copy)] @@ -286,6 +294,21 @@ impl Default for PermutationStrategy { } } +#[derive(ValueEnum, Debug, PartialEq, Eq, Clone, Copy)] +/// How to argue that accesses have been permuted +pub enum RomStrategy { + /// Use Haboeck's argument + Haboeck, + /// Use permute-and-check + Permute, +} + +impl Default for RomStrategy { + fn default() -> Self { + RomStrategy::Haboeck + } +} + /// Options for the prime field used #[derive(Args, Debug, Clone, PartialEq, Eq)] pub struct FmtOpt { diff --git a/examples/ZoKrates/pf/mem/rom.zok b/examples/ZoKrates/pf/mem/rom.zok new file mode 100644 index 000000000..25f09a151 --- /dev/null +++ b/examples/ZoKrates/pf/mem/rom.zok @@ -0,0 +1,22 @@ +const u32 VAL_LEN = 3 +const u32 RAM_LEN = 20 +const u32 ACCESSES = 400 + +struct Val { + field x + field y +} + +const transcript Val[RAM_LEN] array = [Val{x: 0, y: 0}, ...[Val{x: 10, y: 10}; RAM_LEN-1]] + +def main(private field[ACCESSES] y) -> field: + field result = 0 + + for u32 i in 0..ACCESSES do + Val v = array[y[i]] + result = result + v.x + v.y + endfor + return result + + + diff --git a/scripts/ram_test.zsh b/scripts/ram_test.zsh index 1b4a8ce90..18696ea94 100755 --- a/scripts/ram_test.zsh +++ b/scripts/ram_test.zsh @@ -50,6 +50,15 @@ function transcript_type_test { fi } +function cs_count_test { + ex_name=$1 + cs_upper_bound=$2 + rm -rf P V pi + output=$($BIN $ex_name r1cs --action count |& cat) + n_constraints=$(echo "$output" | grep 'Final R1cs size:' | grep -Eo '\b[0-9]+\b') + [[ $n_constraints -lt $cs_upper_bound ]] || (echo "Got $n_constraints, expected < $cs_upper_bound" && exit 1) +} + transcript_count_test ./examples/ZoKrates/pf/mem/volatile.zok 1 transcript_count_test ./examples/ZoKrates/pf/mem/two_level_ptr.zok 1 transcript_count_test ./examples/ZoKrates/pf/mem/volatile_struct.zok 1 @@ -59,6 +68,9 @@ transcript_count_test ./examples/ZoKrates/pf/mem/arr_of_str_of_arr.zok 1 transcript_type_test ./examples/ZoKrates/pf/mem/volatile_struct.zok "RAM" transcript_type_test ./examples/ZoKrates/pf/mem/two_level_ptr.zok "covering ROM" +# A=400; N=20; L=2; expected cost ~= N + A(L+1) = 1220 +cs_count_test ./examples/ZoKrates/pf/mem/rom.zok 1230 + ram_test ./examples/ZoKrates/pf/mem/two_level_ptr.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split" ram_test ./examples/ZoKrates/pf/mem/volatile.zok groth16 "--ram-permutation waksman --ram-index sort --ram-range bit-split" # waksman is broken for non-scalar array values diff --git a/src/ir/opt/mem/ram.rs b/src/ir/opt/mem/ram.rs index f340d0296..d00e25e1f 100644 --- a/src/ir/opt/mem/ram.rs +++ b/src/ir/opt/mem/ram.rs @@ -71,12 +71,13 @@ pub struct AccessCfg { split_times: bool, waksman: bool, covering_rom: bool, + haboeck: bool, } impl AccessCfg { /// Create a new configuration pub fn new(field: FieldT, opt: RamOpt, create: bool) -> Self { - use circ_opt::{IndexStrategy, PermutationStrategy, RangeStrategy}; + use circ_opt::{IndexStrategy, PermutationStrategy, RangeStrategy, RomStrategy}; Self { false_: bool_lit(false), true_: bool_lit(true), @@ -88,6 +89,7 @@ impl AccessCfg { split_times: opt.range == RangeStrategy::BitSplit, waksman: opt.permutation == PermutationStrategy::Waksman, covering_rom: false, + haboeck: opt.rom == RomStrategy::Haboeck, } } /// Create a default configuration, with this field. @@ -103,6 +105,7 @@ impl AccessCfg { split_times: false, waksman: false, covering_rom: false, + haboeck: true, } } /// Create a new default configuration @@ -278,6 +281,7 @@ impl Access { } } + /// Serialize a value as field elements. fn val_to_field_elements(val: &Term, c: &AccessCfg, out: &mut Vec) { match check(val) { Sort::Field(_) | Sort::Bool | Sort::BitVector(_) => out.push(scalar_to_field(val, c)), diff --git a/src/ir/opt/mem/ram/checker.rs b/src/ir/opt/mem/ram/checker.rs index 2ccce764d..481c3de53 100644 --- a/src/ir/opt/mem/ram/checker.rs +++ b/src/ir/opt/mem/ram/checker.rs @@ -7,6 +7,7 @@ use circ_fields::FieldT; use log::{debug, trace}; mod permutation; +mod rom; /// Check a RAM pub fn check_ram(c: &mut Computation, ram: Ram) { @@ -35,150 +36,160 @@ pub fn check_ram(c: &mut Computation, ram: Ram) { let v_s = ram.val_sort.clone(); let mut assertions = Vec::new(); - // (1) sort the transcript, checking only that we've applied a permutation. - let sorted_accesses = if ram.cfg.waksman { - let mut new_bit_var = - |name: &str, val: Term| c.new_var(&ns.fqn(name), Sort::Bool, PROVER_VIS, Some(val)); - permutation::waksman(&ram.accesses, &ram.cfg, &v_s, &mut new_bit_var) + if ram.cfg.covering_rom && ram.cfg.haboeck { + assertions.push(rom::check_covering_rom(c, ns.subspace("haboeck"), ram)); } else { - let mut new_var = - |name: &str, val: Term| c.new_var(&ns.fqn(name), f_s.clone(), PROVER_VIS, Some(val)); - permutation::msh( - &ram.accesses, - &ns, - &ram.cfg, - &mut new_var, - &v_s, - &mut assertions, - ) - }; + // (1) sort the transcript, checking only that we've applied a permutation. + let sorted_accesses = if ram.cfg.waksman { + let mut new_bit_var = + |name: &str, val: Term| c.new_var(&ns.fqn(name), Sort::Bool, PROVER_VIS, Some(val)); + permutation::waksman(&ram.accesses, &ram.cfg, &v_s, &mut new_bit_var) + } else { + let mut new_var = |name: &str, val: Term| { + c.new_var(&ns.fqn(name), f_s.clone(), PROVER_VIS, Some(val)) + }; + permutation::msh( + &ram.accesses, + &ns, + &ram.cfg, + &mut new_var, + &v_s, + &mut assertions, + ) + }; - // (2) check the sorted transcript - let n = sorted_accesses.len(); - let mut accs = sorted_accesses; + // (2) check the sorted transcript + let n = sorted_accesses.len(); + let mut accs = sorted_accesses; - let zero = pf_lit(f.new_v(0)); - let one = pf_lit(f.new_v(1)); - fn sub(a: &Term, b: &Term) -> Term { - term![PF_ADD; a.clone(), term![PF_NEG; b.clone()]] - } + let zero = pf_lit(f.new_v(0)); + let one = pf_lit(f.new_v(1)); + fn sub(a: &Term, b: &Term) -> Term { + term![PF_ADD; a.clone(), term![PF_NEG; b.clone()]] + } - if ram.cfg.covering_rom { - // the covering ROM case - let mut new_var = - |name: &str, val: Term| c.new_var(&ns.fqn(name), f_s.clone(), PROVER_VIS, Some(val)); - assertions.push(term_c![EQ; zero, accs[0].idx]); - for j in 0..(n - 1) { - // previous entry - let i = &accs[j].idx; - let v = accs[j].val_hash.as_ref().expect("missing value hash"); - // this entry - let i_n = &accs[j + 1].idx; - let v_n = accs[j + 1].val_hash.as_ref().expect("missing value hash"); + if ram.cfg.covering_rom { + // the non-Haboeck covering ROM case + let mut new_var = |name: &str, val: Term| { + c.new_var(&ns.fqn(name), f_s.clone(), PROVER_VIS, Some(val)) + }; + assertions.push(term_c![EQ; zero, accs[0].idx]); + for j in 0..(n - 1) { + // previous entry + let i = &accs[j].idx; + let v = accs[j].val_hash.as_ref().expect("missing value hash"); + // this entry + let i_n = &accs[j + 1].idx; + let v_n = accs[j + 1].val_hash.as_ref().expect("missing value hash"); - let i_d = sub(i_n, i); - let v_d = sub(v_n, v); + let i_d = sub(i_n, i); + let v_d = sub(v_n, v); - // (i' - i)(i' - i - 1) = 0 - assertions.push(term![EQ; term![PF_MUL; i_d.clone(), sub(&i_d, &one)], zero.clone()]); + // (i' - i)(i' - i - 1) = 0 + assertions + .push(term![EQ; term![PF_MUL; i_d.clone(), sub(&i_d, &one)], zero.clone()]); - // r = 1/(i' - i) - let r = new_var(&format!("r{}", j), term_c![PF_RECIP; i_d]); - // (i' - i)r(v' - v) = v' - v [v' = v OR i' != i] - assertions.push(term![EQ; term![PF_MUL; i_d, r, v_d.clone()], v_d]); - } - assertions.push(term_c![EQ; ram.cfg.pf_lit(ram.size-1), accs[n - 1].idx]); - } else { - // the general RAM case - // set c value if needed. - if !only_init { - accs[0].create = FieldBit::from_bool_lit(&ram.cfg, true); - for i in 0..(n - 1) { - let create = term![NOT; term![EQ; accs[i].idx.clone(), accs[i+1].idx.clone()]]; - accs[i + 1].create = FieldBit::from_bool(&ram.cfg, create); + // r = 1/(i' - i) + let r = new_var(&format!("r{}", j), term_c![PF_RECIP; i_d]); + // (i' - i)r(v' - v) = v' - v [v' = v OR i' != i] + assertions.push(term![EQ; term![PF_MUL; i_d, r, v_d.clone()], v_d]); + } + assertions.push(term_c![EQ; ram.cfg.pf_lit(ram.size-1), accs[n - 1].idx]); + } else { + // the general RAM case + // set c value if needed. + if !only_init { + accs[0].create = FieldBit::from_bool_lit(&ram.cfg, true); + for i in 0..(n - 1) { + let create = term![NOT; term![EQ; accs[i].idx.clone(), accs[i+1].idx.clone()]]; + accs[i + 1].create = FieldBit::from_bool(&ram.cfg, create); + } } - } - // (3a) v' = ite(a',v',v) - for i in 0..(n - 1) { - accs[i + 1].val = - term_c![Op::Ite; accs[i+1].active.b.clone(), accs[i+1].val, accs[i].val]; - } + // (3a) v' = ite(a',v',v) + for i in 0..(n - 1) { + accs[i + 1].val = + term_c![Op::Ite; accs[i+1].active.b.clone(), accs[i+1].val, accs[i].val]; + } - assertions.push(accs[0].create.b.clone()); + assertions.push(accs[0].create.b.clone()); - let mut deltas = Vec::new(); - // To: check some condition on the start? - for j in 0..(n - 1) { - // previous entry - let i = &accs[j].idx; - let t = &accs[j].time; - let v = accs[j].val_hash.as_ref().expect("missing value hash"); - // this entry - let i_n = &accs[j + 1].idx; - let t_n = &accs[j + 1].time; - let v_n = accs[j + 1].val_hash.as_ref().expect("missing value hash"); - let c_n = &accs[j + 1].create; - let w_n = &accs[j + 1].write; + let mut deltas = Vec::new(); + // To: check some condition on the start? + for j in 0..(n - 1) { + // previous entry + let i = &accs[j].idx; + let t = &accs[j].time; + let v = accs[j].val_hash.as_ref().expect("missing value hash"); + // this entry + let i_n = &accs[j + 1].idx; + let t_n = &accs[j + 1].time; + let v_n = accs[j + 1].val_hash.as_ref().expect("missing value hash"); + let c_n = &accs[j + 1].create; + let w_n = &accs[j + 1].write; - let v_p = if only_init { - v.clone() - } else { - term![ITE; c_n.b.clone(), default.clone(), v.clone()] - }; + let v_p = if only_init { + v.clone() + } else { + term![ITE; c_n.b.clone(), default.clone(), v.clone()] + }; - // delta = (1 - c')(t' - t) - deltas.push(term![PF_MUL; c_n.nf.clone(), sub(t_n, t)]); + // delta = (1 - c')(t' - t) + deltas.push(term![PF_MUL; c_n.nf.clone(), sub(t_n, t)]); - // check c value if not computed: (i' - i)(1 - c') = 0 - if only_init { + // check c value if not computed: (i' - i)(1 - c') = 0 + if only_init { + assertions + .push(term![EQ; term![PF_MUL; sub(i_n, i), c_n.nf.clone()], zero.clone()]); + } + // writes allow a value change: (v' - v)(1 - w') = 0 assertions - .push(term![EQ; term![PF_MUL; sub(i_n, i), c_n.nf.clone()], zero.clone()]); + .push(term![EQ; term![PF_MUL; sub(v_n, &v_p), w_n.nf.clone()], zero.clone()]); } - // writes allow a value change: (v' - v)(1 - w') = 0 - assertions.push(term![EQ; term![PF_MUL; sub(v_n, &v_p), w_n.nf.clone()], zero.clone()]); - } - // check that index blocks are unique - if !only_init { - if ram.cfg.sort_indices { - let bits = ram.size.next_power_of_two().ilog2() as usize; - trace!("Index difference checks ({bits} bits)"); - assertions.push(term![Op::PfFitsInBits(bits); accs[0].idx.clone()]); - for j in 0..(n - 1) { - let d = pf_sub(accs[j + 1].idx.clone(), accs[j].idx.clone()); - assertions.push(term![Op::PfFitsInBits(bits); d]); + // check that index blocks are unique + if !only_init { + if ram.cfg.sort_indices { + let bits = ram.size.next_power_of_two().ilog2() as usize; + trace!("Index difference checks ({bits} bits)"); + assertions.push(term![Op::PfFitsInBits(bits); accs[0].idx.clone()]); + for j in 0..(n - 1) { + let d = pf_sub(accs[j + 1].idx.clone(), accs[j].idx.clone()); + assertions.push(term![Op::PfFitsInBits(bits); d]); + } + } else { + derivative_gcd( + c, + accs.iter().map(|a| a.idx.clone()).collect(), + accs.iter().map(|a| a.create.b.clone()).collect(), + &ns, + &mut assertions, + f, + ); } - } else { - derivative_gcd( - c, - accs.iter().map(|a| a.idx.clone()).collect(), - accs.iter().map(|a| a.create.b.clone()).collect(), - &ns, - &mut assertions, - f, - ); } - } - // check ranges - assertions.push(c.outputs[0].clone()); - #[allow(clippy::type_complexity)] - let range_checker: Box< - dyn Fn(&mut Computation, Vec, &Namespace, &mut Vec, usize, &FieldT), - > = if ram.cfg.split_times { - Box::new(&bit_split_range_check) - } else { - Box::new(&range_check) - }; - range_checker( - c, - deltas, - &ns.subspace("time"), - &mut assertions, - ram.next_time + 1, - f, - ); + // check ranges + assertions.push(c.outputs[0].clone()); + #[allow(clippy::type_complexity)] + let range_checker: Box< + dyn Fn(&mut Computation, Vec, &Namespace, &mut Vec, usize, &FieldT), + > = if ram.cfg.split_times { + Box::new(&bit_split_range_check) + } else if ram.cfg.haboeck { + Box::new(&haboeck_range_check) + } else { + Box::new(&range_check) + }; + range_checker( + c, + deltas, + &ns.subspace("time"), + &mut assertions, + ram.next_time + 1, + f, + ); + } } c.outputs[0] = term(AND, assertions); } @@ -235,6 +246,21 @@ fn range_check( assertions.push(term![EQ; sorted.last().unwrap().clone(), end]); } +/// Haboeck range check +fn haboeck_range_check( + c: &mut Computation, + values: Vec, + ns: &Namespace, + assertions: &mut Vec, + n: usize, + f: &FieldT, +) { + let ns = ns.subspace("range"); + let f_sort = Sort::Field(f.clone()); + let haystack: Vec = f_sort.elems_iter().take(n).collect(); + assertions.push(rom::lookup(c, ns, haystack, values)); +} + /// Ensure that each element of `values` is in `[0, n)`. /// /// Assumes that each value is a field element. diff --git a/src/ir/opt/mem/ram/checker/rom.rs b/src/ir/opt/mem/ram/checker/rom.rs new file mode 100644 index 000000000..dec9554f6 --- /dev/null +++ b/src/ir/opt/mem/ram/checker/rom.rs @@ -0,0 +1,103 @@ +//! Implementation of ROM checking based on https://eprint.iacr.org/2022/1530.pdf +//! +//! Cost: about (N + A)(L + 1) where the ROM size is N, there are A reads, and values have size L. +//! If the ROM contents are fixed, cost drops to N + A(L + 1) + +use super::super::hash::UniversalHasher; +use super::{Access, Ram}; +use crate::front::PROVER_VIS; +use crate::ir::opt::cfold::fold; +use crate::ir::term::*; +use crate::util::ns::Namespace; + +use log::debug; + +/// The implementation of Haboeck's lookup argument. +/// +/// Takes haystack, needles, and returns a term which should be asserted to ensure that each needle +/// is in haystack. +pub fn lookup(c: &mut Computation, ns: Namespace, haystack: Vec, needles: Vec) -> Term { + debug!( + "Haboeck lookup haystack {}, needles {}", + haystack.len(), + needles.len() + ); + if haystack.is_empty() { + assert!(needles.is_empty()); + return bool_lit(true); + } + let sort = check(&haystack[0]); + let f = sort.as_pf().clone(); + let array_op = Op::Array(sort.clone(), sort.clone()); + let haystack_array = term(array_op.clone(), haystack.clone()); + let needles_array = term(array_op.clone(), needles.clone()); + let counts_pre = unmake_array(term![Op::ExtOp(ExtOp::Haboeck); haystack_array, needles_array]); + let counts: Vec = counts_pre + .into_iter() + .enumerate() + .map(|(i, coeff)| { + c.new_var( + &ns.fqn(format!("c{i}")), + sort.clone(), + PROVER_VIS, + Some(coeff), + ) + }) + .collect(); + let key = term( + Op::PfChallenge(ns.fqn("key"), f.clone()), + haystack + .iter() + .chain(&needles) + .chain(&counts) + .cloned() + .collect(), + ); + let haysum = term( + PF_ADD, + counts + .into_iter() + .zip(haystack) + .map(|(ct, hay)| term![PF_DIV; ct, term![PF_ADD; hay, key.clone()]]) + .collect(), + ); + let needlesum = term( + PF_ADD, + needles + .into_iter() + .map(|needle| term![PF_RECIP; term![PF_ADD; needle, key.clone()]]) + .collect(), + ); + term![Op::Eq; haysum, needlesum] +} + +/// Returns a term to assert. +pub fn check_covering_rom(c: &mut Computation, ns: Namespace, ram: Ram) -> Term { + assert!(ram.cfg.covering_rom); + let f = &ram.cfg.field; + if ram.accesses.is_empty() { + return bool_lit(true); + } + // (addr, value) + let mut reads: Vec> = Default::default(); + let mut writes: Vec> = Default::default(); + for a in &ram.accesses { + let mut access = vec![a.idx.clone()]; + Access::val_to_field_elements(&a.val, &ram.cfg, &mut access); + match fold(&a.write.b, &[]).as_bool_opt() { + Some(true) => writes.push(access), + Some(false) => reads.push(access), + None => panic!(), + } + } + assert!(!writes.is_empty()); + let uhf = UniversalHasher::new( + ns.fqn("uhf_key"), + f, + reads.iter().chain(&writes).flatten().cloned().collect(), + writes[0].len(), + ); + let write_hashes = writes.into_iter().map(|a| uhf.hash(a)).collect(); + let read_hashes = reads.into_iter().map(|a| uhf.hash(a)).collect(); + lookup(c, ns.subspace("scalar"), write_hashes, read_hashes) +} diff --git a/src/ir/term/ext.rs b/src/ir/term/ext.rs index a53fd4aea..ffbdf712d 100644 --- a/src/ir/term/ext.rs +++ b/src/ir/term/ext.rs @@ -5,6 +5,7 @@ use super::{Sort, Term, Value}; use circ_hc::Node; use serde::{Deserialize, Serialize}; +mod haboeck; mod poly; mod ram; mod sort; @@ -15,6 +16,8 @@ mod waksman; /// /// Often evaluatable, but not compilable. pub enum ExtOp { + /// See [haboeck]. + Haboeck, /// See [ram::eval] PersistentRamSplit, /// Given an array of tuples, returns a reordering such that the result is sorted. @@ -29,6 +32,7 @@ impl ExtOp { /// Its arity pub fn arity(&self) -> Option { match self { + ExtOp::Haboeck => Some(2), ExtOp::PersistentRamSplit => Some(2), ExtOp::Sort => Some(1), ExtOp::Waksman => Some(1), @@ -38,6 +42,7 @@ impl ExtOp { /// Type-check, given argument sorts pub fn check(&self, arg_sorts: &[&Sort]) -> Result { match self { + ExtOp::Haboeck => haboeck::check(arg_sorts), ExtOp::PersistentRamSplit => ram::check(arg_sorts), ExtOp::Sort => sort::check(arg_sorts), ExtOp::Waksman => waksman::check(arg_sorts), @@ -47,6 +52,7 @@ impl ExtOp { /// Evaluate, given argument values pub fn eval(&self, args: &[&Value]) -> Value { match self { + ExtOp::Haboeck => haboeck::eval(args), ExtOp::PersistentRamSplit => ram::eval(args), ExtOp::Sort => sort::eval(args), ExtOp::Waksman => waksman::eval(args), @@ -60,6 +66,7 @@ impl ExtOp { /// Parse, from bytes. pub fn parse(bytes: &[u8]) -> Option { match bytes { + b"haboeck" => Some(ExtOp::Haboeck), b"persistent_ram_split" => Some(ExtOp::PersistentRamSplit), b"uniq_deri_gcd" => Some(ExtOp::UniqDeriGcd), b"sort" => Some(ExtOp::Sort), @@ -67,6 +74,16 @@ impl ExtOp { _ => None, } } + /// To string + pub fn to_str(&self) -> &'static str { + match self { + ExtOp::Haboeck => "haboeck", + ExtOp::PersistentRamSplit => "persistent_ram_split", + ExtOp::UniqDeriGcd => "uniq_deri_gcd", + ExtOp::Sort => "sort", + ExtOp::Waksman => "Waksman", + } + } } #[cfg(test)] diff --git a/src/ir/term/ext/haboeck.rs b/src/ir/term/ext/haboeck.rs new file mode 100644 index 000000000..9f9628d52 --- /dev/null +++ b/src/ir/term/ext/haboeck.rs @@ -0,0 +1,52 @@ +//! Witness computation for Haboeck's lookup argument +//! +//! https://eprint.iacr.org/2022/1530.pdf +//! +//! Given a haystack array of values h1, ..., hN and an array of needles v1, ..., vA, outputs +//! an array of counts c1, ..., cN such that hi occurs ci times in v1, ..., vA. +//! +//! All input and output arrays must be be field -> field + +use crate::ir::term::ty::*; +use crate::ir::term::*; + +/// Type-check [super::ExtOp::UniqDeriGcd]. +pub fn check(arg_sorts: &[&Sort]) -> Result { + if let &[haystack, needles] = arg_sorts { + let (key0, value0, _n) = ty::array_or(haystack, "haystack must be an array")?; + let (key1, value1, _a) = ty::array_or(needles, "needles must be an array")?; + let key0 = pf_or(key0, "haystack indices must be field")?; + let key1 = pf_or(key1, "needles indices must be field")?; + let value0 = pf_or(value0, "haystack values must be field")?; + let value1 = pf_or(value1, "needles values must be field")?; + eq_or(key0, key1, "field must be the same")?; + eq_or(key1, value0, "field must be the same")?; + eq_or(value0, value1, "field must be the same")?; + Ok(haystack.clone()) + } else { + // wrong arg count + Err(TypeErrorReason::ExpectedArgs(2, arg_sorts.len())) + } +} + +/// Evaluate [super::ExtOp::UniqDeriGcd]. +pub fn eval(args: &[&Value]) -> Value { + let haystack = args[0].as_array().values(); + let sort = args[0].sort().as_array().0.clone(); + let field = sort.as_pf().clone(); + let needles = args[1].as_array().values(); + let haystack_item_index: FxHashMap = haystack + .iter() + .enumerate() + .map(|(i, v)| (v.clone(), i)) + .collect(); + let mut counts = vec![0; haystack.len()]; + for needle in needles { + counts[*haystack_item_index.get(&needle).expect("missing needle")] += 1; + } + let field_counts: Vec = counts + .into_iter() + .map(|c| Value::Field(field.new_v(c))) + .collect(); + Value::Array(Array::from_vec(sort.clone(), sort, field_counts)) +} diff --git a/src/ir/term/ext/test.rs b/src/ir/term/ext/test.rs index 5d9f5115e..d200c93ee 100644 --- a/src/ir/term/ext/test.rs +++ b/src/ir/term/ext/test.rs @@ -149,3 +149,65 @@ fn persistent_ram_split_eval() { ); assert_eq!(&actual_output, expected_output.get("output").unwrap()); } + +fn haboeck_eval(haystack: &[usize], needles: &[usize], counts: &[usize]) { + let t = text::parse_term( + format!( + " + (declare ( + (haystack (array (mod 17) (mod 17) {})) + (needles (array (mod 17) (mod 17) {})) + ) + (haboeck haystack needles))", + haystack.len(), + needles.len() + ) + .as_bytes(), + ); + assert_eq!(haystack.len(), counts.len()); + let haystack: Vec = haystack.iter().map(|h| format!("#f{}", h)).collect(); + let needles: Vec = needles.iter().map(|h| format!("#f{}", h)).collect(); + let counts: Vec = counts.iter().map(|h| format!("#f{}", h)).collect(); + + let inputs = text::parse_value_map( + format!( + "(set_default_modulus 17 + (let + ( + (haystack (#l (mod 17) ({}))) + (needles (#l (mod 17) ({}))) + ) false))", + haystack.join(" "), + needles.join(" ") + ) + .as_bytes(), + ); + let expected_output = text::parse_value_map( + format!( + "(set_default_modulus 17 + (let + ( + (counts (#l (mod 17) ({}))) + ) false))", + counts.join(" ") + ) + .as_bytes(), + ); + let actual_output = eval(&t, &inputs); + assert_eq!(&actual_output, expected_output.get("counts").unwrap()); +} + +#[test] +fn haboeck_eval_2_6_full() { + haboeck_eval(&[1, 2], &[1, 1, 2, 2, 1, 2], &[3, 3]); +} + +#[test] +fn haboeck_eval_4_4_full() { + haboeck_eval(&[1, 2, 4, 3], &[1, 2, 3, 4], &[1, 1, 1, 1]); +} + +#[test] +fn haboeck_eval_6_2() { + haboeck_eval(&[6, 8, 3, 4, 1, 2], &[1, 1], &[0, 0, 0, 0, 2, 0]); +} diff --git a/src/ir/term/fmt.rs b/src/ir/term/fmt.rs index f13a5d500..ae02c3818 100644 --- a/src/ir/term/fmt.rs +++ b/src/ir/term/fmt.rs @@ -183,6 +183,7 @@ impl DisplayIr for Op { Op::FpToFp(a) => write!(f, "(fp2fp {a})"), Op::PfUnOp(a) => write!(f, "{a}"), Op::PfNaryOp(a) => write!(f, "{a}"), + Op::PfDiv => write!(f, "/"), Op::IntNaryOp(a) => write!(f, "{a}"), Op::IntBinPred(a) => write!(f, "{a}"), Op::UbvToPf(a) => write!(f, "(bv2pf {})", a.modulus()), @@ -225,6 +226,7 @@ impl DisplayIr for Op { impl DisplayIr for ext::ExtOp { fn ir_fmt(&self, f: &mut IrFormatter) -> FmtResult { match self { + ext::ExtOp::Haboeck => write!(f, "haboeck"), ext::ExtOp::PersistentRamSplit => write!(f, "persistent_ram_split"), ext::ExtOp::UniqDeriGcd => write!(f, "uniq_deri_gcd"), ext::ExtOp::Sort => write!(f, "sort"), diff --git a/src/ir/term/mod.rs b/src/ir/term/mod.rs index b046ad189..8ea30a841 100644 --- a/src/ir/term/mod.rs +++ b/src/ir/term/mod.rs @@ -135,6 +135,8 @@ pub enum Op { PfChallenge(String, FieldT), /// Requires the input pf element to fit in this many (unsigned) bits. PfFitsInBits(usize), + /// Prime-field division + PfDiv, /// Integer n-ary operator IntNaryOp(IntNaryOp), @@ -247,6 +249,8 @@ pub const BV_CONCAT: Op = Op::BvConcat; pub const PF_NEG: Op = Op::PfUnOp(PfUnOp::Neg); /// prime-field reciprocal pub const PF_RECIP: Op = Op::PfUnOp(PfUnOp::Recip); +/// prime-field division +pub const PF_DIV: Op = Op::PfDiv; /// prime-field addition pub const PF_ADD: Op = Op::PfNaryOp(PfNaryOp::Add); /// prime-field multiplication @@ -296,6 +300,7 @@ impl Op { Op::SbvToFp(_) => Some(1), Op::FpToFp(_) => Some(1), Op::PfUnOp(_) => Some(1), + Op::PfDiv => Some(2), Op::PfNaryOp(_) => None, Op::PfChallenge(_, _) => None, Op::PfFitsInBits(..) => Some(1), @@ -1426,6 +1431,11 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap) -> PfUnOp::Neg => -a, } }), + Op::PfDiv => Value::Field({ + let a = args[0].as_pf().clone(); + let b = args[1].as_pf().clone(); + a * b.recip() + }), Op::PfNaryOp(o) => Value::Field({ let mut xs = args.iter().map(|a| a.as_pf().clone()); let f = xs.next().unwrap(); diff --git a/src/ir/term/text/mod.rs b/src/ir/term/text/mod.rs index 1efb6a623..1d6a30542 100644 --- a/src/ir/term/text/mod.rs +++ b/src/ir/term/text/mod.rs @@ -281,6 +281,7 @@ impl<'src> IrInterp<'src> { Leaf(Ident, b"+") => Ok(Op::PfNaryOp(PfNaryOp::Add)), Leaf(Ident, b"*") => Ok(Op::PfNaryOp(PfNaryOp::Mul)), Leaf(Ident, b"pfrecip") => Ok(Op::PfUnOp(PfUnOp::Recip)), + Leaf(Ident, b"/") => Ok(Op::PfDiv), Leaf(Ident, b"-") => Ok(Op::PfUnOp(PfUnOp::Neg)), Leaf(Ident, b"<") => Ok(INT_LT), Leaf(Ident, b"<=") => Ok(INT_LE), @@ -1324,4 +1325,20 @@ mod test { let t2 = parse_term(s.as_bytes()); assert_eq!(t, t2); } + + #[test] + fn haboeck_roundtrip() { + let t = parse_term( + b" + (declare ( + (haystack (array (mod 17) (mod 17) 5)) + (needles (array (mod 17) (mod 17) 8)) + ) + (haboeck haystack needles))", + ); + let s = serialize_term(&t); + println!("{s}"); + let t2 = parse_term(s.as_bytes()); + assert_eq!(t, t2); + } } diff --git a/src/ir/term/ty.rs b/src/ir/term/ty.rs index 2f5a2ace9..70ad5df84 100644 --- a/src/ir/term/ty.rs +++ b/src/ir/term/ty.rs @@ -55,6 +55,7 @@ fn check_dependencies(t: &Term) -> Vec { Op::SbvToFp(_) => Vec::new(), Op::FpToFp(_) => Vec::new(), Op::PfUnOp(_) => vec![t.cs()[0].clone()], + Op::PfDiv => vec![t.cs()[0].clone()], Op::PfNaryOp(_) => vec![t.cs()[0].clone()], Op::IntNaryOp(_) => Vec::new(), Op::IntBinPred(_) => Vec::new(), @@ -131,6 +132,7 @@ fn check_raw_step(t: &Term, tys: &TypeTable) -> Result { Op::FpToFp(64) => Ok(Sort::F64), Op::FpToFp(32) => Ok(Sort::F32), Op::PfUnOp(_) => Ok(get_ty(&t.cs()[0]).clone()), + Op::PfDiv => Ok(get_ty(&t.cs()[0]).clone()), Op::PfNaryOp(_) => Ok(get_ty(&t.cs()[0]).clone()), Op::IntNaryOp(_) => Ok(Sort::Int), Op::IntBinPred(_) => Ok(Sort::Bool), @@ -230,9 +232,8 @@ pub fn check_raw(t: &Term) -> Result { if p.upgrade().is_some() { to_check.pop(); continue; - } else { - term_tys.remove(&weak); } + term_tys.remove(&weak); } if !back.1 { back.1 = true; @@ -360,6 +361,9 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result Ok(Sort::Field(m.clone())), (Op::PfFitsInBits(_), &[a]) => pf_or(a, "pf fits in bits").map(|_| Sort::Bool), (Op::PfUnOp(_), &[a]) => pf_or(a, "pf unary op").map(|a| a.clone()), + (Op::PfDiv, &[a, b]) => { + eq_or(&pf_or(a, "pf / op").map(|a| a.clone())?, b, "pf / op").cloned() + } (Op::IntNaryOp(_), a) => { let ctx = "int nary op"; all_eq_or(a.iter().cloned(), ctx) @@ -385,7 +389,7 @@ pub fn rec_check_raw_helper(oper: &Op, a: &[&Sort]) -> Result { let ctx = "array op"; a.iter() - .try_fold((), |(), ai| eq_or(v, ai, ctx)) + .try_fold((), |(), ai| eq_or(v, ai, ctx).map(|_| ())) .map(|_| Sort::Array(Box::new(k.clone()), Box::new(v.clone()), a.len())) } (Op::Tuple, a) => Ok(Sort::Tuple(a.iter().map(|a| (*a).clone()).collect())), @@ -617,9 +621,13 @@ pub(super) fn tuple_or<'a>(a: &'a Sort, ctx: &'static str) -> Result<&'a [Sort], } } -pub(super) fn eq_or(a: &Sort, b: &Sort, ctx: &'static str) -> Result<(), TypeErrorReason> { +pub(super) fn eq_or<'a>( + a: &'a Sort, + b: &'a Sort, + ctx: &'static str, +) -> Result<&'a Sort, TypeErrorReason> { if a == b { - Ok(()) + Ok(a) } else { Err(TypeErrorReason::NotEqual(a.clone(), b.clone(), ctx)) } diff --git a/src/target/r1cs/trans.rs b/src/target/r1cs/trans.rs index 33dcc4b0a..40fe957c1 100644 --- a/src/target/r1cs/trans.rs +++ b/src/target/r1cs/trans.rs @@ -1094,6 +1094,17 @@ impl<'cfg> ToR1cs<'cfg> { i } }, + Op::PfDiv => match self.cfg.r1cs.div_by_zero { + FieldDivByZero::Incomplete => { + // ix = y + let y = self.get_pf(&c.cs()[0]).clone(); + let x = self.get_pf(&c.cs()[1]).clone(); + let div = self.fresh_wit("div", term![PF_DIV; y.0.clone(), x.0.clone()]); + self.constraint(x.1, div.1.clone(), y.1); + div + } + _ => unimplemented!(), + }, _ => panic!("Non-field in embed_pf: {}", c), }; self.cache.insert(c.clone(), EmbeddedTerm::Field(lc));