This commit is contained in:
Folkert 2020-08-12 14:03:55 +02:00
parent bdd8751107
commit 8c86836101
7 changed files with 257 additions and 30 deletions

View File

@ -774,9 +774,6 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
// construct the blocks that may jump to this join point
build_exp_stmt(env, layout_ids, scope, parent, remainder);
// remove this join point again
scope.join_points.remove(&id);
for (ptr, param) in joinpoint_args.iter().zip(parameters.iter()) {
scope.insert(param.symbol, (param.layout.clone(), *ptr));
}
@ -789,6 +786,9 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
// put the continuation in
let result = build_exp_stmt(env, layout_ids, scope, parent, continuation);
// remove this join point again
scope.join_points.remove(&id);
cont_block.move_after(phi_block).unwrap();
result

View File

@ -449,4 +449,25 @@ mod gen_primitives {
i64
);
}
#[test]
fn factorial() {
assert_evals_to!(
indoc!(
r#"
factorial = \n, accum ->
when n is
0 ->
accum
_ ->
factorial (n - 1) (n * accum)
factorial 10 1
"#
),
3628800,
i64
);
}
}

View File

@ -627,16 +627,17 @@ impl<'a> Context<'a> {
let v_orig = v;
// NOTE deviation from lean, insert into local context
let mut ctx = self.clone();
ctx.local_context.join_points.insert(*j, (xs, v_orig));
let (v, v_live_vars) = {
let ctx = self.update_var_info_with_params(xs);
let ctx = ctx.update_var_info_with_params(xs);
ctx.visit_stmt(v)
};
let v = self.add_dec_for_dead_params(xs, v, &v_live_vars);
let mut ctx = self.clone();
// NOTE deviation from lean, insert into local context
ctx.local_context.join_points.insert(*j, (xs, v_orig));
let v = ctx.add_dec_for_dead_params(xs, v, &v_live_vars);
let mut ctx = ctx.clone();
update_jp_live_vars(*j, xs, v, &mut ctx.jp_live_vars);

View File

@ -379,7 +379,7 @@ impl<'a, 'i> Env<'a, 'i> {
}
#[derive(Clone, Debug, PartialEq, Copy, Eq, Hash)]
pub struct JoinPointId(Symbol);
pub struct JoinPointId(pub Symbol);
#[derive(Clone, Debug, PartialEq)]
pub struct Param<'a> {
@ -1009,9 +1009,6 @@ fn specialize<'a>(
debug_assert!(matches!(unified, roc_unify::unify::Unified::Success(_)));
//let ret_symbol = env.unique_symbol();
//let hole = env.arena.alloc(Stmt::Ret(ret_symbol));
//let specialized_body = with_hole(env, body, procs, layout_cache, ret_symbol, hole);
let specialized_body = from_can(env, body, procs, layout_cache);
// reset subs, so we don't get type errors when specializing for a different signature
@ -1031,6 +1028,11 @@ fn specialize<'a>(
proc_args.push((layout, *arg_name));
}
let proc_args = proc_args.into_bump_slice();
let specialized_body =
crate::tail_recursion::make_tail_recursive(env, proc_name, specialized_body, proc_args);
let ret_layout = layout_cache
.from_var(&env.arena, ret_var, env.subs)
.unwrap_or_else(|err| panic!("TODO handle invalid function {:?}", err));
@ -1040,7 +1042,7 @@ fn specialize<'a>(
let proc = Proc {
name: proc_name,
args: proc_args.into_bump_slice(),
args: proc_args,
body: specialized_body,
closes_over: closes_over_layout,
ret_layout,

View File

@ -14,6 +14,7 @@
pub mod inc_dec;
pub mod ir;
pub mod layout;
pub mod tail_recursion;
// Temporary, while we can build up test cases and optimize the exhaustiveness checking.
// For now, following this warning's advice will lead to nasty type inference errors.

View File

@ -0,0 +1,201 @@
use crate::ir::{CallType, Env, Expr, JoinPointId, Param, Stmt};
use crate::layout::Layout;
use bumpalo::collections::Vec;
use bumpalo::Bump;
use roc_module::symbol::Symbol;
pub fn make_tail_recursive<'a>(
env: &mut Env<'a, '_>,
needle: Symbol,
stmt: Stmt<'a>,
args: &'a [(Layout<'a>, Symbol)],
) -> Stmt<'a> {
let id = JoinPointId(env.unique_symbol());
let alloced = env.arena.alloc(stmt);
match insert_jumps(env.arena, alloced, id, needle) {
None => alloced.clone(),
Some(new) => {
// jumps were inserted, we must now add a join point
let params = Vec::from_iter_in(
args.iter().map(|(layout, symbol)| Param {
symbol: *symbol,
layout: layout.clone(),
borrow: true,
}),
env.arena,
)
.into_bump_slice();
let args = Vec::from_iter_in(args.iter().map(|t| t.1), env.arena).into_bump_slice();
let jump = env.arena.alloc(Stmt::Jump(id, args));
Stmt::Join {
id,
remainder: jump,
parameters: params,
continuation: new,
}
}
}
}
fn insert_jumps<'a>(
arena: &'a Bump,
stmt: &'a Stmt<'a>,
goal_id: JoinPointId,
needle: Symbol,
) -> Option<&'a Stmt<'a>> {
use Stmt::*;
match stmt {
Let(
symbol,
Expr::FunctionCall {
call_type: CallType::ByName(fsym),
args,
..
},
_,
Stmt::Ret(rsym),
) if needle == *fsym && symbol == rsym => {
// replace the call and return with a jump
let jump = Stmt::Jump(goal_id, args);
Some(arena.alloc(jump))
}
Let(symbol, expr, layout, cont) => {
let opt_cont = insert_jumps(arena, cont, goal_id, needle);
if opt_cont.is_some() {
let cont = opt_cont.unwrap_or(cont);
Some(arena.alloc(Let(*symbol, expr.clone(), layout.clone(), cont)))
} else {
None
}
}
Join {
id,
parameters,
remainder,
continuation,
} => {
let opt_remainder = insert_jumps(arena, remainder, goal_id, needle);
let opt_continuation = insert_jumps(arena, continuation, goal_id, needle);
if opt_remainder.is_some() || opt_continuation.is_some() {
let remainder = opt_remainder.unwrap_or(remainder);
let continuation = opt_continuation.unwrap_or_else(|| *continuation);
Some(arena.alloc(Join {
id: *id,
parameters,
remainder,
continuation,
}))
} else {
None
}
}
Cond {
cond_symbol,
cond_layout,
branching_symbol,
branching_layout,
pass,
fail,
ret_layout,
} => {
let opt_pass = insert_jumps(arena, pass, goal_id, needle);
let opt_fail = insert_jumps(arena, fail, goal_id, needle);
if opt_pass.is_some() || opt_fail.is_some() {
let pass = opt_pass.unwrap_or(pass);
let fail = opt_fail.unwrap_or_else(|| *fail);
Some(arena.alloc(Cond {
cond_symbol: *cond_symbol,
cond_layout: cond_layout.clone(),
branching_symbol: *branching_symbol,
branching_layout: branching_layout.clone(),
pass,
fail,
ret_layout: ret_layout.clone(),
}))
} else {
None
}
}
Switch {
cond_symbol,
cond_layout,
branches,
default_branch,
ret_layout,
} => {
let opt_default = insert_jumps(arena, default_branch, goal_id, needle);
let mut did_change = false;
let opt_branches = Vec::from_iter_in(
branches.iter().map(|(label, branch)| {
match insert_jumps(arena, branch, goal_id, needle) {
None => None,
Some(branch) => {
did_change = true;
Some((*label, branch.clone()))
}
}
}),
arena,
);
if opt_default.is_some() || did_change {
let default_branch = opt_default.unwrap_or(default_branch);
let branches = if did_change {
let new = Vec::from_iter_in(
opt_branches.into_iter().zip(branches.iter()).map(
|(opt_branch, branch)| match opt_branch {
None => branch.clone(),
Some(new_branch) => new_branch,
},
),
arena,
);
new.into_bump_slice()
} else {
branches
};
Some(arena.alloc(Switch {
cond_symbol: *cond_symbol,
cond_layout: cond_layout.clone(),
default_branch,
branches,
ret_layout: ret_layout.clone(),
}))
} else {
None
}
}
Ret(_) => None,
Inc(symbol, cont) => match insert_jumps(arena, cont, goal_id, needle) {
Some(cont) => Some(arena.alloc(Inc(*symbol, cont))),
None => None,
},
Dec(symbol, cont) => match insert_jumps(arena, cont, goal_id, needle) {
Some(cont) => Some(arena.alloc(Dec(*symbol, cont))),
None => None,
},
Jump(_, _) => None,
RuntimeError(_) => None,
}
}

