changes after review

This commit is contained in:
Folkert 2023-06-23 22:36:21 +02:00
parent 0b03a0bc26
commit 654cf7b861
No known key found for this signature in database
GPG Key ID: 1F17F6FFD112B97C
17 changed files with 142 additions and 123 deletions

1
Cargo.lock generated
View File

@ -3669,6 +3669,7 @@ dependencies = [
name = "roc_mono"
version = "0.0.1"
dependencies = [
"arrayvec 0.7.2",
"bitvec",
"bumpalo",
"hashbrown 0.13.2",

View File

@ -87,7 +87,7 @@ macro_rules! map_symbol_to_lowlevel_and_arity {
LowLevel::PtrCast => unimplemented!(),
LowLevel::PtrStore => unimplemented!(),
LowLevel::PtrLoad => unimplemented!(),
LowLevel::PtrToStackValue => unimplemented!(),
LowLevel::Alloca => unimplemented!(),
LowLevel::RefCountIncRcPtr => unimplemented!(),
LowLevel::RefCountDecRcPtr=> unimplemented!(),
LowLevel::RefCountIncDataPtr => unimplemented!(),

View File

@ -102,16 +102,6 @@ impl<T: PartialEq> VecSet<T> {
{
self.elements.retain(f)
}
pub fn keep_if_in_both(&mut self, other: &Self) {
self.elements.retain(|e| other.contains(e));
}
pub fn keep_if_in_either(&mut self, other: Self) {
for e in other.elements {
self.insert(e);
}
}
}
impl<A: Ord> Extend<A> for VecSet<A> {

View File

@ -2791,14 +2791,7 @@ impl<
.storage_manager
.load_to_general_reg(&mut self.buf, structure);
let mask_symbol = self.debug_symbol("tag_id_mask");
let mask_reg = self
.storage_manager
.claim_general_reg(&mut self.buf, &mask_symbol);
ASM::mov_reg64_imm64(&mut self.buf, mask_reg, (!0b111) as _);
// mask out the tag id bits
ASM::and_reg64_reg64_reg64(&mut self.buf, mask_reg, ptr_reg, mask_reg);
let (mask_symbol, mask_reg) = self.clear_tag_id(ptr_reg);
let mut offset = 0;
for field in &other_fields[..index as usize] {
@ -2814,6 +2807,8 @@ impl<
element_layout,
*sym,
);
self.free_symbol(&mask_symbol)
}
UnionLayout::Recursive(tag_layouts) => {
let other_fields = tag_layouts[tag_id as usize];
@ -2824,21 +2819,12 @@ impl<
.load_to_general_reg(&mut self.buf, structure);
// mask out the tag id bits
let unmasked_reg = if union_layout
.stores_tag_id_as_data(self.storage_manager.target_info)
{
ptr_reg
let (unmasked_symbol, unmasked_reg) =
if union_layout.stores_tag_id_as_data(self.storage_manager.target_info) {
(None, ptr_reg)
} else {
let umasked_symbol = self.debug_symbol("unmasked");
let unmasked_reg = self
.storage_manager
.claim_general_reg(&mut self.buf, &umasked_symbol);
ASM::mov_reg64_imm64(&mut self.buf, unmasked_reg, (!0b111) as _);
ASM::and_reg64_reg64_reg64(&mut self.buf, unmasked_reg, ptr_reg, unmasked_reg);
unmasked_reg
let (mask_symbol, mask_reg) = self.clear_tag_id(ptr_reg);
(Some(mask_symbol), mask_reg)
};
let mut offset = 0;
@ -2855,6 +2841,10 @@ impl<
element_layout,
*sym,
);
if let Some(unmasked_symbol) = unmasked_symbol {
self.free_symbol(&unmasked_symbol);
}
}
}
}
@ -2911,14 +2901,7 @@ impl<
other_tags[tag_id as usize - 1]
};
let mask_symbol = self.debug_symbol("tag_id_mask");
let mask_reg = self
.storage_manager
.claim_general_reg(&mut self.buf, &mask_symbol);
ASM::mov_reg64_imm64(&mut self.buf, mask_reg, (!0b111) as _);
// mask out the tag id bits
ASM::and_reg64_reg64_reg64(&mut self.buf, mask_reg, ptr_reg, mask_reg);
let (mask_symbol, mask_reg) = self.clear_tag_id(ptr_reg);
let mut offset = 0;
for field in &other_fields[..index as usize] {
@ -2926,6 +2909,8 @@ impl<
}
ASM::add_reg64_reg64_imm32(&mut self.buf, sym_reg, mask_reg, offset as i32);
self.free_symbol(&mask_symbol);
}
UnionLayout::Recursive(tag_layouts) => {
let other_fields = tag_layouts[tag_id as usize];
@ -2935,21 +2920,12 @@ impl<
.load_to_general_reg(&mut self.buf, structure);
// mask out the tag id bits
let unmasked_reg = if union_layout
.stores_tag_id_as_data(self.storage_manager.target_info)
{
ptr_reg
let (unmasked_symbol, unmasked_reg) =
if union_layout.stores_tag_id_as_data(self.storage_manager.target_info) {
(None, ptr_reg)
} else {
let umasked_symbol = self.debug_symbol("unmasked");
let unmasked_reg = self
.storage_manager
.claim_general_reg(&mut self.buf, &umasked_symbol);
ASM::mov_reg64_imm64(&mut self.buf, unmasked_reg, (!0b111) as _);
ASM::and_reg64_reg64_reg64(&mut self.buf, unmasked_reg, ptr_reg, unmasked_reg);
unmasked_reg
let (mask_symbol, mask_reg) = self.clear_tag_id(ptr_reg);
(Some(mask_symbol), mask_reg)
};
let mut offset = 0;
@ -2958,6 +2934,10 @@ impl<
}
ASM::add_reg64_reg64_imm32(&mut self.buf, sym_reg, unmasked_reg, offset as i32);
if let Some(unmasked_symbol) = unmasked_symbol {
self.free_symbol(&unmasked_symbol);
}
}
}
}
@ -3986,6 +3966,19 @@ impl<
CC: CallConv<GeneralReg, FloatReg, ASM>,
> Backend64Bit<'a, 'r, GeneralReg, FloatReg, ASM, CC>
{
fn clear_tag_id(&mut self, ptr_reg: GeneralReg) -> (Symbol, GeneralReg) {
let unmasked_symbol = self.debug_symbol("unmasked");
let unmasked_reg = self
.storage_manager
.claim_general_reg(&mut self.buf, &unmasked_symbol);
ASM::mov_reg64_imm64(&mut self.buf, unmasked_reg, (!0b111) as _);
ASM::and_reg64_reg64_reg64(&mut self.buf, unmasked_reg, ptr_reg, unmasked_reg);
(unmasked_symbol, unmasked_reg)
}
fn compare(
&mut self,
op: CompareOperation,

View File

@ -1605,7 +1605,7 @@ trait Backend<'a> {
self.build_ptr_load(*sym, args[0], *ret_layout);
}
LowLevel::PtrToStackValue => {
LowLevel::Alloca => {
self.build_ptr_to_stack_value(*sym, args[0], arg_layouts[0]);
}

View File

@ -1502,7 +1502,7 @@ pub(crate) fn build_exp_expr<'a, 'ctx>(
let ptr = tag_pointer_clear_tag_id(env, argument.into_pointer_value());
let target_loaded_type = basic_type_from_layout(env, layout_interner, ret_repr);
union_field_at_index(
union_field_ptr_at_index(
env,
layout_interner,
field_layouts,
@ -1518,7 +1518,7 @@ pub(crate) fn build_exp_expr<'a, 'ctx>(
let struct_type = basic_type_from_layout(env, layout_interner, struct_layout);
let target_loaded_type = basic_type_from_layout(env, layout_interner, ret_repr);
union_field_at_index(
union_field_ptr_at_index(
env,
layout_interner,
field_layouts,
@ -1546,7 +1546,7 @@ pub(crate) fn build_exp_expr<'a, 'ctx>(
let ptr = tag_pointer_clear_tag_id(env, argument.into_pointer_value());
let target_loaded_type = basic_type_from_layout(env, layout_interner, ret_repr);
union_field_at_index(
union_field_ptr_at_index(
env,
layout_interner,
field_layouts,
@ -1569,7 +1569,7 @@ pub(crate) fn build_exp_expr<'a, 'ctx>(
let struct_type = basic_type_from_layout(env, layout_interner, struct_layout);
let target_loaded_type = basic_type_from_layout(env, layout_interner, ret_repr);
union_field_at_index(
union_field_ptr_at_index(
env,
layout_interner,
field_layouts,
@ -2156,7 +2156,7 @@ fn lookup_at_index_ptr<'a, 'ctx>(
struct_type: Option<StructType<'ctx>>,
target_loaded_type: BasicTypeEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
let elem_ptr = union_field_at_index_help(
let elem_ptr = union_field_ptr_at_index_help(
env,
layout_interner,
field_layouts,
@ -2179,7 +2179,7 @@ fn lookup_at_index_ptr<'a, 'ctx>(
cast_if_necessary_for_opaque_recursive_pointers(env.builder, result, target_loaded_type)
}
fn union_field_at_index_help<'a, 'ctx>(
fn union_field_ptr_at_index_help<'a, 'ctx>(
env: &Env<'a, 'ctx, '_>,
layout_interner: &STLayoutInterner<'a>,
field_layouts: &'a [InLayout<'a>],
@ -2213,7 +2213,7 @@ fn union_field_at_index_help<'a, 'ctx>(
.unwrap()
}
fn union_field_at_index<'a, 'ctx>(
fn union_field_ptr_at_index<'a, 'ctx>(
env: &Env<'a, 'ctx, '_>,
layout_interner: &STLayoutInterner<'a>,
field_layouts: &'a [InLayout<'a>],
@ -2222,7 +2222,7 @@ fn union_field_at_index<'a, 'ctx>(
value: PointerValue<'ctx>,
target_loaded_type: BasicTypeEnum<'ctx>,
) -> PointerValue<'ctx> {
let result = union_field_at_index_help(
let result = union_field_ptr_at_index_help(
env,
layout_interner,
field_layouts,

View File

@ -1325,7 +1325,7 @@ pub(crate) fn run_low_level<'a, 'ctx>(
.new_build_load(element_type, ptr.into_pointer_value(), "ptr_load")
}
PtrToStackValue => {
Alloca => {
arguments!(initial_value);
let ptr = entry_block_alloca_zerofill(env, initial_value.get_type(), "stack_value");

View File

@ -1979,8 +1979,8 @@ impl<'a> LowLevelCall<'a> {
);
}
PtrLoad => backend.expr_unbox(self.ret_symbol, self.arguments[0]),
PtrToStackValue => {
// PtrToStackValue : a -> Ptr a
Alloca => {
// Alloca : a -> Ptr a
let arg = self.arguments[0];
let arg_layout = backend.storage.symbol_layouts.get(&arg).unwrap();

View File

@ -120,7 +120,7 @@ pub enum LowLevel {
PtrCast,
PtrStore,
PtrLoad,
PtrToStackValue,
Alloca,
RefCountIncRcPtr,
RefCountDecRcPtr,
RefCountIncDataPtr,
@ -232,7 +232,7 @@ macro_rules! map_symbol_to_lowlevel {
LowLevel::PtrCast => unimplemented!(),
LowLevel::PtrStore => unimplemented!(),
LowLevel::PtrLoad => unimplemented!(),
LowLevel::PtrToStackValue => unimplemented!(),
LowLevel::Alloca => unimplemented!(),
LowLevel::RefCountIncRcPtr => unimplemented!(),
LowLevel::RefCountDecRcPtr=> unimplemented!(),
LowLevel::RefCountIncDataPtr => unimplemented!(),

View File

@ -27,6 +27,7 @@ roc_types = { path = "../types" }
ven_pretty = { path = "../../vendor/pretty" }
bitvec.workspace = true
arrayvec.workspace = true
bumpalo.workspace = true
hashbrown.workspace = true
parking_lot.workspace = true

View File

@ -1045,7 +1045,7 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[Ownership] {
PtrStore => arena.alloc_slice_copy(&[owned, owned]),
PtrLoad => arena.alloc_slice_copy(&[owned]),
PtrToStackValue => arena.alloc_slice_copy(&[owned]),
Alloca => arena.alloc_slice_copy(&[owned]),
PtrCast | RefCountIncRcPtr | RefCountDecRcPtr | RefCountIncDataPtr | RefCountDecDataPtr
| RefCountIsUnique => {

View File

@ -1680,7 +1680,7 @@ fn low_level_no_rc(lowlevel: &LowLevel) -> RC {
// only inserted for internal purposes. RC should not touch it
PtrStore => RC::NoRc,
PtrLoad => RC::NoRc,
PtrToStackValue => RC::NoRc,
Alloca => RC::NoRc,
PtrCast | RefCountIncRcPtr | RefCountDecRcPtr | RefCountIncDataPtr | RefCountDecDataPtr
| RefCountIsUnique => {

View File

@ -7659,6 +7659,7 @@ fn substitute_in_expr<'a>(
None => None,
},
// currently only used for tail recursion modulo cons (TRMC)
UnionFieldPtrAtIndex {
structure,
tag_id,

View File

@ -674,7 +674,10 @@ pub(crate) enum LayoutWrapper<'a> {
pub enum LayoutRepr<'a> {
Builtin(Builtin<'a>),
Struct(&'a [InLayout<'a>]),
// A (heap allocated) reference-counted value
Boxed(InLayout<'a>),
// A pointer (heap or stack) without any reference counting
// Ptr is not user-facing. The compiler author must make sure that invariants are upheld
Ptr(InLayout<'a>),
Union(UnionLayout<'a>),
LambdaSet(LambdaSet<'a>),

View File

@ -10,7 +10,7 @@ use crate::layout::{
};
use bumpalo::collections::Vec;
use bumpalo::Bump;
use roc_collections::{MutMap, VecSet};
use roc_collections::{MutMap, VecMap, VecSet};
use roc_module::low_level::LowLevel;
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
@ -423,15 +423,6 @@ impl TrmcCandidateSet {
self.active.insert(call);
}
fn extend(&mut self, other: Self) {
self.confirmed.keep_if_in_either(other.confirmed);
self.invalid.keep_if_in_either(other.invalid);
self.active.keep_if_in_either(other.active);
self.active.retain(|k| !self.invalid.contains(k));
self.confirmed.retain(|k| !self.invalid.contains(k));
}
fn retain<F>(&mut self, keep: F)
where
F: Fn(&Symbol) -> bool,
@ -443,7 +434,7 @@ impl TrmcCandidateSet {
}
self.active.retain(|k| !self.invalid.contains(k));
self.confirmed.retain(|k| !self.invalid.contains(k));
debug_assert!(!self.confirmed.iter().any(|x| self.invalid.contains(x)));
}
}
@ -465,26 +456,31 @@ where
return VecSet::default();
}
trmc_candidates_help(proc.name, &proc.body, TrmcCandidateSet::default()).confirmed
let mut candidate_set = TrmcCandidateSet::default();
trmc_candidates_help(proc.name, &proc.body, &mut candidate_set);
candidate_set.confirmed
}
fn trmc_candidates_help<'a>(
function_name: LambdaName,
stmt: &'_ Stmt<'a>,
mut candidates: TrmcCandidateSet,
) -> TrmcCandidateSet {
candidates: &mut TrmcCandidateSet,
) {
// if this stmt is the literal tail tag application and return, then this is a TRMC opportunity
if let Some(cons_info) = TrmcEnv::is_terminal_constructor(stmt) {
// must use the result of a recursive call directly as an argument
// we pick the (syntactically) first one
for recursive_call in candidates.active.iter() {
if cons_info.arguments.contains(recursive_call) {
return TrmcCandidateSet {
confirmed: VecSet::singleton(*recursive_call),
active: VecSet::default(),
invalid: candidates.invalid,
};
}
// the tag application must directly use the result of the recursive call
let recursive_call = candidates
.active
.iter()
.copied()
.find(|call| cons_info.arguments.contains(call));
// if we find a usage, this is a confirmed TRMC call
if let Some(recursive_call) = recursive_call {
candidates.active.remove(&recursive_call);
candidates.confirmed.insert(recursive_call);
return;
}
}
@ -511,15 +507,9 @@ fn trmc_candidates_help<'a>(
.map(|(_, _, stmt)| stmt)
.chain([default_branch.1]);
let mut accum = candidates.clone();
for next in it {
let x = trmc_candidates_help(function_name, next, candidates.clone());
accum.extend(x);
trmc_candidates_help(function_name, next, candidates);
}
accum
}
Stmt::Refcounting(_, next) => trmc_candidates_help(function_name, next, candidates),
Stmt::Expect { remainder, .. }
@ -528,26 +518,63 @@ fn trmc_candidates_help<'a>(
Stmt::Join {
body, remainder, ..
} => {
let mut x = trmc_candidates_help(function_name, body, candidates.clone());
let y = trmc_candidates_help(function_name, remainder, candidates.clone());
trmc_candidates_help(function_name, body, candidates);
trmc_candidates_help(function_name, remainder, candidates);
}
Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => { /* terminal */ }
}
}
x.extend(y);
x
}
Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => candidates,
}
}
// TRMC (tail recursion modulo constructor) is an optimization for some recursive functions that return a recursive data type. The most basic example is a repeat function on linked lists:
//
// ```roc
// LinkedList a : [ Nil, Cons a (LinkedList a) ]
//
// repeat : a, Nat -> LinkedList a
// repeat = \element, n ->
// when n is
// 0 -> Nil
// _ -> Cons element (repeat element (n - 1))
// ```
//
// This function is recursive, but cannot use standard tail-call elimintation, because the recursive call is not in tail position (i.e. the last thing happening before a return). Rather the recursive call is an argument to a constructor of the recursive output type. This means that `repeat n` will creat `n` stack frames. For big inputs, a stack overflow is inevitable.
//
// But there is a trick: TRMC. Using TRMC and join points, we are able to convert this function into a loop, which uses only one stack frame for the whole process.
//
// ```pseudo-roc
// repeat : a, Nat -> LinkedList a
// repeat = \initialElement, initialN ->
// joinpoint trmc = \element, n, hole, head ->
// when n is
// 0 ->
// # write the value `Nil` into the hole
// *hole = Nil
// # dereference (load from) the pointer to the first element
// *head
//
// _ ->
// *hole = Cons element NULL
// newHole = &hole.Cons.1
// jump trmc element (n - 1) newHole head
// in
// # creates a stack allocation, gives a pointer to that stack allocation
// initial : Ptr (LinkedList a) = #alloca NULL
// jump trmc initialElement initialN initial initial
// ```
//
// The functionality here figures out whether this transformation can be applied in valid way, and then performs the transformation.
#[derive(Clone)]
pub(crate) struct TrmcEnv<'a> {
/// Current hole to fill
hole_symbol: Symbol,
/// Pointer to the first constructor ("the head of the list")
head_symbol: Symbol,
joinpoint_id: JoinPointId,
return_layout: InLayout<'a>,
ptr_return_layout: InLayout<'a>,
trmc_calls: MutMap<Symbol, Option<Call<'a>>>,
trmc_calls: VecMap<Symbol, Option<Call<'a>>>,
}
#[derive(Debug)]
@ -599,7 +626,7 @@ impl<'a> TrmcEnv<'a> {
fn is_recursive_call(call: &Call<'a>, lambda_name: LambdaName<'_>) -> bool {
match call.call_type {
CallType::ByName { name, .. } => {
// TODO are there other restrictions?
// because we do not allow polymorphic recursion, this is the only constraint
name == lambda_name
}
CallType::Foreign { .. } | CallType::LowLevel { .. } | CallType::HigherOrder(_) => {
@ -617,7 +644,6 @@ impl<'a> TrmcEnv<'a> {
let ptr_write = Call {
call_type: crate::ir::CallType::LowLevel {
op: LowLevel::PtrStore,
// update_mode: env.next_update_mode_id(),
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: env.arena.alloc([ptr, value]),
@ -639,9 +665,9 @@ impl<'a> TrmcEnv<'a> {
let arena = env.arena;
let return_layout = proc.ret_layout;
let mut joinpoint_parameters = Vec::with_capacity_in(proc.args.len() + 1, env.arena);
let mut joinpoint_parameters = Vec::with_capacity_in(proc.args.len() + 2, env.arena);
let mut new_proc_arguments = Vec::with_capacity_in(proc.args.len(), env.arena);
let mut jump_arguments = Vec::with_capacity_in(proc.args.len() + 1, env.arena);
let mut jump_arguments = Vec::with_capacity_in(proc.args.len() + 2, env.arena);
for (i, (layout, old_symbol)) in proc.args.iter().enumerate() {
let symbol = env.named_unique_symbol(&format!("arg_{i}"));
@ -670,7 +696,7 @@ impl<'a> TrmcEnv<'a> {
let call = Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrToStackValue,
op: LowLevel::Alloca,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: arena.alloc([null_symbol]),
@ -744,9 +770,13 @@ impl<'a> TrmcEnv<'a> {
match stmt {
Stmt::Let(symbol, expr, layout, next) => {
// if this is a TRMC call,
// if this is a TRMC call, remember what the call looks like, so we can turn it
// into a jump later. The call is then removed from the Stmt
if let Some(opt_call) = self.trmc_calls.get_mut(symbol) {
debug_assert!(opt_call.is_none());
debug_assert!(
opt_call.is_none(),
"didn't expect to visit call again since symbols are unique"
);
let call = match expr {
Expr::Call(call) => call,

View File

@ -9,7 +9,7 @@ procedure Test.10 (Test.11):
procedure Test.2 (#Derived_gen.0, #Derived_gen.1):
let #Derived_gen.3 : [<rnu><null>, C I64 *self] = NullPointer;
let #Derived_gen.2 : Ptr([<rnu><null>, C I64 *self]) = lowlevel PtrToStackValue #Derived_gen.3;
let #Derived_gen.2 : Ptr([<rnu><null>, C I64 *self]) = lowlevel Alloca #Derived_gen.3;
joinpoint #Derived_gen.4 Test.4 Test.5 #Derived_gen.5 #Derived_gen.6:
let Test.22 : U8 = 1i64;
let Test.23 : U8 = GetTagId Test.5;

View File

@ -8,7 +8,7 @@ procedure Num.24 (#Attr.2, #Attr.3):
procedure Test.3 (#Derived_gen.0, #Derived_gen.1, #Derived_gen.2):
let #Derived_gen.4 : [<rnu>C *self I64 *self I32 Int1, <null>] = NullPointer;
let #Derived_gen.3 : Ptr([<rnu>C *self I64 *self I32 Int1, <null>]) = lowlevel PtrToStackValue #Derived_gen.4;
let #Derived_gen.3 : Ptr([<rnu>C *self I64 *self I32 Int1, <null>]) = lowlevel Alloca #Derived_gen.4;
joinpoint #Derived_gen.5 Test.9 Test.10 Test.11 #Derived_gen.6 #Derived_gen.7:
let Test.254 : U8 = 0i64;
let Test.255 : U8 = GetTagId Test.9;