Skip to content

Commit

Permalink
Forward inlining on lambdas produces better results. This is due to a…
Browse files Browse the repository at this point in the history
… forward pass being able to apply an argument that may have no_inline at the top where as vice-versa would reduce the arg first.
  • Loading branch information
MicroProofs committed Nov 12, 2024
1 parent 3814a54 commit 9891d99
Showing 1 changed file with 74 additions and 55 deletions.
129 changes: 74 additions & 55 deletions crates/uplc/src/optimize/shrinker2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,7 @@ impl Term<Name> {
id_gen: &mut IdGen,
with: &mut impl FnMut(Option<usize>, &mut Term<Name>, Vec<Args>, &Scope, &mut Context),
context: &mut Context,
inline_lambda: bool,
) {
match self {
Term::Apply { function, argument } => {
Expand All @@ -888,6 +889,7 @@ impl Term<Name> {
id_gen,
with,
context,
inline_lambda,
);
let apply_id = id_gen.next_id();

Expand All @@ -901,6 +903,7 @@ impl Term<Name> {
id_gen,
with,
context,
inline_lambda,
);

with(Some(apply_id), self, vec![], scope, context);
Expand All @@ -911,7 +914,7 @@ impl Term<Name> {

arg_stack.push(Args::Force(force_id));

f.traverse_uplc_with_helper(scope, arg_stack, id_gen, with, context);
f.traverse_uplc_with_helper(scope, arg_stack, id_gen, with, context, inline_lambda);

with(Some(force_id), self, vec![], scope, context);
}
Expand All @@ -925,14 +928,16 @@ impl Term<Name> {
})
.unwrap_or_default();

d.traverse_uplc_with_helper(scope, arg_stack, id_gen, with, context);
d.traverse_uplc_with_helper(scope, arg_stack, id_gen, with, context, inline_lambda);

with(None, self, delay_arg, scope, context);
}
Term::Lambda {
parameter_name,
body,
} => {
let p = parameter_name.clone();

// Lambda pops one item off the arg stack. If there is no item then it is a unsaturated lambda
// NO_INLINE lambdas come in with 0 arguments on the arg stack
let args = if parameter_name.text == NO_INLINE {
Expand All @@ -947,51 +952,53 @@ impl Term<Name> {
.unwrap_or_default()
};

let body = Rc::make_mut(body);
if inline_lambda {
// Pass in either one or zero args.
// For lambda we run the function with first then recurse on the body or replaced term

body.traverse_uplc_with_helper(scope, arg_stack, id_gen, with, context);
with(None, self, args, scope, context);

with(None, self, args, scope, context);
match self {
Term::Lambda {
parameter_name,
body,
} if parameter_name.text == p.text && parameter_name.unique == p.unique => {
let body = Rc::make_mut(body);
body.traverse_uplc_with_helper(
scope,
arg_stack,
id_gen,
with,
context,
inline_lambda,
);
}

// if lambda_first {
// // Pass in either one or zero args.
// // For lambda we run the function with first then recurse on the body or replaced term

// if p.text.contains("pred") {
// println!("ARG STACK IS {:#?} WITH NAME {}", args, p.text);
// }
// with(None, self, args, scope, context);

// match self {
// Term::Lambda {
// parameter_name,
// body,
// } if parameter_name.text == p.text && parameter_name.unique == p.unique => {
// let body = Rc::make_mut(body);
// body.traverse_uplc_with_helper(
// scope,
// arg_stack,
// id_gen,
// with,
// context,
// lambda_first,
// );
// }

// Term::Constr { .. } => todo!(),
// Term::Case { .. } => todo!(),
// other => other.traverse_uplc_with_helper(
// scope,
// arg_stack,
// id_gen,
// with,
// context,
// lambda_first,
// ),
// }
// } else {

// }
Term::Constr { .. } => todo!(),
Term::Case { .. } => todo!(),
other => other.traverse_uplc_with_helper(
scope,
arg_stack,
id_gen,
with,
context,
inline_lambda,
),
}
} else {
let body = Rc::make_mut(body);

body.traverse_uplc_with_helper(
scope,
arg_stack,
id_gen,
with,
context,
inline_lambda,
);

with(None, self, args, scope, context);
}
}

Term::Case { .. } => todo!(),
Expand Down Expand Up @@ -1743,6 +1750,7 @@ impl Term<Name> {
impl Program<Name> {
fn traverse_uplc_with(
self,
inline_lambda: bool,
with: &mut impl FnMut(Option<usize>, &mut Term<Name>, Vec<Args>, &Scope, &mut Context),
) -> (Self, Context) {
let mut term = self.term;
Expand All @@ -1759,7 +1767,14 @@ impl Program<Name> {
node_count: 0,
};

term.traverse_uplc_with_helper(&scope, arg_stack, &mut id_gen, with, &mut context);
term.traverse_uplc_with_helper(
&scope,
arg_stack,
&mut id_gen,
with,
&mut context,
inline_lambda,
);
(
Program {
version: self.version,
Expand All @@ -1771,13 +1786,13 @@ impl Program<Name> {
// This one runs the optimizations that are only done a single time
pub fn run_once_pass(self) -> Self {
let program = self
.traverse_uplc_with(&mut |id, term, _arg_stack, scope, context| {
.traverse_uplc_with(false, &mut |id, term, _arg_stack, scope, context| {
term.inline_constr_ops(id, vec![], scope, context);
})
.0;

let (program, context) =
program.traverse_uplc_with(&mut |id, term, arg_stack, scope, context| {
program.traverse_uplc_with(false, &mut |id, term, arg_stack, scope, context| {
term.bls381_compressor(id, vec![], scope, context);
term.builtin_force_reducer(id, arg_stack, scope, context);
term.remove_inlined_ids(id, vec![], scope, context);
Expand Down Expand Up @@ -1828,7 +1843,7 @@ impl Program<Name> {
}

pub fn multi_pass(self) -> (Self, Context) {
self.traverse_uplc_with(&mut |id, term, arg_stack, scope, context| {
self.traverse_uplc_with(true, &mut |id, term, arg_stack, scope, context| {
let mut changed;

changed = term.lambda_reducer(id, arg_stack.clone(), scope, context);
Expand Down Expand Up @@ -1868,7 +1883,7 @@ impl Program<Name> {
}

pub fn clean_up(self) -> Self {
self.traverse_uplc_with(&mut |id, term, _arg_stack, scope, context| {
self.traverse_uplc_with(true, &mut |id, term, _arg_stack, scope, context| {
term.remove_no_inlines(id, vec![], scope, context);
})
.0
Expand All @@ -1887,8 +1902,9 @@ impl Program<Name> {

let mut final_ids: IndexMap<Vec<usize>, ()> = IndexMap::new();

let (step_a, _) =
self.traverse_uplc_with(&mut |_id, term, arg_stack, scope, _context| match term {
let (step_a, _) = self.traverse_uplc_with(
false,
&mut |_id, term, arg_stack, scope, _context| match term {
Term::Builtin(func) => {
if func.can_curry_builtin() && arg_stack.len() == func.arity() {
let arg_stack = arg_stack
Expand Down Expand Up @@ -1979,7 +1995,8 @@ impl Program<Name> {
Term::Constr { .. } => todo!(),
Term::Case { .. } => todo!(),
_ => {}
});
},
);

id_mapped_curry_terms
.into_iter()
Expand All @@ -2003,8 +2020,9 @@ impl Program<Name> {
}
});

let (mut step_b, _) =
step_a.traverse_uplc_with(&mut |id, term, arg_stack, scope, _context| match term {
let (mut step_b, _) = step_a.traverse_uplc_with(
false,
&mut |id, term, arg_stack, scope, _context| match term {
Term::Builtin(func) => {
if func.can_curry_builtin() && arg_stack.len() == func.arity() {
let mut arg_stack = arg_stack
Expand Down Expand Up @@ -2094,7 +2112,8 @@ impl Program<Name> {
}
}
}
});
},
);

let mut interner = CodeGenInterner::new();

Expand Down

0 comments on commit 9891d99

Please sign in to comment.