Merge pull request #946 from rtfeldman/optimize-rc

Fuse RC operations
This commit is contained in:
Richard Feldman 2021-01-30 21:43:10 -05:00 committed by GitHub
commit 44fd0351be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 1167 additions and 186 deletions

5
Cargo.lock generated
View File

@ -1572,9 +1572,9 @@ dependencies = [
[[package]]
name = "linked-hash-map"
version = "0.5.3"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8dd5a6d5999d9907cda8ed67bbd137d3af8085216c2ac62de5be860bd41f304a"
checksum = "7fb9b38af92608140b86b693604b9ffcc5824240a484d1ecd4795bacb2fe88f3"
[[package]]
name = "llvm-sys"
@ -2986,6 +2986,7 @@ version = "0.1.0"
dependencies = [
"bumpalo",
"indoc",
"linked-hash-map",
"maplit",
"pretty_assertions",
"quickcheck",

View File

@ -23,6 +23,24 @@ mod cli_run {
flags: &[&str],
expected_ending: &str,
use_valgrind: bool,
) {
check_output_with_stdin(
file,
"",
executable_filename,
flags,
expected_ending,
use_valgrind,
)
}
fn check_output_with_stdin(
file: &Path,
stdin_str: &str,
executable_filename: &str,
flags: &[&str],
expected_ending: &str,
use_valgrind: bool,
) {
let compile_out = run_roc(&[&["build", file.to_str().unwrap()], flags].concat());
if !compile_out.stderr.is_empty() {
@ -31,8 +49,10 @@ mod cli_run {
assert!(compile_out.status.success());
let out = if use_valgrind {
let (valgrind_out, raw_xml) =
run_with_valgrind(&[file.with_file_name(executable_filename).to_str().unwrap()]);
let (valgrind_out, raw_xml) = run_with_valgrind(
stdin_str,
&[file.with_file_name(executable_filename).to_str().unwrap()],
);
if valgrind_out.status.success() {
let memory_errors = extract_valgrind_errors(&raw_xml).unwrap_or_else(|err| {
@ -55,6 +75,7 @@ mod cli_run {
} else {
run_cmd(
file.with_file_name(executable_filename).to_str().unwrap(),
stdin_str,
&[],
)
};
@ -174,8 +195,9 @@ mod cli_run {
#[test]
#[serial(nqueens)]
fn run_nqueens_not_optimized() {
check_output(
check_output_with_stdin(
&example_file("benchmarks", "NQueens.roc"),
"",
"nqueens",
&[],
"4\n",

View File

@ -61,15 +61,29 @@ pub fn run_roc(args: &[&str]) -> Out {
}
#[allow(dead_code)]
pub fn run_cmd(cmd_name: &str, args: &[&str]) -> Out {
pub fn run_cmd(cmd_name: &str, stdin_str: &str, args: &[&str]) -> Out {
let mut cmd = Command::new(cmd_name);
for arg in args {
cmd.arg(arg);
}
let output = cmd
.output()
let mut child = cmd
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.unwrap_or_else(|_| panic!("failed to execute cmd `{}` in CLI test", cmd_name));
{
let stdin = child.stdin.as_mut().expect("Failed to open stdin");
stdin
.write_all(stdin_str.as_bytes())
.expect("Failed to write to stdin");
}
let output = child
.wait_with_output()
.unwrap_or_else(|_| panic!("failed to execute cmd `{}` in CLI test", cmd_name));
Out {
@ -80,7 +94,7 @@ pub fn run_cmd(cmd_name: &str, args: &[&str]) -> Out {
}
#[allow(dead_code)]
pub fn run_with_valgrind(args: &[&str]) -> (Out, String) {
pub fn run_with_valgrind(stdin_str: &str, args: &[&str]) -> (Out, String) {
//TODO: figure out if there is a better way to get the valgrind executable.
let mut cmd = Command::new("valgrind");
let named_tempfile =
@ -114,8 +128,23 @@ pub fn run_with_valgrind(args: &[&str]) -> (Out, String) {
cmd.arg(arg);
}
let output = cmd
.output()
cmd.stdin(Stdio::piped());
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
let mut child = cmd
.spawn()
.expect("failed to execute compiled `valgrind` binary in CLI test");
{
let stdin = child.stdin.as_mut().expect("Failed to open stdin");
stdin
.write_all(stdin_str.as_bytes())
.expect("Failed to write to stdin");
}
let output = child
.wait_with_output()
.expect("failed to execute compiled `valgrind` binary in CLI test");
let mut file = named_tempfile.into_file();

View File

@ -41,7 +41,7 @@ use roc_collections::all::{ImMap, MutSet};
use roc_module::ident::TagName;
use roc_module::low_level::LowLevel;
use roc_module::symbol::{Interns, ModuleId, Symbol};
use roc_mono::ir::{CallType, JoinPointId, Wrapped};
use roc_mono::ir::{BranchInfo, CallType, JoinPointId, ModifyRc, Wrapped};
use roc_mono::layout::{Builtin, ClosureLayout, Layout, LayoutIds, MemoryMode, UnionLayout};
use target_lexicon::CallingConvention;
@ -2040,7 +2040,7 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
cond_layout: cond_layout.clone(),
cond_symbol: *cond_symbol,
branches,
default_branch,
default_branch: default_branch.1,
ret_type,
};
@ -2113,24 +2113,65 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>(
// This doesn't currently do anything
context.i64_type().const_zero().into()
}
Inc(symbol, inc_amount, cont) => {
let (value, layout) = load_symbol_and_layout(scope, symbol);
let layout = layout.clone();
if layout.contains_refcounted() {
increment_refcount_layout(env, parent, layout_ids, *inc_amount, value, &layout);
}
/*
Inc(symbol1, 1, Dec(symbol2, cont)) if symbol1 == symbol2 => {
dbg!(symbol1);
build_exp_stmt(env, layout_ids, scope, parent, cont)
}
Dec(symbol, cont) => {
let (value, layout) = load_symbol_and_layout(scope, symbol);
*/
Refcounting(modify, cont) => {
use ModifyRc::*;
if layout.contains_refcounted() {
decrement_refcount_layout(env, parent, layout_ids, value, layout);
match modify {
Inc(symbol, inc_amount) => {
match cont {
Refcounting(ModifyRc::Dec(symbol1), contcont)
if *inc_amount == 1 && symbol == symbol1 =>
{
// the inc and dec cancel out
build_exp_stmt(env, layout_ids, scope, parent, contcont)
}
_ => {
let (value, layout) = load_symbol_and_layout(scope, symbol);
let layout = layout.clone();
if layout.contains_refcounted() {
increment_refcount_layout(
env,
parent,
layout_ids,
*inc_amount,
value,
&layout,
);
}
build_exp_stmt(env, layout_ids, scope, parent, cont)
}
}
}
Dec(symbol) => {
let (value, layout) = load_symbol_and_layout(scope, symbol);
if layout.contains_refcounted() {
decrement_refcount_layout(env, parent, layout_ids, value, layout);
}
build_exp_stmt(env, layout_ids, scope, parent, cont)
}
DecRef(symbol) => {
let (value, layout) = load_symbol_and_layout(scope, symbol);
if layout.is_refcounted() {
let value_ptr = value.into_pointer_value();
let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr);
refcount_ptr.decrement(env, layout);
}
build_exp_stmt(env, layout_ids, scope, parent, cont)
}
}
build_exp_stmt(env, layout_ids, scope, parent, cont)
}
RuntimeError(error_msg) => {
@ -2299,7 +2340,7 @@ fn extract_tag_discriminant_ptr<'a, 'ctx, 'env>(
struct SwitchArgsIr<'a, 'ctx> {
pub cond_symbol: Symbol,
pub cond_layout: Layout<'a>,
pub branches: &'a [(u64, roc_mono::ir::Stmt<'a>)],
pub branches: &'a [(u64, BranchInfo<'a>, roc_mono::ir::Stmt<'a>)],
pub default_branch: &'a roc_mono::ir::Stmt<'a>,
pub ret_type: BasicTypeEnum<'ctx>,
}
@ -2428,7 +2469,7 @@ fn build_switch_ir<'a, 'ctx, 'env>(
if let Layout::Builtin(Builtin::Int1) = cond_layout {
match (branches, default_branch) {
([(0, false_branch)], true_branch) | ([(1, true_branch)], false_branch) => {
([(0, _, false_branch)], true_branch) | ([(1, _, true_branch)], false_branch) => {
let then_block = context.append_basic_block(parent, "then_block");
let else_block = context.append_basic_block(parent, "else_block");
@ -2466,7 +2507,7 @@ fn build_switch_ir<'a, 'ctx, 'env>(
let default_block = context.append_basic_block(parent, "default");
let mut cases = Vec::with_capacity_in(branches.len(), arena);
for (int, _) in branches.iter() {
for (int, _, _) in branches.iter() {
// Switch constants must all be same type as switch value!
// e.g. this is incorrect, and will trigger a LLVM warning:
//
@ -2496,7 +2537,7 @@ fn build_switch_ir<'a, 'ctx, 'env>(
builder.build_switch(cond, default_block, &cases);
for ((_, branch_expr), (_, block)) in branches.iter().zip(cases) {
for ((_, _, branch_expr), (_, block)) in branches.iter().zip(cases) {
builder.position_at_end(block);
let branch_val = build_exp_stmt(env, layout_ids, scope, parent, branch_expr);

View File

@ -133,7 +133,7 @@ impl<'ctx> PointerToRefcount<'ctx> {
self.set_refcount(env, new_refcount);
}
fn decrement<'a, 'env>(&self, env: &Env<'a, 'ctx, 'env>, layout: &Layout<'a>) {
pub fn decrement<'a, 'env>(&self, env: &Env<'a, 'ctx, 'env>, layout: &Layout<'a>) {
let context = env.context;
let block = env.builder.get_insert_block().expect("to be in a function");
let di_location = env.builder.get_current_debug_location().unwrap();

View File

@ -1936,6 +1936,7 @@ mod gen_primitives {
}
#[test]
#[ignore]
fn rosetree_basic() {
assert_non_opt_evals_to!(
indoc!(

View File

@ -395,21 +395,19 @@ where
..
} => {
self.set_last_seen(*cond_symbol, stmt);
for (_, branch) in *branches {
for (_, _, branch) in *branches {
self.scan_ast(branch);
}
self.scan_ast(default_branch);
self.scan_ast(default_branch.1);
}
Stmt::Ret(sym) => {
self.set_last_seen(*sym, stmt);
}
Stmt::Rethrow => {}
Stmt::Inc(sym, _inc, following) => {
self.set_last_seen(*sym, stmt);
self.scan_ast(following);
}
Stmt::Dec(sym, following) => {
self.set_last_seen(*sym, stmt);
Stmt::Refcounting(modify, following) => {
let sym = modify.get_symbol();
self.set_last_seen(sym, stmt);
self.scan_ast(following);
}
Stmt::Join {

View File

@ -1767,7 +1767,7 @@ fn update<'a>(
}
MadeSpecializations {
module_id,
ident_ids,
mut ident_ids,
subs,
procedures,
external_specializations_requested,
@ -1779,32 +1779,40 @@ fn update<'a>(
state.module_cache.mono_problems.insert(module_id, problems);
state.procedures.extend(procedures);
state.constrained_ident_ids.insert(module_id, ident_ids);
state.timings.insert(module_id, module_timing);
for (module_id, requested) in external_specializations_requested {
let existing = match state
.module_cache
.external_specializations_requested
.entry(module_id)
{
Vacant(entry) => entry.insert(ExternalSpecializations::default()),
Occupied(entry) => entry.into_mut(),
};
existing.extend(requested);
}
let work = state
.dependencies
.notify(module_id, Phase::MakeSpecializations);
state.procedures.extend(procedures);
state.timings.insert(module_id, module_timing);
if state.dependencies.solved_all() && state.goal_phase == Phase::MakeSpecializations {
debug_assert!(work.is_empty(), "still work remaining {:?}", &work);
Proc::insert_refcount_operations(arena, &mut state.procedures);
Proc::optimize_refcount_operations(
arena,
module_id,
&mut ident_ids,
&mut state.procedures,
);
state.constrained_ident_ids.insert(module_id, ident_ids);
for (module_id, requested) in external_specializations_requested {
let existing = match state
.module_cache
.external_specializations_requested
.entry(module_id)
{
Vacant(entry) => entry.insert(ExternalSpecializations::default()),
Occupied(entry) => entry.into_mut(),
};
existing.extend(requested);
}
// display the mono IR of the module, for debug purposes
if roc_mono::ir::PRETTY_PRINT_IR_SYMBOLS {
let procs_string = state
@ -1830,6 +1838,21 @@ fn update<'a>(
// the originally requested module, we're all done!
return Ok(state);
} else {
state.constrained_ident_ids.insert(module_id, ident_ids);
for (module_id, requested) in external_specializations_requested {
let existing = match state
.module_cache
.external_specializations_requested
.entry(module_id)
{
Vacant(entry) => entry.insert(ExternalSpecializations::default()),
Occupied(entry) => entry.into_mut(),
};
existing.extend(requested);
}
start_tasks(work, &mut state, &injector, worker_listeners)?;
}

View File

@ -18,6 +18,7 @@ roc_problem = { path = "../problem" }
ven_pretty = { path = "../../vendor/pretty" }
bumpalo = { version = "3.2", features = ["collections"] }
ven_ena = { path = "../../vendor/ena" }
linked-hash-map = "0.5.4"
[dev-dependencies]
roc_constrain = { path = "../constrain" }

View File

@ -165,10 +165,10 @@ impl<'a> ParamMap<'a> {
default_branch,
..
} => {
stack.extend(branches.iter().map(|b| &b.1));
stack.push(default_branch);
stack.extend(branches.iter().map(|b| &b.2));
stack.push(default_branch.1);
}
Inc(_, _, _) | Dec(_, _) => unreachable!("these have not been introduced yet"),
Refcounting(_, _) => unreachable!("these have not been introduced yet"),
Ret(_) | Rethrow | Jump(_, _) | RuntimeError(_) => {
// these are terminal, do nothing
@ -508,12 +508,12 @@ impl<'a> BorrowInfState<'a> {
default_branch,
..
} => {
for (_, b) in branches.iter() {
for (_, _, b) in branches.iter() {
self.collect_stmt(b);
}
self.collect_stmt(default_branch);
self.collect_stmt(default_branch.1);
}
Inc(_, _, _) | Dec(_, _) => unreachable!("these have not been introduced yet"),
Refcounting(_, _) => unreachable!("these have not been introduced yet"),
Ret(_) | RuntimeError(_) | Rethrow => {
// these are terminal, do nothing

View File

@ -1,6 +1,6 @@
use crate::exhaustive::{Ctor, RenderAs, TagId, Union};
use crate::ir::{
DestructType, Env, Expr, JoinPointId, Literal, Param, Pattern, Procs, Stmt, Wrapped,
BranchInfo, DestructType, Env, Expr, JoinPointId, Literal, Param, Pattern, Procs, Stmt, Wrapped,
};
use crate::layout::{Builtin, Layout, LayoutCache, UnionLayout};
use roc_collections::all::{MutMap, MutSet};
@ -1355,20 +1355,86 @@ fn compile_test<'a>(
lhs: Symbol,
rhs: Symbol,
fail: &'a Stmt<'a>,
cond: Stmt<'a>,
) -> Stmt<'a> {
compile_test_help(
env,
ConstructorKnown::Neither,
ret_layout,
stores,
lhs,
rhs,
fail,
cond,
)
}
#[allow(clippy::too_many_arguments)]
fn compile_test_help<'a>(
env: &mut Env<'a, '_>,
branch_info: ConstructorKnown<'a>,
ret_layout: Layout<'a>,
stores: bumpalo::collections::Vec<'a, (Symbol, Layout<'a>, Expr<'a>)>,
lhs: Symbol,
rhs: Symbol,
fail: &'a Stmt<'a>,
mut cond: Stmt<'a>,
) -> Stmt<'a> {
// if test_symbol then cond else fail
let test_symbol = env.unique_symbol();
let arena = env.arena;
cond = crate::ir::cond(
env,
test_symbol,
Layout::Builtin(Builtin::Int1),
cond,
fail.clone(),
let (pass_info, fail_info) = {
use ConstructorKnown::*;
match branch_info {
Both {
scrutinee,
layout,
pass,
fail,
} => {
let pass_info = BranchInfo::Constructor {
scrutinee,
layout: layout.clone(),
tag_id: pass,
};
let fail_info = BranchInfo::Constructor {
scrutinee,
layout: layout.clone(),
tag_id: fail,
};
(pass_info, fail_info)
}
OnlyPass {
scrutinee,
layout,
tag_id,
} => {
let pass_info = BranchInfo::Constructor {
scrutinee,
layout: layout.clone(),
tag_id,
};
(pass_info, BranchInfo::None)
}
Neither => (BranchInfo::None, BranchInfo::None),
}
};
let branches = env.arena.alloc([(1u64, pass_info, cond)]);
let default_branch = (fail_info, &*env.arena.alloc(fail.clone()));
cond = Stmt::Switch {
cond_symbol: test_symbol,
cond_layout: Layout::Builtin(Builtin::Int1),
ret_layout,
);
branches,
default_branch,
};
let test = Expr::Call(crate::ir::Call {
call_type: crate::ir::CallType::LowLevel { op: LowLevel::Eq },
@ -1412,6 +1478,53 @@ fn compile_tests<'a>(
cond
}
enum ConstructorKnown<'a> {
Both {
scrutinee: Symbol,
layout: Layout<'a>,
pass: u8,
fail: u8,
},
OnlyPass {
scrutinee: Symbol,
layout: Layout<'a>,
tag_id: u8,
},
Neither,
}
impl<'a> ConstructorKnown<'a> {
fn from_test_chain(
cond_symbol: Symbol,
cond_layout: &Layout<'a>,
test_chain: &[(Path, Test)],
) -> Self {
match test_chain {
[(path, test)] => match (path, test) {
(Path::Empty, Test::IsCtor { tag_id, union, .. }) => {
if union.alternatives.len() == 2 {
// excluded middle: we also know the tag_id in the fail branch
ConstructorKnown::Both {
layout: cond_layout.clone(),
scrutinee: cond_symbol,
pass: *tag_id,
fail: (*tag_id == 0) as u8,
}
} else {
ConstructorKnown::OnlyPass {
layout: cond_layout.clone(),
scrutinee: cond_symbol,
tag_id: *tag_id,
}
}
}
_ => ConstructorKnown::Neither,
},
_ => ConstructorKnown::Neither,
}
}
}
// TODO procs and layout are currently unused, but potentially required
// for defining optional fields?
// if not, do remove
@ -1447,8 +1560,6 @@ fn decide_to_branching<'a>(
} => {
// generate a (nested) if-then-else
let (tests, guard) = stores_and_condition(env, cond_symbol, &cond_layout, test_chain);
let pass_expr = decide_to_branching(
env,
procs,
@ -1471,6 +1582,11 @@ fn decide_to_branching<'a>(
jumps,
);
let chain_branch_info =
ConstructorKnown::from_test_chain(cond_symbol, &cond_layout, &test_chain);
let (tests, guard) = stores_and_condition(env, cond_symbol, &cond_layout, test_chain);
let number_of_tests = tests.len() as i64 + guard.is_some() as i64;
debug_assert!(number_of_tests > 0);
@ -1478,7 +1594,26 @@ fn decide_to_branching<'a>(
let fail = env.arena.alloc(fail_expr);
if number_of_tests == 1 {
// if there is just one test, compile to a simple if-then-else
compile_tests(env, ret_layout, tests, guard, fail, pass_expr)
if guard.is_none() {
// use knowledge about constructors for optimization
debug_assert_eq!(tests.len(), 1);
let (new_stores, lhs, rhs, _layout) = tests.into_iter().next().unwrap();
compile_test_help(
env,
chain_branch_info,
ret_layout.clone(),
new_stores,
lhs,
rhs,
fail,
pass_expr,
)
} else {
compile_tests(env, ret_layout, tests, guard, fail, pass_expr)
}
} else {
// otherwise, we use a join point so the code for the `else` case
// is only generated once.
@ -1540,7 +1675,7 @@ fn decide_to_branching<'a>(
other => todo!("other {:?}", other),
};
branches.push((tag, branch));
branches.push((tag, BranchInfo::None, branch));
}
// We have learned more about the exact layout of the cond (based on the path)
@ -1549,7 +1684,7 @@ fn decide_to_branching<'a>(
cond_layout: inner_cond_layout,
cond_symbol: inner_cond_symbol,
branches: branches.into_bump_slice(),
default_branch: env.arena.alloc(default_branch),
default_branch: (BranchInfo::None, env.arena.alloc(default_branch)),
ret_layout,
};

View File

@ -0,0 +1,570 @@
use crate::ir::{BranchInfo, Expr, ModifyRc, Stmt, Wrapped};
use crate::layout::{Layout, UnionLayout};
use bumpalo::collections::Vec;
use bumpalo::Bump;
use linked_hash_map::LinkedHashMap;
use roc_collections::all::MutMap;
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
// This file is heavily inspired by the Perceus paper
//
// https://www.microsoft.com/en-us/research/uploads/prod/2020/11/perceus-tr-v1.pdf
//
// With how we insert RC instructions, this pattern is very common:
//
// when xs is
// Cons x xx ->
// inc x;
// inc xx;
// dec xs;
// ...
//
// This pattern is very inefficient, because it will first increment the tail (recursively),
// and then decrement it again. We can see this more clearly if we inline/specialize the `dec xs`
//
// when xs is
// Cons x xx ->
// inc x;
// inc xx;
// dec x;
// dec xx;
// decref xs
// ...
//
// Here `decref` non-recursively decrements (and possibly frees) `xs`. Now the idea is that we can
// fuse `inc x; dec x` by just doing nothing: they cancel out
//
// We can do slightly more, in the `Nil` case
//
// when xs is
// ...
// Nil ->
// dec xs;
// accum
//
// Here we know that `Nil` is represented by NULL (a linked list has a NullableUnwrapped layout),
// so we can just drop the `dec xs`
//
// # complications
//
// Let's work through the `Cons x xx` example
//
// First we need to know the constructor of `xs` in the particular block. This information would
// normally be lost when we compile pattern matches, but we keep it in the `BrachInfo` field of
// switch branches. here we also store the symbol that was switched on, and the layout of that
// symbol.
//
// Next, we need to know that `x` and `xx` alias the head and tail of `xs`. We store that
// information when encountering a `AccessAtIndex` into `xs`.
//
// In most cases these two pieces of information are enough. We keep track of a
// `LinkedHashMap<Symbol, i64>`: `LinkedHashMap` remembers insertion order, which is crucial here.
// The `i64` value represents the increment (positive value) or decrement (negative value). When
// the value is 0, increments and decrements have cancelled out and we just emit nothing.
//
// We need to do slightly more work in the case of
//
// when xs is
// Cons _ xx ->
// recurse xx (1 + accum)
//
// In this case, the head is not bound. That's OK when the list elements are not refcounted (or
// contain anything refcounted). But when they do, we can't expand the `dec xs` because there is no
// way to reference the head element.
//
// Our refcounting mechanism can't deal well with unused variables (it'll leak their memory). But
// we can insert the access after RC instructions have been inserted. So in the above case we
// actually get
//
// when xs is
// Cons _ xx ->
// let v1 = AccessAtIndex 1 xs
// inc v1;
// let xx = AccessAtIndex 2 xs
// inc xx;
// dec v1;
// dec xx;
// decref xs;
// recurse xx (1 + accum)
//
// Here we see another problem: the increments and decrements cannot be fused immediately.
// Therefore we add a rule that we can "push down" increments and decrements past
//
// - `Let`s binding a `AccessAtIndex`
// - refcount operations
//
// This allows the aforementioned `LinkedHashMap` to accumulate all changes, and then emit
// all (uncancelled) modifications at once before any "non-push-downable-stmt", hence:
//
// when xs is
// Cons _ xx ->
// let v1 = AccessAtIndex 1 xs
// let xx = AccessAtIndex 2 xs
// dec v1;
// decref xs;
// recurse xx (1 + accum)
pub struct Env<'a, 'i> {
/// bump allocator
pub arena: &'a Bump,
/// required for creating new `Symbol`s
pub home: ModuleId,
pub ident_ids: &'i mut IdentIds,
/// layout of the symbol
pub layout_map: MutMap<Symbol, Layout<'a>>,
/// record for each symbol, the aliases of its fields
pub alias_map: MutMap<Symbol, MutMap<u64, Symbol>>,
/// for a symbol (found in a `when x is`), record in which branch we are
pub constructor_map: MutMap<Symbol, u64>,
/// increments and decrements deferred until later
pub deferred: Deferred<'a>,
}
#[derive(Debug)]
pub struct Deferred<'a> {
pub inc_dec_map: LinkedHashMap<Symbol, i64>,
pub assignments: Vec<'a, (Symbol, Expr<'a>, Layout<'a>)>,
pub decrefs: Vec<'a, Symbol>,
}
impl<'a, 'i> Env<'a, 'i> {
fn insert_branch_info(&mut self, info: &BranchInfo<'a>) {
match info {
BranchInfo::Constructor {
layout,
scrutinee,
tag_id,
} => {
self.constructor_map.insert(*scrutinee, *tag_id as u64);
self.layout_map.insert(*scrutinee, layout.clone());
}
BranchInfo::None => (),
}
}
fn remove_branch_info(&mut self, info: &BranchInfo) {
match info {
BranchInfo::Constructor { scrutinee, .. } => {
self.constructor_map.remove(scrutinee);
self.layout_map.remove(scrutinee);
}
BranchInfo::None => (),
}
}
pub fn unique_symbol(&mut self) -> Symbol {
let ident_id = self.ident_ids.gen_unique();
self.home.register_debug_idents(&self.ident_ids);
Symbol::new(self.home, ident_id)
}
fn manual_unique_symbol(home: ModuleId, ident_ids: &mut IdentIds) -> Symbol {
let ident_id = ident_ids.gen_unique();
home.register_debug_idents(&ident_ids);
Symbol::new(home, ident_id)
}
}
fn layout_for_constructor<'a>(
_arena: &'a Bump,
layout: &Layout<'a>,
constructor: u64,
) -> ConstructorLayout<&'a [Layout<'a>]> {
use ConstructorLayout::*;
use Layout::*;
match layout {
Union(variant) => {
use UnionLayout::*;
match variant {
NullableUnwrapped {
nullable_id,
other_fields,
} => {
if (constructor > 0) == *nullable_id {
ConstructorLayout::IsNull
} else {
ConstructorLayout::HasFields(other_fields)
}
}
NullableWrapped {
nullable_id,
other_tags,
} => {
if constructor as i64 == *nullable_id {
ConstructorLayout::IsNull
} else {
ConstructorLayout::HasFields(other_tags[constructor as usize])
}
}
NonRecursive(fields) | Recursive(fields) => HasFields(fields[constructor as usize]),
NonNullableUnwrapped(fields) => {
debug_assert_eq!(constructor, 0);
HasFields(fields)
}
}
}
_ => unreachable!(),
}
}
fn work_for_constructor<'a>(
env: &mut Env<'a, '_>,
symbol: &Symbol,
) -> ConstructorLayout<Vec<'a, Symbol>> {
use ConstructorLayout::*;
let mut result = Vec::new_in(env.arena);
let constructor = match env.constructor_map.get(symbol) {
None => return ConstructorLayout::Unknown,
Some(v) => *v,
};
let full_layout = match env.layout_map.get(symbol) {
None => return ConstructorLayout::Unknown,
Some(v) => v,
};
let field_aliases = env.alias_map.get(symbol);
match layout_for_constructor(env.arena, full_layout, constructor) {
Unknown => Unknown,
IsNull => IsNull,
HasFields(cons_layout) => {
// figure out if there is at least one aliased refcounted field. Only then
// does it make sense to inline the decrement
let at_least_one_aliased = (|| {
for (i, field_layout) in cons_layout.iter().enumerate() {
if field_layout.contains_refcounted()
&& field_aliases.and_then(|map| map.get(&(i as u64))).is_some()
{
return true;
}
}
false
})();
// for each field, if it has refcounted content, check if it has an alias
// if so, use the alias, otherwise load the field.
for (i, field_layout) in cons_layout.iter().enumerate() {
if field_layout.contains_refcounted() {
match field_aliases.and_then(|map| map.get(&(i as u64))) {
Some(alias_symbol) => {
// the field was bound in a pattern match
result.push(*alias_symbol);
}
None if at_least_one_aliased => {
// the field was not bound in a pattern match
// we have to extract it now, but we only extract it
// if at least one field is aliased.
let expr = Expr::AccessAtIndex {
index: i as u64,
field_layouts: cons_layout,
structure: *symbol,
wrapped: Wrapped::MultiTagUnion,
};
// create a fresh symbol for this field
let alias_symbol = Env::manual_unique_symbol(env.home, env.ident_ids);
let layout = if let Layout::RecursivePointer = field_layout {
full_layout.clone()
} else {
field_layout.clone()
};
env.deferred.assignments.push((alias_symbol, expr, layout));
result.push(alias_symbol);
}
None => {
// if all refcounted fields were unaliased, generate a normal decrement
// of the whole structure (less code generated this way)
return ConstructorLayout::Unknown;
}
}
}
}
ConstructorLayout::HasFields(result)
}
}
}
fn can_push_inc_through(stmt: &Stmt) -> bool {
use Stmt::*;
match stmt {
Let(_, expr, _, _) => {
// we can always delay an increment/decrement until after a field access
matches!(expr, Expr::AccessAtIndex { .. } | Expr::Literal(_))
}
Refcounting(ModifyRc::Inc(_, _), _) => true,
Refcounting(ModifyRc::Dec(_), _) => true,
_ => false,
}
}
#[derive(Debug)]
enum ConstructorLayout<T> {
IsNull,
HasFields(T),
Unknown,
}
pub fn expand_and_cancel<'a>(env: &mut Env<'a, '_>, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> {
use Stmt::*;
let mut deferred_default = Deferred {
inc_dec_map: Default::default(),
assignments: Vec::new_in(env.arena),
decrefs: Vec::new_in(env.arena),
};
let deferred = if can_push_inc_through(stmt) {
deferred_default
} else {
std::mem::swap(&mut deferred_default, &mut env.deferred);
deferred_default
};
let mut result = {
match stmt {
Let(mut symbol, expr, layout, cont) => {
env.layout_map.insert(symbol, layout.clone());
let mut expr = expr;
let mut layout = layout;
let mut cont = cont;
// prevent long chains of `Let`s from blowing the stack
let mut literal_stack = Vec::new_in(env.arena);
while !matches!(&expr, Expr::AccessAtIndex { .. } ) {
if let Stmt::Let(symbol1, expr1, layout1, cont1) = cont {
literal_stack.push((symbol, expr.clone(), layout.clone()));
symbol = *symbol1;
expr = expr1;
layout = layout1;
cont = cont1;
} else {
break;
}
}
let new_cont;
if let Expr::AccessAtIndex {
structure, index, ..
} = &expr
{
let entry = env
.alias_map
.entry(*structure)
.or_insert_with(MutMap::default);
entry.insert(*index, symbol);
new_cont = expand_and_cancel(env, cont);
// make sure to remove the alias, so other branches don't use it by accident
env.alias_map
.get_mut(structure)
.and_then(|map| map.remove(index));
} else {
new_cont = expand_and_cancel(env, cont);
}
let stmt = Let(symbol, expr.clone(), layout.clone(), new_cont);
let mut stmt = &*env.arena.alloc(stmt);
for (symbol, expr, layout) in literal_stack.into_iter().rev() {
stmt = env.arena.alloc(Stmt::Let(symbol, expr, layout, stmt));
}
stmt
}
Switch {
cond_symbol,
cond_layout,
ret_layout,
branches,
default_branch,
} => {
let mut new_branches = Vec::with_capacity_in(branches.len(), env.arena);
for (id, info, branch) in branches.iter() {
env.insert_branch_info(info);
let branch = expand_and_cancel(env, branch);
env.remove_branch_info(info);
env.constructor_map.remove(cond_symbol);
new_branches.push((*id, info.clone(), branch.clone()));
}
env.insert_branch_info(&default_branch.0);
let new_default = (
default_branch.0.clone(),
expand_and_cancel(env, default_branch.1),
);
env.remove_branch_info(&default_branch.0);
let stmt = Switch {
cond_symbol: *cond_symbol,
cond_layout: cond_layout.clone(),
ret_layout: ret_layout.clone(),
branches: new_branches.into_bump_slice(),
default_branch: new_default,
};
&*env.arena.alloc(stmt)
}
Refcounting(ModifyRc::DecRef(_symbol), _cont) => unreachable!("not introduced yet"),
Refcounting(ModifyRc::Dec(symbol), cont) => {
use ConstructorLayout::*;
match work_for_constructor(env, symbol) {
HasFields(dec_symbols) => {
// we can inline the decrement
// decref the current cell
env.deferred.decrefs.push(*symbol);
// and record decrements for all the fields
for dec_symbol in dec_symbols {
let count = env.deferred.inc_dec_map.entry(dec_symbol).or_insert(0);
*count -= 1;
}
}
Unknown => {
// we can't inline the decrement; just record it
let count = env.deferred.inc_dec_map.entry(*symbol).or_insert(0);
*count -= 1;
}
IsNull => {
// we decrement a value represented as `NULL` at runtime;
// we can drop this decrement completely
}
}
expand_and_cancel(env, cont)
}
Refcounting(ModifyRc::Inc(symbol, inc_amount), cont) => {
let count = env.deferred.inc_dec_map.entry(*symbol).or_insert(0);
*count += *inc_amount as i64;
expand_and_cancel(env, cont)
}
Invoke {
symbol,
call,
layout,
pass,
fail,
} => {
let pass = expand_and_cancel(env, pass);
let fail = expand_and_cancel(env, fail);
let stmt = Invoke {
symbol: *symbol,
call: call.clone(),
layout: layout.clone(),
pass,
fail,
};
env.arena.alloc(stmt)
}
Join {
id,
parameters,
continuation,
remainder,
} => {
let continuation = expand_and_cancel(env, continuation);
let remainder = expand_and_cancel(env, remainder);
let stmt = Join {
id: *id,
parameters,
continuation,
remainder,
};
env.arena.alloc(stmt)
}
Rethrow | Ret(_) | Jump(_, _) | RuntimeError(_) => stmt,
}
};
for symbol in deferred.decrefs {
let stmt = Refcounting(ModifyRc::DecRef(symbol), result);
result = env.arena.alloc(stmt);
}
// do all decrements
for (symbol, amount) in deferred.inc_dec_map.iter().rev() {
use std::cmp::Ordering;
match amount.cmp(&0) {
Ordering::Equal => {
// do nothing else
}
Ordering::Greater => {
// do nothing yet
}
Ordering::Less => {
// the RC insertion should not double decrement in a block
debug_assert_eq!(*amount, -1);
// insert missing decrements
let stmt = Refcounting(ModifyRc::Dec(*symbol), result);
result = env.arena.alloc(stmt);
}
}
}
for (symbol, amount) in deferred.inc_dec_map.into_iter().rev() {
use std::cmp::Ordering;
match amount.cmp(&0) {
Ordering::Equal => {
// do nothing else
}
Ordering::Greater => {
// insert missing increments
let stmt = Refcounting(ModifyRc::Inc(symbol, amount as u64), result);
result = env.arena.alloc(stmt);
}
Ordering::Less => {
// already done
}
}
}
for (symbol, expr, layout) in deferred.assignments {
let stmt = Stmt::Let(symbol, expr, layout, result);
result = env.arena.alloc(stmt);
}
result
}

View File

@ -1,5 +1,5 @@
use crate::borrow::ParamMap;
use crate::ir::{Expr, JoinPointId, Param, Proc, Stmt};
use crate::ir::{Expr, JoinPointId, ModifyRc, Param, Proc, Stmt};
use crate::layout::Layout;
use bumpalo::collections::Vec;
use bumpalo::Bump;
@ -52,8 +52,9 @@ pub fn occuring_variables(stmt: &Stmt<'_>) -> (MutSet<Symbol>, MutSet<Symbol>) {
Rethrow => {}
Inc(symbol, _, cont) | Dec(symbol, cont) => {
result.insert(*symbol);
Refcounting(modify, cont) => {
let symbol = modify.get_symbol();
result.insert(symbol);
stack.push(cont);
}
@ -81,8 +82,8 @@ pub fn occuring_variables(stmt: &Stmt<'_>) -> (MutSet<Symbol>, MutSet<Symbol>) {
} => {
result.insert(*cond_symbol);
stack.extend(branches.iter().map(|(_, s)| s));
stack.push(default_branch);
stack.extend(branches.iter().map(|(_, _, s)| s));
stack.push(default_branch.1);
}
RuntimeError(_) => {}
@ -277,7 +278,8 @@ impl<'a> Context<'a> {
return stmt;
}
self.arena.alloc(Stmt::Inc(symbol, inc_amount, stmt))
let modify = ModifyRc::Inc(symbol, inc_amount);
self.arena.alloc(Stmt::Refcounting(modify, stmt))
}
fn add_dec(&self, symbol: Symbol, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> {
@ -293,7 +295,8 @@ impl<'a> Context<'a> {
return stmt;
}
self.arena.alloc(Stmt::Dec(symbol, stmt))
let modify = ModifyRc::Dec(symbol);
self.arena.alloc(Stmt::Refcounting(modify, stmt))
}
fn add_inc_before_consume_all(
@ -820,13 +823,13 @@ impl<'a> Context<'a> {
let case_live_vars = collect_stmt(stmt, &self.jp_live_vars, MutSet::default());
let branches = Vec::from_iter_in(
branches.iter().map(|(label, branch)| {
branches.iter().map(|(label, info, branch)| {
// TODO should we use ctor info like Lean?
let ctx = self.clone();
let (b, alt_live_vars) = ctx.visit_stmt(branch);
let b = ctx.add_dec_for_alt(&case_live_vars, &alt_live_vars, b);
(*label, b.clone())
(*label, info.clone(), b.clone())
}),
self.arena,
)
@ -835,8 +838,12 @@ impl<'a> Context<'a> {
let default_branch = {
// TODO should we use ctor info like Lean?
let ctx = self.clone();
let (b, alt_live_vars) = ctx.visit_stmt(default_branch);
ctx.add_dec_for_alt(&case_live_vars, &alt_live_vars, b)
let (b, alt_live_vars) = ctx.visit_stmt(default_branch.1);
(
default_branch.0.clone(),
ctx.add_dec_for_alt(&case_live_vars, &alt_live_vars, b),
)
};
let switch = self.arena.alloc(Switch {
@ -850,7 +857,7 @@ impl<'a> Context<'a> {
(switch, case_live_vars)
}
RuntimeError(_) | Inc(_, _, _) | Dec(_, _) => (stmt, MutSet::default()),
RuntimeError(_) | Refcounting(_, _) => (stmt, MutSet::default()),
}
}
}
@ -901,23 +908,12 @@ pub fn collect_stmt(
vars
}
Inc(symbol, _, cont) | Dec(symbol, cont) => {
vars.insert(*symbol);
Refcounting(modify, cont) => {
let symbol = modify.get_symbol();
vars.insert(symbol);
collect_stmt(cont, jp_live_vars, vars)
}
Jump(id, arguments) => {
vars.extend(arguments.iter().copied());
// NOTE deviation from Lean
// we fall through when no join point is available
if let Some(jvars) = jp_live_vars.get(id) {
vars.extend(jvars);
}
vars
}
Join {
id: j,
parameters,
@ -935,6 +931,18 @@ pub fn collect_stmt(
collect_stmt(b, &jp_live_vars, vars)
}
Jump(id, arguments) => {
vars.extend(arguments.iter().copied());
// NOTE deviation from Lean
// we fall through when no join point is available
if let Some(jvars) = jp_live_vars.get(id) {
vars.extend(jvars);
}
vars
}
Switch {
cond_symbol,
branches,
@ -943,11 +951,11 @@ pub fn collect_stmt(
} => {
vars.insert(*cond_symbol);
for (_, branch) in branches.iter() {
for (_, _info, branch) in branches.iter() {
vars.extend(collect_stmt(branch, jp_live_vars, vars.clone()));
}
vars.extend(collect_stmt(default_branch, jp_live_vars, vars.clone()));
vars.extend(collect_stmt(default_branch.1, jp_live_vars, vars.clone()));
vars
}

View File

@ -157,6 +157,36 @@ impl<'a> Proc<'a> {
crate::inc_dec::visit_proc(arena, borrow_params, proc);
}
}
pub fn optimize_refcount_operations<'i>(
arena: &'a Bump,
home: ModuleId,
ident_ids: &'i mut IdentIds,
procs: &mut MutMap<(Symbol, Layout<'a>), Proc<'a>>,
) {
use crate::expand_rc;
let deferred = expand_rc::Deferred {
inc_dec_map: Default::default(),
assignments: Vec::new_in(arena),
decrefs: Vec::new_in(arena),
};
let mut env = expand_rc::Env {
home,
arena,
ident_ids,
layout_map: Default::default(),
alias_map: Default::default(),
constructor_map: Default::default(),
deferred,
};
for (_, proc) in procs.iter_mut() {
let b = expand_rc::expand_and_cancel(&mut env, arena.alloc(proc.body.clone()));
proc.body = b.clone();
}
}
}
#[derive(Clone, Debug, Default)]
@ -729,8 +759,8 @@ pub fn cond<'a>(
fail: Stmt<'a>,
ret_layout: Layout<'a>,
) -> Stmt<'a> {
let branches = env.arena.alloc([(1u64, pass)]);
let default_branch = env.arena.alloc(fail);
let branches = env.arena.alloc([(1u64, BranchInfo::None, pass)]);
let default_branch = (BranchInfo::None, &*env.arena.alloc(fail));
Stmt::Switch {
cond_symbol,
@ -758,16 +788,15 @@ pub enum Stmt<'a> {
cond_layout: Layout<'a>,
/// The u64 in the tuple will be compared directly to the condition Expr.
/// If they are equal, this branch will be taken.
branches: &'a [(u64, Stmt<'a>)],
branches: &'a [(u64, BranchInfo<'a>, Stmt<'a>)],
/// If no other branches pass, this default branch will be taken.
default_branch: &'a Stmt<'a>,
default_branch: (BranchInfo<'a>, &'a Stmt<'a>),
/// Each branch must return a value of this type.
ret_layout: Layout<'a>,
},
Ret(Symbol),
Rethrow,
Inc(Symbol, u64, &'a Stmt<'a>),
Dec(Symbol, &'a Stmt<'a>),
Refcounting(ModifyRc, &'a Stmt<'a>),
Join {
id: JoinPointId,
parameters: &'a [Param<'a>],
@ -780,6 +809,91 @@ pub enum Stmt<'a> {
RuntimeError(&'a str),
}
/// in the block below, symbol `scrutinee` is assumed be be of shape `tag_id`
#[derive(Clone, Debug, PartialEq)]
pub enum BranchInfo<'a> {
None,
Constructor {
scrutinee: Symbol,
layout: Layout<'a>,
tag_id: u8,
},
}
impl<'a> BranchInfo<'a> {
pub fn to_doc<'b, D, A>(&'b self, alloc: &'b D) -> DocBuilder<'b, D, A>
where
D: DocAllocator<'b, A>,
D::Doc: Clone,
A: Clone,
{
use BranchInfo::*;
match self {
Constructor {
tag_id,
scrutinee,
layout: _,
} if PRETTY_PRINT_IR_SYMBOLS => alloc
.hardline()
.append(" BranchInfo: { scrutinee: ")
.append(symbol_to_doc(alloc, *scrutinee))
.append(", tag_id: ")
.append(format!("{}", tag_id))
.append("} "),
_ => alloc.text(""),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum ModifyRc {
Inc(Symbol, u64),
Dec(Symbol),
DecRef(Symbol),
}
impl ModifyRc {
pub fn to_doc<'b, D, A>(&'b self, alloc: &'b D) -> DocBuilder<'b, D, A>
where
D: DocAllocator<'b, A>,
D::Doc: Clone,
A: Clone,
{
use ModifyRc::*;
match self {
Inc(symbol, 1) => alloc
.text("inc ")
.append(symbol_to_doc(alloc, *symbol))
.append(";"),
Inc(symbol, n) => alloc
.text("inc ")
.append(alloc.text(format!("{}", n)))
.append(symbol_to_doc(alloc, *symbol))
.append(";"),
Dec(symbol) => alloc
.text("dec ")
.append(symbol_to_doc(alloc, *symbol))
.append(";"),
DecRef(symbol) => alloc
.text("decref ")
.append(symbol_to_doc(alloc, *symbol))
.append(";"),
}
}
pub fn get_symbol(&self) -> Symbol {
use ModifyRc::*;
match self {
Inc(symbol, _) => *symbol,
Dec(symbol) => *symbol,
DecRef(symbol) => *symbol,
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum Literal<'a> {
// Literals
@ -1133,13 +1247,18 @@ impl<'a> Stmt<'a> {
.text("let ")
.append(symbol_to_doc(alloc, *symbol))
//.append(" : ")
//.append(alloc.text(format!("{:?}", layout)))
//.append(alloc.text(format!("{:?}", _layout)))
.append(" = ")
.append(expr.to_doc(alloc))
.append(";")
.append(alloc.hardline())
.append(cont.to_doc(alloc)),
Refcounting(modify, cont) => modify
.to_doc(alloc)
.append(alloc.hardline())
.append(cont.to_doc(alloc)),
Invoke {
symbol,
call,
@ -1186,16 +1305,18 @@ impl<'a> Stmt<'a> {
..
} => {
match branches {
[(1, pass)] => {
let fail = default_branch;
[(1, info, pass)] => {
let fail = default_branch.1;
alloc
.text("if ")
.append(symbol_to_doc(alloc, *cond_symbol))
.append(" then")
.append(info.to_doc(alloc))
.append(alloc.hardline())
.append(pass.to_doc(alloc).indent(4))
.append(alloc.hardline())
.append(alloc.text("else"))
.append(default_branch.0.to_doc(alloc))
.append(alloc.hardline())
.append(fail.to_doc(alloc).indent(4))
}
@ -1204,12 +1325,12 @@ impl<'a> Stmt<'a> {
let default_doc = alloc
.text("default:")
.append(alloc.hardline())
.append(default_branch.to_doc(alloc).indent(4))
.append(default_branch.1.to_doc(alloc).indent(4))
.indent(4);
let branches_docs = branches
.iter()
.map(|(tag, expr)| {
.map(|(tag, _info, expr)| {
alloc
.text(format!("case {}:", tag))
.append(alloc.hardline())
@ -1267,25 +1388,6 @@ impl<'a> Stmt<'a> {
.append(alloc.intersperse(it, alloc.space()))
.append(";")
}
Inc(symbol, 1, cont) => alloc
.text("inc ")
.append(symbol_to_doc(alloc, *symbol))
.append(";")
.append(alloc.hardline())
.append(cont.to_doc(alloc)),
Inc(symbol, n, cont) => alloc
.text("inc ")
.append(alloc.text(format!("{}", n)))
.append(symbol_to_doc(alloc, *symbol))
.append(";")
.append(alloc.hardline())
.append(cont.to_doc(alloc)),
Dec(symbol, cont) => alloc
.text("dec ")
.append(symbol_to_doc(alloc, *symbol))
.append(";")
.append(alloc.hardline())
.append(cont.to_doc(alloc)),
}
}
@ -4657,17 +4759,17 @@ fn substitute_in_stmt_help<'a>(
default_branch,
ret_layout,
} => {
let opt_default = substitute_in_stmt_help(arena, default_branch, subs);
let opt_default = substitute_in_stmt_help(arena, default_branch.1, subs);
let mut did_change = false;
let opt_branches = Vec::from_iter_in(
branches.iter().map(|(label, branch)| {
branches.iter().map(|(label, info, branch)| {
match substitute_in_stmt_help(arena, branch, subs) {
None => None,
Some(branch) => {
did_change = true;
Some((*label, branch.clone()))
Some((*label, info.clone(), branch.clone()))
}
}
}),
@ -4675,7 +4777,10 @@ fn substitute_in_stmt_help<'a>(
);
if opt_default.is_some() || did_change {
let default_branch = opt_default.unwrap_or(default_branch);
let default_branch = (
default_branch.0.clone(),
opt_default.unwrap_or(default_branch.1),
);
let branches = if did_change {
let new = Vec::from_iter_in(
@ -4708,14 +4813,13 @@ fn substitute_in_stmt_help<'a>(
Some(s) => Some(arena.alloc(Ret(s))),
None => None,
},
Inc(symbol, inc, cont) => match substitute_in_stmt_help(arena, cont, subs) {
Some(cont) => Some(arena.alloc(Inc(*symbol, *inc, cont))),
None => None,
},
Dec(symbol, cont) => match substitute_in_stmt_help(arena, cont, subs) {
Some(cont) => Some(arena.alloc(Dec(*symbol, cont))),
None => None,
},
Refcounting(modify, cont) => {
// TODO should we substitute in the ModifyRc?
match substitute_in_stmt_help(arena, cont, subs) {
Some(cont) => Some(arena.alloc(Refcounting(*modify, cont))),
None => None,
}
}
Jump(id, args) => {
let mut did_change = false;

View File

@ -648,8 +648,9 @@ impl<'a> Layout<'a> {
| NonNullableUnwrapped(_) => true,
}
}
RecursivePointer => true,
Closure(_, closure_layout, _) => closure_layout.contains_refcounted(),
FunctionPointer(_, _) | RecursivePointer | Pointer(_) => false,
FunctionPointer(_, _) | Pointer(_) => false,
}
}
}

View File

@ -3,6 +3,7 @@
#![allow(clippy::large_enum_variant)]
pub mod borrow;
pub mod expand_rc;
pub mod inc_dec;
pub mod ir;
pub mod layout;

View File

@ -181,17 +181,17 @@ fn insert_jumps<'a>(
default_branch,
ret_layout,
} => {
let opt_default = insert_jumps(arena, default_branch, goal_id, needle);
let opt_default = insert_jumps(arena, default_branch.1, goal_id, needle);
let mut did_change = false;
let opt_branches = Vec::from_iter_in(
branches.iter().map(|(label, branch)| {
branches.iter().map(|(label, info, branch)| {
match insert_jumps(arena, branch, goal_id, needle) {
None => None,
Some(branch) => {
did_change = true;
Some((*label, branch.clone()))
Some((*label, info.clone(), branch.clone()))
}
}
}),
@ -199,7 +199,10 @@ fn insert_jumps<'a>(
);
if opt_default.is_some() || did_change {
let default_branch = opt_default.unwrap_or(default_branch);
let default_branch = (
default_branch.0.clone(),
opt_default.unwrap_or(default_branch.1),
);
let branches = if did_change {
let new = Vec::from_iter_in(
@ -228,12 +231,8 @@ fn insert_jumps<'a>(
None
}
}
Inc(symbol, inc, cont) => match insert_jumps(arena, cont, goal_id, needle) {
Some(cont) => Some(arena.alloc(Inc(*symbol, *inc, cont))),
None => None,
},
Dec(symbol, cont) => match insert_jumps(arena, cont, goal_id, needle) {
Some(cont) => Some(arena.alloc(Dec(*symbol, cont))),
Refcounting(modify, cont) => match insert_jumps(arena, cont, goal_id, needle) {
Some(cont) => Some(arena.alloc(Refcounting(*modify, cont))),
None => None,
},

View File

@ -856,14 +856,14 @@ mod test_mono {
joinpoint Test.8 Test.3:
ret Test.3;
in
let Test.12 = 1i64;
let Test.13 = Index 0 Test.2;
let Test.17 = lowlevel Eq Test.12 Test.13;
let Test.15 = 1i64;
let Test.16 = Index 0 Test.2;
let Test.17 = lowlevel Eq Test.15 Test.16;
if Test.17 then
let Test.14 = Index 1 Test.2;
let Test.15 = 3i64;
let Test.16 = lowlevel Eq Test.15 Test.14;
if Test.16 then
let Test.12 = Index 1 Test.2;
let Test.13 = 3i64;
let Test.14 = lowlevel Eq Test.13 Test.12;
if Test.14 then
let Test.9 = 1i64;
jump Test.8 Test.9;
else
@ -1933,28 +1933,24 @@ mod test_mono {
let Test.16 = S Test.19 Test.18;
let Test.14 = S Test.17 Test.16;
let Test.2 = S Test.15 Test.14;
let Test.7 = 0i64;
let Test.8 = Index 0 Test.2;
let Test.13 = lowlevel Eq Test.7 Test.8;
let Test.11 = 0i64;
let Test.12 = Index 0 Test.2;
let Test.13 = lowlevel Eq Test.11 Test.12;
if Test.13 then
let Test.9 = Index 1 Test.2;
inc Test.9;
let Test.10 = 0i64;
let Test.11 = Index 0 Test.9;
dec Test.9;
let Test.12 = lowlevel Eq Test.10 Test.11;
if Test.12 then
let Test.7 = Index 1 Test.2;
let Test.8 = 0i64;
let Test.9 = Index 0 Test.7;
let Test.10 = lowlevel Eq Test.8 Test.9;
if Test.10 then
let Test.4 = Index 1 Test.2;
inc Test.4;
dec Test.2;
let Test.3 = 1i64;
decref Test.2;
ret Test.3;
else
dec Test.2;
let Test.5 = 0i64;
dec Test.2;
ret Test.5;
else
dec Test.2;
let Test.6 = 0i64;
ret Test.6;
"#

View File

@ -5,6 +5,7 @@ app "nqueens"
main : Task.Task {} []
main =
# Task.after Task.getInt \n ->
queens 6
|> Str.fromInt
|> Task.putLine
@ -18,10 +19,10 @@ length : ConsList a -> I64
length = \xs -> lengthHelp xs 0
lengthHelp : ConsList a, I64 -> I64
lengthHelp = \xs, acc ->
when xs is
lengthHelp = \foobar, acc ->
when foobar is
Cons _ lrest -> lengthHelp lrest (1 + acc)
Nil -> acc
Cons _ rest -> lengthHelp rest (1 + acc)
safe : I64, I64, ConsList I64 -> Bool
safe = \queen, diagonal, xs ->
@ -41,8 +42,8 @@ appendSafe = \k, soln, solns ->
else
appendSafe (k - 1) soln solns
extend = \n, acc, solns ->
when solns is
extend = \n, acc, solutions ->
when solutions is
Nil -> acc
Cons soln rest -> extend n (appendSafe n soln acc) rest

View File

@ -6,7 +6,8 @@ platform folkertdev/foo
provides [ mainForHost ]
effects Effect
{
putLine : Str -> Effect {}
putLine : Str -> Effect {},
getInt : Effect { value: I64, errorCode: [ A, B ], isError: Bool }
}
mainForHost : Task.Task {} [] as Fx

View File

@ -1,5 +1,5 @@
interface Task
exposes [ Task, succeed, fail, after, map, putLine ]
exposes [ Task, succeed, fail, after, map, putLine, getInt ]
imports [ Effect ]
@ -15,7 +15,6 @@ fail : err -> Task * err
fail = \val ->
Effect.always (Err val)
after : Task a err, (a -> Task b err) -> Task b err
after = \effect, transform ->
Effect.after effect \result ->
@ -32,3 +31,16 @@ map = \effect, transform ->
putLine : Str -> Task {} *
putLine = \line -> Effect.map (Effect.putLine line) (\_ -> Ok {})
getInt : Task I64 []
getInt =
Effect.after Effect.getInt \{ isError, value, errorCode } ->
when isError is
True ->
when errorCode is
# A -> Task.fail InvalidCharacter
# B -> Task.fail IOError
_ -> Task.succeed -1
False ->
Task.succeed value

View File

@ -4,6 +4,7 @@ const RocStr = str.RocStr;
const testing = std.testing;
const expectEqual = testing.expectEqual;
const expect = testing.expect;
const maxInt = std.math.maxInt;
const mem = std.mem;
const Allocator = mem.Allocator;
@ -96,3 +97,39 @@ pub export fn roc_fx_putLine(rocPath: str.RocStr) i64 {
return 0;
}
const GetInt = extern struct {
value: i64,
error_code: u8,
is_error: bool,
};
pub export fn roc_fx_getInt() GetInt {
if (roc_fx_getInt_help()) |value| {
const get_int = GetInt{ .is_error = false, .value = value, .error_code = 0 };
return get_int;
} else |err| switch (err) {
error.InvalidCharacter => {
return GetInt{ .is_error = true, .value = 0, .error_code = 0 };
},
else => {
return GetInt{ .is_error = true, .value = 0, .error_code = 1 };
},
}
return 0;
}
fn roc_fx_getInt_help() !i64 {
const stdin = std.io.getStdIn().inStream();
var buf: [40]u8 = undefined;
const line: []u8 = (try stdin.readUntilDelimiterOrEof(&buf, '\n')) orelse "";
return std.fmt.parseInt(i64, line, 10);
}
fn readLine() []u8 {
const stdin = std.io.getStdIn().reader();
return (stdin.readUntilDelimiterOrEof(&line_buf, '\n') catch unreachable) orelse "";
}