Merge pull request #5407 from roc-lang/last-seen-join-points

Last seen join points
This commit is contained in:
Brendan Hansknecht 2023-05-14 16:29:44 +00:00 committed by GitHub
commit b32cd5687b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 198 additions and 142 deletions

View File

@ -3,6 +3,12 @@ test-gen-llvm = "test -p test_gen"
test-gen-dev = "test -p roc_gen_dev -p test_gen --no-default-features --features gen-dev"
test-gen-wasm = "test -p roc_gen_wasm -p test_gen --no-default-features --features gen-wasm"
test-gen-llvm-wasm = "test -p roc_gen_wasm -p test_gen --no-default-features --features gen-llvm-wasm"
nextest-gen-llvm = "nextest run -p test_gen"
nextest-gen-dev = "nextest run -p roc_gen_dev -p test_gen --no-default-features --features gen-dev"
nextest-gen-wasm = "nextest run -p roc_gen_wasm -p test_gen --no-default-features --features gen-wasm"
nextest-gen-llvm-wasm = "nextest run -p roc_gen_wasm -p test_gen --no-default-features --features gen-llvm-wasm"
uitest = "test -p uitest"
[target.wasm32-unknown-unknown]

View File

@ -5,6 +5,8 @@
// See github.com/roc-lang/roc/issues/800 for discussion of the large_enum_variant check.
#![allow(clippy::large_enum_variant, clippy::upper_case_acronyms)]
use std::collections::hash_map::Entry;
use bumpalo::{collections::Vec, Bump};
use roc_builtins::bitcode::{self, FloatWidth, IntWidth};
use roc_collections::all::{MutMap, MutSet};
@ -91,6 +93,190 @@ struct ListArgument<'a> {
element_width: Symbol,
}
// Track when a variable is last used (and hence when it can be disregarded). This is non-trivial
// in the presence of join points. Consider this example:
//
// let len = 3
//
// joinpoint f = \a ->
// joinpoint g = \b ->
// # len is used here
// in
// ...
// in
// ...
//
// we have to keep `len` alive until after the joinpoint goes out of scope!
#[derive(Debug, Default)]
struct LastSeenMap<'a> {
last_seen: MutMap<Symbol, *const Stmt<'a>>,
join_map: MutMap<JoinPointId, &'a [Param<'a>]>,
}
impl<'a> LastSeenMap<'a> {
fn set_last_seen(&mut self, symbol: Symbol, stmt: &'a Stmt<'a>) {
self.last_seen.insert(symbol, stmt);
}
/// scan_ast runs through the ast and fill the last seen map.
/// This must iterate through the ast in the same way that build_stmt does. i.e. then before else.
fn scan_ast(root: &'a Stmt<'a>) -> MutMap<Symbol, *const Stmt<'a>> {
let mut this: Self = Default::default();
this.scan_ast_help(root);
this.last_seen
}
fn scan_ast_help(&mut self, stmt: &'a Stmt<'a>) {
match stmt {
Stmt::Let(sym, expr, _, following) => {
self.set_last_seen(*sym, stmt);
match expr {
Expr::Literal(_) => {}
Expr::NullPointer => {}
Expr::Call(call) => self.scan_ast_call(call, stmt),
Expr::Tag { arguments, .. } => {
for sym in *arguments {
self.set_last_seen(*sym, stmt);
}
}
Expr::ExprBox { symbol } => {
self.set_last_seen(*symbol, stmt);
}
Expr::ExprUnbox { symbol } => {
self.set_last_seen(*symbol, stmt);
}
Expr::Struct(syms) => {
for sym in *syms {
self.set_last_seen(*sym, stmt);
}
}
Expr::StructAtIndex { structure, .. } => {
self.set_last_seen(*structure, stmt);
}
Expr::GetTagId { structure, .. } => {
self.set_last_seen(*structure, stmt);
}
Expr::UnionAtIndex { structure, .. } => {
self.set_last_seen(*structure, stmt);
}
Expr::Array { elems, .. } => {
for elem in *elems {
if let ListLiteralElement::Symbol(sym) = elem {
self.set_last_seen(*sym, stmt);
}
}
}
Expr::Reuse {
symbol, arguments, ..
} => {
self.set_last_seen(*symbol, stmt);
for sym in *arguments {
self.set_last_seen(*sym, stmt);
}
}
Expr::Reset { symbol, .. } | Expr::ResetRef { symbol, .. } => {
self.set_last_seen(*symbol, stmt);
}
Expr::EmptyArray => {}
Expr::RuntimeErrorFunction(_) => {}
}
self.scan_ast_help(following);
}
Stmt::Switch {
cond_symbol,
branches,
default_branch,
..
} => {
self.set_last_seen(*cond_symbol, stmt);
for (_, _, branch) in *branches {
self.scan_ast_help(branch);
}
self.scan_ast_help(default_branch.1);
}
Stmt::Ret(sym) => {
self.set_last_seen(*sym, stmt);
}
Stmt::Refcounting(modify, following) => {
let sym = modify.get_symbol();
self.set_last_seen(sym, stmt);
self.scan_ast_help(following);
}
Stmt::Join {
parameters,
body: continuation,
remainder,
id: JoinPointId(sym),
..
} => {
self.set_last_seen(*sym, stmt);
self.join_map.insert(JoinPointId(*sym), parameters);
self.scan_ast_help(remainder);
for (symbol, symbol_stmt) in Self::scan_ast(continuation) {
match self.last_seen.entry(symbol) {
Entry::Occupied(mut occupied) => {
// lives for the joinpoint
occupied.insert(stmt);
}
Entry::Vacant(vacant) => {
// lives for some time within the continuation
vacant.insert(symbol_stmt);
}
}
}
for param in *parameters {
self.set_last_seen(param.symbol, stmt);
}
}
Stmt::Jump(JoinPointId(sym), symbols) => {
if let Some(parameters) = self.join_map.get(&JoinPointId(*sym)) {
// Keep the parameters around. They will be overwritten when jumping.
for param in *parameters {
self.set_last_seen(param.symbol, stmt);
}
}
for sym in *symbols {
self.set_last_seen(*sym, stmt);
}
}
Stmt::Dbg { .. } => todo!("dbg not implemented in the dev backend"),
Stmt::Expect { .. } => todo!("expect is not implemented in the dev backend"),
Stmt::ExpectFx { .. } => todo!("expect-fx is not implemented in the dev backend"),
Stmt::Crash(msg, _crash_tag) => {
self.set_last_seen(*msg, stmt);
}
}
}
fn scan_ast_call(&mut self, call: &roc_mono::ir::Call, stmt: &'a roc_mono::ir::Stmt<'a>) {
let roc_mono::ir::Call {
call_type,
arguments,
} = call;
for sym in *arguments {
self.set_last_seen(*sym, stmt);
}
match call_type {
CallType::ByName { .. } => {}
CallType::LowLevel { .. } => {}
CallType::HigherOrder { .. } => {}
CallType::Foreign { .. } => {}
}
}
}
trait Backend<'a> {
fn env(&self) -> &Env<'a>;
fn interns(&self) -> &Interns;
@ -273,14 +459,16 @@ trait Backend<'a> {
proc.ret_layout,
);
let body = self.env().arena.alloc(proc.body);
self.reset(proc_name, proc.is_self_recursive);
self.load_args(proc.args, &proc.ret_layout);
for (layout, sym) in proc.args {
self.set_layout_map(*sym, layout);
}
self.scan_ast(&proc.body);
self.scan_ast(body);
self.create_free_map();
self.build_stmt(layout_ids, &proc.body, &proc.ret_layout);
self.build_stmt(layout_ids, body, &proc.ret_layout);
let mut helper_proc_names = bumpalo::vec![in self.env().arena];
helper_proc_names.reserve(self.helper_proc_symbols().len());
@ -2033,11 +2221,6 @@ trait Backend<'a> {
/// free_symbol frees any registers or stack space used to hold a symbol.
fn free_symbol(&mut self, sym: &Symbol);
/// set_last_seen sets the statement a symbol was last seen in.
fn set_last_seen(&mut self, sym: Symbol, stmt: &Stmt<'a>) {
self.last_seen_map().insert(sym, stmt);
}
/// last_seen_map gets the map from symbol to when it is last seen in the function.
fn last_seen_map(&mut self) -> &mut MutMap<Symbol, *const Stmt<'a>>;
@ -2080,140 +2263,7 @@ trait Backend<'a> {
/// scan_ast runs through the ast and fill the last seen map.
/// This must iterate through the ast in the same way that build_stmt does. i.e. then before else.
fn scan_ast(&mut self, stmt: &Stmt<'a>) {
// Join map keeps track of join point parameters so that we can keep them around while they still might be jumped to.
let mut join_map: MutMap<JoinPointId, &'a [Param<'a>]> = MutMap::default();
match stmt {
Stmt::Let(sym, expr, _, following) => {
self.set_last_seen(*sym, stmt);
match expr {
Expr::Literal(_) => {}
Expr::NullPointer => {}
Expr::Call(call) => self.scan_ast_call(call, stmt),
Expr::Tag { arguments, .. } => {
for sym in *arguments {
self.set_last_seen(*sym, stmt);
}
}
Expr::ExprBox { symbol } => {
self.set_last_seen(*symbol, stmt);
}
Expr::ExprUnbox { symbol } => {
self.set_last_seen(*symbol, stmt);
}
Expr::Struct(syms) => {
for sym in *syms {
self.set_last_seen(*sym, stmt);
}
}
Expr::StructAtIndex { structure, .. } => {
self.set_last_seen(*structure, stmt);
}
Expr::GetTagId { structure, .. } => {
self.set_last_seen(*structure, stmt);
}
Expr::UnionAtIndex { structure, .. } => {
self.set_last_seen(*structure, stmt);
}
Expr::Array { elems, .. } => {
for elem in *elems {
if let ListLiteralElement::Symbol(sym) = elem {
self.set_last_seen(*sym, stmt);
}
}
}
Expr::Reuse {
symbol, arguments, ..
} => {
self.set_last_seen(*symbol, stmt);
for sym in *arguments {
self.set_last_seen(*sym, stmt);
}
}
Expr::Reset { symbol, .. } | Expr::ResetRef { symbol, .. } => {
self.set_last_seen(*symbol, stmt);
}
Expr::EmptyArray => {}
Expr::RuntimeErrorFunction(_) => {}
}
self.scan_ast(following);
}
Stmt::Switch {
cond_symbol,
branches,
default_branch,
..
} => {
self.set_last_seen(*cond_symbol, stmt);
for (_, _, branch) in *branches {
self.scan_ast(branch);
}
self.scan_ast(default_branch.1);
}
Stmt::Ret(sym) => {
self.set_last_seen(*sym, stmt);
}
Stmt::Refcounting(modify, following) => {
let sym = modify.get_symbol();
self.set_last_seen(sym, stmt);
self.scan_ast(following);
}
Stmt::Join {
parameters,
body: continuation,
remainder,
id: JoinPointId(sym),
..
} => {
self.set_last_seen(*sym, stmt);
join_map.insert(JoinPointId(*sym), parameters);
for param in *parameters {
self.set_last_seen(param.symbol, stmt);
}
self.scan_ast(remainder);
self.scan_ast(continuation);
}
Stmt::Jump(JoinPointId(sym), symbols) => {
if let Some(parameters) = join_map.get(&JoinPointId(*sym)) {
// Keep the parameters around. They will be overwritten when jumping.
for param in *parameters {
self.set_last_seen(param.symbol, stmt);
}
}
for sym in *symbols {
self.set_last_seen(*sym, stmt);
}
}
Stmt::Dbg { .. } => todo!("dbg not implemented in the dev backend"),
Stmt::Expect { .. } => todo!("expect is not implemented in the dev backend"),
Stmt::ExpectFx { .. } => todo!("expect-fx is not implemented in the dev backend"),
Stmt::Crash(msg, _crash_tag) => {
self.set_last_seen(*msg, stmt);
}
}
}
fn scan_ast_call(&mut self, call: &roc_mono::ir::Call, stmt: &roc_mono::ir::Stmt<'a>) {
let roc_mono::ir::Call {
call_type,
arguments,
} = call;
for sym in *arguments {
self.set_last_seen(*sym, stmt);
}
match call_type {
CallType::ByName { .. } => {}
CallType::LowLevel { .. } => {}
CallType::HigherOrder { .. } => {}
CallType::Foreign { .. } => {}
}
fn scan_ast(&mut self, stmt: &'a Stmt<'a>) {
*self.last_seen_map() = LastSeenMap::scan_ast(stmt);
}
}