mirror of
https://github.com/roc-lang/roc.git
synced 2024-11-10 10:02:38 +03:00
Merge pull request #5407 from roc-lang/last-seen-join-points
Last seen join points
This commit is contained in:
commit
b32cd5687b
@ -3,6 +3,12 @@ test-gen-llvm = "test -p test_gen"
|
||||
test-gen-dev = "test -p roc_gen_dev -p test_gen --no-default-features --features gen-dev"
|
||||
test-gen-wasm = "test -p roc_gen_wasm -p test_gen --no-default-features --features gen-wasm"
|
||||
test-gen-llvm-wasm = "test -p roc_gen_wasm -p test_gen --no-default-features --features gen-llvm-wasm"
|
||||
|
||||
nextest-gen-llvm = "nextest run -p test_gen"
|
||||
nextest-gen-dev = "nextest run -p roc_gen_dev -p test_gen --no-default-features --features gen-dev"
|
||||
nextest-gen-wasm = "nextest run -p roc_gen_wasm -p test_gen --no-default-features --features gen-wasm"
|
||||
nextest-gen-llvm-wasm = "nextest run -p roc_gen_wasm -p test_gen --no-default-features --features gen-llvm-wasm"
|
||||
|
||||
uitest = "test -p uitest"
|
||||
|
||||
[target.wasm32-unknown-unknown]
|
||||
|
@ -5,6 +5,8 @@
|
||||
// See github.com/roc-lang/roc/issues/800 for discussion of the large_enum_variant check.
|
||||
#![allow(clippy::large_enum_variant, clippy::upper_case_acronyms)]
|
||||
|
||||
use std::collections::hash_map::Entry;
|
||||
|
||||
use bumpalo::{collections::Vec, Bump};
|
||||
use roc_builtins::bitcode::{self, FloatWidth, IntWidth};
|
||||
use roc_collections::all::{MutMap, MutSet};
|
||||
@ -91,6 +93,190 @@ struct ListArgument<'a> {
|
||||
element_width: Symbol,
|
||||
}
|
||||
|
||||
// Track when a variable is last used (and hence when it can be disregarded). This is non-trivial
|
||||
// in the presence of join points. Consider this example:
|
||||
//
|
||||
// let len = 3
|
||||
//
|
||||
// joinpoint f = \a ->
|
||||
// joinpoint g = \b ->
|
||||
// # len is used here
|
||||
// in
|
||||
// ...
|
||||
// in
|
||||
// ...
|
||||
//
|
||||
// we have to keep `len` alive until after the joinpoint goes out of scope!
|
||||
#[derive(Debug, Default)]
|
||||
struct LastSeenMap<'a> {
|
||||
last_seen: MutMap<Symbol, *const Stmt<'a>>,
|
||||
join_map: MutMap<JoinPointId, &'a [Param<'a>]>,
|
||||
}
|
||||
|
||||
impl<'a> LastSeenMap<'a> {
|
||||
fn set_last_seen(&mut self, symbol: Symbol, stmt: &'a Stmt<'a>) {
|
||||
self.last_seen.insert(symbol, stmt);
|
||||
}
|
||||
|
||||
/// scan_ast runs through the ast and fill the last seen map.
|
||||
/// This must iterate through the ast in the same way that build_stmt does. i.e. then before else.
|
||||
fn scan_ast(root: &'a Stmt<'a>) -> MutMap<Symbol, *const Stmt<'a>> {
|
||||
let mut this: Self = Default::default();
|
||||
|
||||
this.scan_ast_help(root);
|
||||
|
||||
this.last_seen
|
||||
}
|
||||
|
||||
fn scan_ast_help(&mut self, stmt: &'a Stmt<'a>) {
|
||||
match stmt {
|
||||
Stmt::Let(sym, expr, _, following) => {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
match expr {
|
||||
Expr::Literal(_) => {}
|
||||
Expr::NullPointer => {}
|
||||
|
||||
Expr::Call(call) => self.scan_ast_call(call, stmt),
|
||||
|
||||
Expr::Tag { arguments, .. } => {
|
||||
for sym in *arguments {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
}
|
||||
}
|
||||
Expr::ExprBox { symbol } => {
|
||||
self.set_last_seen(*symbol, stmt);
|
||||
}
|
||||
Expr::ExprUnbox { symbol } => {
|
||||
self.set_last_seen(*symbol, stmt);
|
||||
}
|
||||
Expr::Struct(syms) => {
|
||||
for sym in *syms {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
}
|
||||
}
|
||||
Expr::StructAtIndex { structure, .. } => {
|
||||
self.set_last_seen(*structure, stmt);
|
||||
}
|
||||
Expr::GetTagId { structure, .. } => {
|
||||
self.set_last_seen(*structure, stmt);
|
||||
}
|
||||
Expr::UnionAtIndex { structure, .. } => {
|
||||
self.set_last_seen(*structure, stmt);
|
||||
}
|
||||
Expr::Array { elems, .. } => {
|
||||
for elem in *elems {
|
||||
if let ListLiteralElement::Symbol(sym) = elem {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
}
|
||||
}
|
||||
}
|
||||
Expr::Reuse {
|
||||
symbol, arguments, ..
|
||||
} => {
|
||||
self.set_last_seen(*symbol, stmt);
|
||||
for sym in *arguments {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
}
|
||||
}
|
||||
Expr::Reset { symbol, .. } | Expr::ResetRef { symbol, .. } => {
|
||||
self.set_last_seen(*symbol, stmt);
|
||||
}
|
||||
Expr::EmptyArray => {}
|
||||
Expr::RuntimeErrorFunction(_) => {}
|
||||
}
|
||||
self.scan_ast_help(following);
|
||||
}
|
||||
|
||||
Stmt::Switch {
|
||||
cond_symbol,
|
||||
branches,
|
||||
default_branch,
|
||||
..
|
||||
} => {
|
||||
self.set_last_seen(*cond_symbol, stmt);
|
||||
for (_, _, branch) in *branches {
|
||||
self.scan_ast_help(branch);
|
||||
}
|
||||
self.scan_ast_help(default_branch.1);
|
||||
}
|
||||
Stmt::Ret(sym) => {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
}
|
||||
Stmt::Refcounting(modify, following) => {
|
||||
let sym = modify.get_symbol();
|
||||
|
||||
self.set_last_seen(sym, stmt);
|
||||
self.scan_ast_help(following);
|
||||
}
|
||||
Stmt::Join {
|
||||
parameters,
|
||||
body: continuation,
|
||||
remainder,
|
||||
id: JoinPointId(sym),
|
||||
..
|
||||
} => {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
self.join_map.insert(JoinPointId(*sym), parameters);
|
||||
self.scan_ast_help(remainder);
|
||||
|
||||
for (symbol, symbol_stmt) in Self::scan_ast(continuation) {
|
||||
match self.last_seen.entry(symbol) {
|
||||
Entry::Occupied(mut occupied) => {
|
||||
// lives for the joinpoint
|
||||
occupied.insert(stmt);
|
||||
}
|
||||
Entry::Vacant(vacant) => {
|
||||
// lives for some time within the continuation
|
||||
vacant.insert(symbol_stmt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for param in *parameters {
|
||||
self.set_last_seen(param.symbol, stmt);
|
||||
}
|
||||
}
|
||||
Stmt::Jump(JoinPointId(sym), symbols) => {
|
||||
if let Some(parameters) = self.join_map.get(&JoinPointId(*sym)) {
|
||||
// Keep the parameters around. They will be overwritten when jumping.
|
||||
for param in *parameters {
|
||||
self.set_last_seen(param.symbol, stmt);
|
||||
}
|
||||
}
|
||||
for sym in *symbols {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
}
|
||||
}
|
||||
|
||||
Stmt::Dbg { .. } => todo!("dbg not implemented in the dev backend"),
|
||||
Stmt::Expect { .. } => todo!("expect is not implemented in the dev backend"),
|
||||
Stmt::ExpectFx { .. } => todo!("expect-fx is not implemented in the dev backend"),
|
||||
|
||||
Stmt::Crash(msg, _crash_tag) => {
|
||||
self.set_last_seen(*msg, stmt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn scan_ast_call(&mut self, call: &roc_mono::ir::Call, stmt: &'a roc_mono::ir::Stmt<'a>) {
|
||||
let roc_mono::ir::Call {
|
||||
call_type,
|
||||
arguments,
|
||||
} = call;
|
||||
|
||||
for sym in *arguments {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
}
|
||||
|
||||
match call_type {
|
||||
CallType::ByName { .. } => {}
|
||||
CallType::LowLevel { .. } => {}
|
||||
CallType::HigherOrder { .. } => {}
|
||||
CallType::Foreign { .. } => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait Backend<'a> {
|
||||
fn env(&self) -> &Env<'a>;
|
||||
fn interns(&self) -> &Interns;
|
||||
@ -273,14 +459,16 @@ trait Backend<'a> {
|
||||
proc.ret_layout,
|
||||
);
|
||||
|
||||
let body = self.env().arena.alloc(proc.body);
|
||||
|
||||
self.reset(proc_name, proc.is_self_recursive);
|
||||
self.load_args(proc.args, &proc.ret_layout);
|
||||
for (layout, sym) in proc.args {
|
||||
self.set_layout_map(*sym, layout);
|
||||
}
|
||||
self.scan_ast(&proc.body);
|
||||
self.scan_ast(body);
|
||||
self.create_free_map();
|
||||
self.build_stmt(layout_ids, &proc.body, &proc.ret_layout);
|
||||
self.build_stmt(layout_ids, body, &proc.ret_layout);
|
||||
|
||||
let mut helper_proc_names = bumpalo::vec![in self.env().arena];
|
||||
helper_proc_names.reserve(self.helper_proc_symbols().len());
|
||||
@ -2033,11 +2221,6 @@ trait Backend<'a> {
|
||||
/// free_symbol frees any registers or stack space used to hold a symbol.
|
||||
fn free_symbol(&mut self, sym: &Symbol);
|
||||
|
||||
/// set_last_seen sets the statement a symbol was last seen in.
|
||||
fn set_last_seen(&mut self, sym: Symbol, stmt: &Stmt<'a>) {
|
||||
self.last_seen_map().insert(sym, stmt);
|
||||
}
|
||||
|
||||
/// last_seen_map gets the map from symbol to when it is last seen in the function.
|
||||
fn last_seen_map(&mut self) -> &mut MutMap<Symbol, *const Stmt<'a>>;
|
||||
|
||||
@ -2080,140 +2263,7 @@ trait Backend<'a> {
|
||||
|
||||
/// scan_ast runs through the ast and fill the last seen map.
|
||||
/// This must iterate through the ast in the same way that build_stmt does. i.e. then before else.
|
||||
fn scan_ast(&mut self, stmt: &Stmt<'a>) {
|
||||
// Join map keeps track of join point parameters so that we can keep them around while they still might be jumped to.
|
||||
let mut join_map: MutMap<JoinPointId, &'a [Param<'a>]> = MutMap::default();
|
||||
match stmt {
|
||||
Stmt::Let(sym, expr, _, following) => {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
match expr {
|
||||
Expr::Literal(_) => {}
|
||||
Expr::NullPointer => {}
|
||||
|
||||
Expr::Call(call) => self.scan_ast_call(call, stmt),
|
||||
|
||||
Expr::Tag { arguments, .. } => {
|
||||
for sym in *arguments {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
}
|
||||
}
|
||||
Expr::ExprBox { symbol } => {
|
||||
self.set_last_seen(*symbol, stmt);
|
||||
}
|
||||
Expr::ExprUnbox { symbol } => {
|
||||
self.set_last_seen(*symbol, stmt);
|
||||
}
|
||||
Expr::Struct(syms) => {
|
||||
for sym in *syms {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
}
|
||||
}
|
||||
Expr::StructAtIndex { structure, .. } => {
|
||||
self.set_last_seen(*structure, stmt);
|
||||
}
|
||||
Expr::GetTagId { structure, .. } => {
|
||||
self.set_last_seen(*structure, stmt);
|
||||
}
|
||||
Expr::UnionAtIndex { structure, .. } => {
|
||||
self.set_last_seen(*structure, stmt);
|
||||
}
|
||||
Expr::Array { elems, .. } => {
|
||||
for elem in *elems {
|
||||
if let ListLiteralElement::Symbol(sym) = elem {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
}
|
||||
}
|
||||
}
|
||||
Expr::Reuse {
|
||||
symbol, arguments, ..
|
||||
} => {
|
||||
self.set_last_seen(*symbol, stmt);
|
||||
for sym in *arguments {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
}
|
||||
}
|
||||
Expr::Reset { symbol, .. } | Expr::ResetRef { symbol, .. } => {
|
||||
self.set_last_seen(*symbol, stmt);
|
||||
}
|
||||
Expr::EmptyArray => {}
|
||||
Expr::RuntimeErrorFunction(_) => {}
|
||||
}
|
||||
self.scan_ast(following);
|
||||
}
|
||||
|
||||
Stmt::Switch {
|
||||
cond_symbol,
|
||||
branches,
|
||||
default_branch,
|
||||
..
|
||||
} => {
|
||||
self.set_last_seen(*cond_symbol, stmt);
|
||||
for (_, _, branch) in *branches {
|
||||
self.scan_ast(branch);
|
||||
}
|
||||
self.scan_ast(default_branch.1);
|
||||
}
|
||||
Stmt::Ret(sym) => {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
}
|
||||
Stmt::Refcounting(modify, following) => {
|
||||
let sym = modify.get_symbol();
|
||||
|
||||
self.set_last_seen(sym, stmt);
|
||||
self.scan_ast(following);
|
||||
}
|
||||
Stmt::Join {
|
||||
parameters,
|
||||
body: continuation,
|
||||
remainder,
|
||||
id: JoinPointId(sym),
|
||||
..
|
||||
} => {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
join_map.insert(JoinPointId(*sym), parameters);
|
||||
for param in *parameters {
|
||||
self.set_last_seen(param.symbol, stmt);
|
||||
}
|
||||
self.scan_ast(remainder);
|
||||
self.scan_ast(continuation);
|
||||
}
|
||||
Stmt::Jump(JoinPointId(sym), symbols) => {
|
||||
if let Some(parameters) = join_map.get(&JoinPointId(*sym)) {
|
||||
// Keep the parameters around. They will be overwritten when jumping.
|
||||
for param in *parameters {
|
||||
self.set_last_seen(param.symbol, stmt);
|
||||
}
|
||||
}
|
||||
for sym in *symbols {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
}
|
||||
}
|
||||
|
||||
Stmt::Dbg { .. } => todo!("dbg not implemented in the dev backend"),
|
||||
Stmt::Expect { .. } => todo!("expect is not implemented in the dev backend"),
|
||||
Stmt::ExpectFx { .. } => todo!("expect-fx is not implemented in the dev backend"),
|
||||
|
||||
Stmt::Crash(msg, _crash_tag) => {
|
||||
self.set_last_seen(*msg, stmt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn scan_ast_call(&mut self, call: &roc_mono::ir::Call, stmt: &roc_mono::ir::Stmt<'a>) {
|
||||
let roc_mono::ir::Call {
|
||||
call_type,
|
||||
arguments,
|
||||
} = call;
|
||||
|
||||
for sym in *arguments {
|
||||
self.set_last_seen(*sym, stmt);
|
||||
}
|
||||
|
||||
match call_type {
|
||||
CallType::ByName { .. } => {}
|
||||
CallType::LowLevel { .. } => {}
|
||||
CallType::HigherOrder { .. } => {}
|
||||
CallType::Foreign { .. } => {}
|
||||
}
|
||||
fn scan_ast(&mut self, stmt: &'a Stmt<'a>) {
|
||||
*self.last_seen_map() = LastSeenMap::scan_ast(stmt);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user