View File

@ -587,8 +587,8 @@ mod test_mono {
ret Test.9;
procedure Num.14 (#Attr.2, #Attr.3):
let Test.10 = lowlevel NumAdd #Attr.2 #Attr.3;
ret Test.10;
let Test.11 = lowlevel NumAdd #Attr.2 #Attr.3;
ret Test.11;
let Test.8 = 1f64;
let Test.1 = Array [Test.8];
@ -1082,22 +1082,23 @@ mod test_mono {
indoc!(
r#"
procedure Test.0 (Test.2, Test.3):
let Test.15 = true;
let Test.16 = 0i64;
let Test.17 = lowlevel Eq Test.16 Test.2;
let Test.14 = lowlevel And Test.17 Test.15;
if Test.14 then
ret Test.3;
else
let Test.12 = 1i64;
let Test.9 = CallByName Num.15 Test.2 Test.12;
let Test.10 = CallByName Num.16 Test.2 Test.3;
let Test.8 = CallByName Test.0 Test.9 Test.10;
ret Test.8;
jump Test.20 Test.2 Test.3;
joinpoint Test.20 Test.2 Test.3:
let Test.17 = true;
let Test.18 = 0i64;
let Test.19 = lowlevel Eq Test.18 Test.2;
let Test.16 = lowlevel And Test.19 Test.17;
if Test.16 then
ret Test.3;
else
let Test.13 = 1i64;
let Test.9 = CallByName Num.15 Test.2 Test.13;
let Test.10 = CallByName Num.16 Test.2 Test.3;
jump Test.20 Test.9 Test.10;
procedure Num.15 (#Attr.2, #Attr.3):
let Test.13 = lowlevel NumSub #Attr.2 #Attr.3;
ret Test.13;
let Test.14 = lowlevel NumSub #Attr.2 #Attr.3;
ret Test.14;
procedure Num.16 (#Attr.2, #Attr.3):
let Test.11 = lowlevel NumMul #Attr.2 #Attr.3;