This commit is contained in:
Folkert 2023-06-18 14:21:48 +02:00
parent 4a9514d2c4
commit 0247237fe8
No known key found for this signature in database
GPG Key ID: 1F17F6FFD112B97C
16 changed files with 625 additions and 163 deletions

View File

@ -1422,6 +1422,38 @@ fn expr_spec<'a>(
builder.add_get_tuple_field(block, variant_id, index)
}
},
UnionFieldPtrAtIndex {
index,
tag_id,
structure,
union_layout,
} => {
let index = (*index) as u32;
let tag_value_id = env.symbols[structure];
let type_name_bytes = recursive_tag_union_name_bytes(union_layout).as_bytes();
let type_name = TypeName(&type_name_bytes);
// unwrap the named wrapper
let union_id = builder.add_unwrap_named(block, MOD_APP, type_name, tag_value_id)?;
// now we have a tuple (cell, union { ... }); decompose
let heap_cell = builder.add_get_tuple_field(block, union_id, TAG_CELL_INDEX)?;
let union_data = builder.add_get_tuple_field(block, union_id, TAG_DATA_INDEX)?;
// we're reading from this value, so touch the heap cell
builder.add_touch(block, heap_cell)?;
// next, unwrap the union at the tag id that we've got
let variant_id = builder.add_unwrap_union(block, union_data, *tag_id as u32)?;
let value = builder.add_get_tuple_field(block, variant_id, index)?;
// construct the box. Here the heap_cell of the tag is re-used, I'm hoping that that
// conveys to morphic that we're borrowing into the existing tag?!
builder.add_make_tuple(block, &[heap_cell, value])
}
StructAtIndex {
index, structure, ..
} => {

View File

@ -85,7 +85,9 @@ macro_rules! map_symbol_to_lowlevel_and_arity {
// these are used internally and not tied to a symbol
LowLevel::Hash => unimplemented!(),
LowLevel::PtrCast => unimplemented!(),
LowLevel::PtrWrite => unimplemented!(),
LowLevel::PtrStore => unimplemented!(),
LowLevel::PtrLoad => unimplemented!(),
LowLevel::PtrToZeroed => unimplemented!(),
LowLevel::RefCountIncRcPtr => unimplemented!(),
LowLevel::RefCountDecRcPtr=> unimplemented!(),
LowLevel::RefCountIncDataPtr => unimplemented!(),

View File

@ -135,6 +135,10 @@ flags! {
/// instructions.
ROC_PRINT_IR_AFTER_REFCOUNT
/// Writes a pretty-printed mono IR to stderr after the tail recursion (modulo cons)
/// has been applied.
ROC_PRINT_IR_AFTER_TRMC
/// Writes a pretty-printed mono IR to stderr after performing dropspecialization.
/// Which inlines drop functions to remove pairs of alloc/dealloc instructions of its children.
ROC_PRINT_IR_AFTER_DROP_SPECIALIZATION

View File

@ -163,6 +163,9 @@ impl<'a> LastSeenMap<'a> {
Expr::UnionAtIndex { structure, .. } => {
self.set_last_seen(*structure, stmt);
}
Expr::UnionFieldPtrAtIndex { structure, .. } => {
self.set_last_seen(*structure, stmt);
}
Expr::Array { elems, .. } => {
for elem in *elems {
if let ListLiteralElement::Symbol(sym) = elem {
@ -794,6 +797,14 @@ trait Backend<'a> {
} => {
self.load_union_at_index(sym, structure, *tag_id, *index, union_layout);
}
Expr::UnionFieldPtrAtIndex {
structure,
tag_id,
union_layout,
index,
} => {
todo!();
}
Expr::GetTagId {
structure,
union_layout,
@ -1581,7 +1592,7 @@ trait Backend<'a> {
self.build_ptr_cast(sym, &args[0])
}
LowLevel::PtrWrite => {
LowLevel::PtrStore => {
let element_layout = match self.interner().get_repr(*ret_layout) {
LayoutRepr::Boxed(boxed) => boxed,
_ => unreachable!("cannot write to {:?}", self.interner().dbg(*ret_layout)),
@ -1589,6 +1600,10 @@ trait Backend<'a> {
self.build_ptr_write(*sym, args[0], args[1], element_layout);
}
LowLevel::PtrLoad => {
//
todo!()
}
LowLevel::RefCountDecRcPtr => self.build_fn_call(
sym,
bitcode::UTILS_DECREF_RC_PTR.to_string(),

View File

@ -23,7 +23,7 @@ use inkwell::passes::{PassManager, PassManagerBuilder};
use inkwell::types::{
AnyType, BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FunctionType, IntType, StructType,
};
use inkwell::values::BasicValueEnum::{self};
use inkwell::values::BasicValueEnum;
use inkwell::values::{
BasicMetadataValueEnum, CallSiteValue, FunctionValue, InstructionValue, IntValue, PointerValue,
StructValue,
@ -1379,12 +1379,13 @@ pub(crate) fn build_exp_expr<'a, 'ctx>(
layout_interner.get_repr(layout),
);
lookup_at_index_ptr2(
lookup_at_index_ptr(
env,
layout_interner,
field_layouts,
*index as usize,
ptr,
None,
target_loaded_type,
)
}
@ -1404,7 +1405,7 @@ pub(crate) fn build_exp_expr<'a, 'ctx>(
field_layouts,
*index as usize,
argument.into_pointer_value(),
struct_type.into_struct_type(),
Some(struct_type.into_struct_type()),
target_loaded_type,
)
}
@ -1430,12 +1431,13 @@ pub(crate) fn build_exp_expr<'a, 'ctx>(
layout_interner.get_repr(layout),
);
lookup_at_index_ptr2(
lookup_at_index_ptr(
env,
layout_interner,
field_layouts,
*index as usize,
ptr,
None,
target_loaded_type,
)
}
@ -1463,13 +1465,117 @@ pub(crate) fn build_exp_expr<'a, 'ctx>(
// the tag id is not stored
*index as usize,
argument.into_pointer_value(),
struct_type.into_struct_type(),
Some(struct_type.into_struct_type()),
target_loaded_type,
)
}
}
}
UnionFieldPtrAtIndex {
tag_id,
structure,
index,
union_layout,
} => {
// cast the argument bytes into the desired shape for this tag
let (argument, structure_layout) = scope.load_symbol_and_layout(structure);
let ret_repr = layout_interner.get_repr(layout);
let pointer_value = match union_layout {
UnionLayout::NonRecursive(_) => unreachable!(),
UnionLayout::Recursive(tag_layouts) => {
debug_assert!(argument.is_pointer_value());
let field_layouts = tag_layouts[*tag_id as usize];
let ptr = tag_pointer_clear_tag_id(env, argument.into_pointer_value());
let target_loaded_type = basic_type_from_layout(env, layout_interner, ret_repr);
union_field_at_index(
env,
layout_interner,
field_layouts,
None,
*index as usize,
ptr,
target_loaded_type,
)
}
UnionLayout::NonNullableUnwrapped(field_layouts) => {
let struct_layout = LayoutRepr::struct_(field_layouts);
let struct_type = basic_type_from_layout(env, layout_interner, struct_layout);
let target_loaded_type = basic_type_from_layout(env, layout_interner, ret_repr);
union_field_at_index(
env,
layout_interner,
field_layouts,
Some(struct_type.into_struct_type()),
*index as usize,
argument.into_pointer_value(),
target_loaded_type,
)
}
UnionLayout::NullableWrapped {
nullable_id,
other_tags,
} => {
debug_assert!(argument.is_pointer_value());
debug_assert_ne!(*tag_id, *nullable_id);
let tag_index = if *tag_id < *nullable_id {
*tag_id
} else {
tag_id - 1
};
let field_layouts = other_tags[tag_index as usize];
let ptr = tag_pointer_clear_tag_id(env, argument.into_pointer_value());
let target_loaded_type = basic_type_from_layout(env, layout_interner, ret_repr);
union_field_at_index(
env,
layout_interner,
field_layouts,
None,
*index as usize,
ptr,
target_loaded_type,
)
.into()
}
UnionLayout::NullableUnwrapped {
nullable_id,
other_fields,
} => {
debug_assert!(argument.is_pointer_value());
debug_assert_ne!(*tag_id != 0, *nullable_id);
let field_layouts = other_fields;
let struct_layout = LayoutRepr::struct_(field_layouts);
let struct_type = basic_type_from_layout(env, layout_interner, struct_layout);
let target_loaded_type = basic_type_from_layout(env, layout_interner, ret_repr);
union_field_at_index(
env,
layout_interner,
field_layouts,
Some(struct_type.into_struct_type()),
// the tag id is not stored
*index as usize,
argument.into_pointer_value(),
target_loaded_type,
)
}
};
pointer_value.into()
}
GetTagId {
structure,
union_layout,
@ -2025,21 +2131,20 @@ fn lookup_at_index_ptr<'a, 'ctx>(
field_layouts: &[InLayout<'a>],
index: usize,
value: PointerValue<'ctx>,
struct_type: StructType<'ctx>,
struct_type: Option<StructType<'ctx>>,
target_loaded_type: BasicTypeEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
let builder = env.builder;
let ptr = env.builder.build_pointer_cast(
let elem_ptr = union_field_at_index_help(
env,
layout_interner,
field_layouts,
struct_type,
index,
value,
struct_type.ptr_type(AddressSpace::default()),
"cast_lookup_at_index_ptr",
);
let elem_ptr = builder
.new_build_struct_gep(struct_type, ptr, index as u32, "at_index_struct_gep")
.unwrap();
let field_layout = field_layouts[index];
let result = load_roc_value(
env,
@ -2054,19 +2159,23 @@ fn lookup_at_index_ptr<'a, 'ctx>(
cast_if_necessary_for_opaque_recursive_pointers(env.builder, result, target_loaded_type)
}
fn lookup_at_index_ptr2<'a, 'ctx>(
fn union_field_at_index_help<'a, 'ctx>(
env: &Env<'a, 'ctx, '_>,
layout_interner: &STLayoutInterner<'a>,
field_layouts: &'a [InLayout<'a>],
opt_struct_type: Option<StructType<'ctx>>,
index: usize,
value: PointerValue<'ctx>,
target_loaded_type: BasicTypeEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
) -> PointerValue<'ctx> {
let builder = env.builder;
let struct_layout = LayoutRepr::struct_(field_layouts);
let struct_type =
basic_type_from_layout(env, layout_interner, struct_layout).into_struct_type();
let struct_type = match opt_struct_type {
Some(st) => st,
None => {
let struct_layout = LayoutRepr::struct_(field_layouts);
basic_type_from_layout(env, layout_interner, struct_layout).into_struct_type()
}
};
let data_ptr = env.builder.build_pointer_cast(
value,
@ -2074,27 +2183,40 @@ fn lookup_at_index_ptr2<'a, 'ctx>(
"cast_lookup_at_index_ptr",
);
let elem_ptr = builder
builder
.new_build_struct_gep(
struct_type,
data_ptr,
index as u32,
"at_index_struct_gep_data",
)
.unwrap();
.unwrap()
}
let field_layout = field_layouts[index];
let result = load_roc_value(
fn union_field_at_index<'a, 'ctx>(
env: &Env<'a, 'ctx, '_>,
layout_interner: &STLayoutInterner<'a>,
field_layouts: &'a [InLayout<'a>],
opt_struct_type: Option<StructType<'ctx>>,
index: usize,
value: PointerValue<'ctx>,
target_loaded_type: BasicTypeEnum<'ctx>,
) -> PointerValue<'ctx> {
let result = union_field_at_index_help(
env,
layout_interner,
layout_interner.get_repr(field_layout),
elem_ptr,
"load_at_index_ptr",
field_layouts,
opt_struct_type,
index,
value,
);
// A recursive pointer in the loaded structure is stored as a `i64*`, but the loaded layout
// might want a more precise structure. As such, cast it to the refined type if needed.
cast_if_necessary_for_opaque_recursive_pointers(env.builder, result, target_loaded_type)
let from_value: BasicValueEnum = result.into();
let to_type: BasicTypeEnum = target_loaded_type;
cast_if_necessary_for_opaque_recursive_pointers(env.builder, from_value, to_type)
.into_pointer_value()
}
pub fn reserve_with_refcount<'a, 'ctx>(
@ -3071,7 +3193,7 @@ pub fn cast_if_necessary_for_opaque_recursive_pointers<'ctx>(
to_type: BasicTypeEnum<'ctx>,
) -> BasicValueEnum<'ctx> {
if from_value.get_type() != to_type
// Only perform the cast if the target types are transumatble.
// Only perform the cast if the target types are transmutable.
&& equivalent_type_constructors(&from_value.get_type(), &to_type)
{
complex_bitcast(

View File

@ -1304,8 +1304,28 @@ pub(crate) fn run_low_level<'a, 'ctx>(
.into()
}
PtrStore | PtrLoad | PtrToZeroed | RefCountIncRcPtr | RefCountDecRcPtr
| RefCountIncDataPtr | RefCountDecDataPtr => {
PtrStore => {
arguments!(ptr, value);
env.builder.build_store(ptr.into_pointer_value(), value);
// ptr
env.context.struct_type(&[], false).const_zero().into()
}
PtrLoad => {
arguments!(ptr);
let ret_repr = layout_interner.get_repr(layout);
let element_type = basic_type_from_layout(env, layout_interner, ret_repr);
env.builder
.new_build_load(element_type, ptr.into_pointer_value(), "ptr_load")
}
PtrToZeroed => todo!(),
RefCountIncRcPtr | RefCountDecRcPtr | RefCountIncDataPtr | RefCountDecDataPtr => {
unreachable!("Not used in LLVM backend: {:?}", op);
}

View File

@ -1079,6 +1079,13 @@ impl<'a, 'r> WasmBackend<'a, 'r> {
index,
} => self.expr_union_at_index(*structure, *tag_id, union_layout, *index, sym),
Expr::UnionFieldPtrAtIndex {
structure,
tag_id,
union_layout,
index,
} => todo!(),
Expr::ExprBox { symbol: arg_sym } => self.expr_box(sym, *arg_sym, layout, storage),
Expr::ExprUnbox { symbol: arg_sym } => self.expr_unbox(sym, *arg_sym),

View File

@ -16,7 +16,7 @@ use roc_can::module::{
};
use roc_collections::{default_hasher, BumpMap, MutMap, MutSet, VecMap, VecSet};
use roc_constrain::module::constrain_module;
use roc_debug_flags::dbg_do;
use roc_debug_flags::{dbg_do, ROC_PRINT_IR_AFTER_TRMC};
#[cfg(debug_assertions)]
use roc_debug_flags::{
ROC_CHECK_MONO_IR, ROC_PRINT_IR_AFTER_DROP_SPECIALIZATION, ROC_PRINT_IR_AFTER_REFCOUNT,
@ -3104,6 +3104,16 @@ fn update<'a>(
let ident_ids = state.constrained_ident_ids.get_mut(&module_id).unwrap();
roc_mono::tail_recursion::apply_trmc(
arena,
&mut layout_interner,
module_id,
ident_ids,
&mut state.procedures,
);
debug_print_ir!(state, &layout_interner, ROC_PRINT_IR_AFTER_TRMC);
inc_dec::insert_inc_dec_operations(
arena,
&layout_interner,

View File

@ -230,7 +230,9 @@ macro_rules! map_symbol_to_lowlevel {
// these are used internally and not tied to a symbol
LowLevel::Hash => unimplemented!(),
LowLevel::PtrCast => unimplemented!(),
LowLevel::PtrWrite => unimplemented!(),
LowLevel::PtrStore => unimplemented!(),
LowLevel::PtrLoad => unimplemented!(),
LowLevel::PtrToZeroed => unimplemented!(),
LowLevel::RefCountIncRcPtr => unimplemented!(),
LowLevel::RefCountDecRcPtr=> unimplemented!(),
LowLevel::RefCountIncDataPtr => unimplemented!(),

View File

@ -741,6 +741,14 @@ impl<'a> BorrowInfState<'a> {
self.if_is_owned_then_own(z, *x);
}
UnionFieldPtrAtIndex { structure: x, .. } => {
// if the structure (record/tag/array) is owned, the extracted value is
self.if_is_owned_then_own(*x, z);
// if the extracted value is owned, the structure must be too
self.if_is_owned_then_own(z, *x);
}
GetTagId { structure: x, .. } => {
// if the structure (record/tag/array) is owned, the extracted value is
self.if_is_owned_then_own(*x, z);
@ -1035,7 +1043,9 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[Ownership] {
unreachable!("These lowlevel operations are turned into mono Expr's")
}
PtrWrite => arena.alloc_slice_copy(&[irrelevant, irrelevant]),
PtrStore => arena.alloc_slice_copy(&[owned, borrowed]),
PtrLoad => arena.alloc_slice_copy(&[owned]),
PtrToZeroed => arena.alloc_slice_copy(&[owned]),
PtrCast | RefCountIncRcPtr | RefCountDecRcPtr | RefCountIncDataPtr | RefCountDecDataPtr
| RefCountIsUnique => {

View File

@ -429,6 +429,15 @@ impl<'a, 'r> Ctx<'a, 'r> {
} => self.with_sym_layout(structure, |ctx, _def_line, layout| {
ctx.check_union_at_index(structure, layout, union_layout, tag_id, index)
}),
&Expr::UnionFieldPtrAtIndex {
structure,
tag_id,
union_layout,
index,
} => self.with_sym_layout(structure, |ctx, _def_line, layout| {
// TODO: I suspect this will fail because the output layout has an extra Box layer?
ctx.check_union_at_index(structure, layout, union_layout, tag_id, index)
}),
Expr::Array { elem_layout, elems } => {
for elem in elems.iter() {
match elem {

View File

@ -211,7 +211,20 @@ fn specialize_drops_stmt<'a, 'i>(
// TODO perhaps we need the union_layout later as well? if so, create a new function/map to store it.
environment.add_union_child(*structure, *binding, *tag_id, *index);
// Generated code might know the tag of the union without switching on it.
// So if we unionAtIndex, we must know the tag and we can use it to specialize the drop.
// So if we UnionAtIndex, we must know the tag and we can use it to specialize the drop.
environment.symbol_tag.insert(*structure, *tag_id);
alloc_let_with_continuation!(environment)
}
Expr::UnionFieldPtrAtIndex {
structure,
tag_id,
union_layout: _,
index,
} => {
// TODO perhaps we need the union_layout later as well? if so, create a new function/map to store it.
environment.add_union_child(*structure, *binding, *tag_id, *index);
// Generated code might know the tag of the union without switching on it.
// So if we UnionFieldPtrAtIndex, we must know the tag and we can use it to specialize the drop.
environment.symbol_tag.insert(*structure, *tag_id);
alloc_let_with_continuation!(environment)
}
@ -1666,8 +1679,13 @@ fn low_level_no_rc(lowlevel: &LowLevel) -> RC {
unreachable!("These lowlevel operations are turned into mono Expr's")
}
PtrCast | PtrWrite | RefCountIncRcPtr | RefCountDecRcPtr | RefCountIncDataPtr
| RefCountDecDataPtr | RefCountIsUnique => {
// only inserted for internal purposes. RC should not touch it
PtrStore => RC::NoRc,
PtrLoad => RC::NoRc,
PtrToZeroed => RC::NoRc,
PtrCast | RefCountIncRcPtr | RefCountDecRcPtr | RefCountIncDataPtr | RefCountDecDataPtr
| RefCountIsUnique => {
unreachable!("Only inserted *after* borrow checking: {:?}", lowlevel);
}
}

View File

@ -345,17 +345,16 @@ impl<'v> RefcountEnvironment<'v> {
// A groupby or something similar would be nice here.
let mut symbol_usage = MutMap::default();
for symbol in symbols {
match {
self.symbols_rc_types
.get(&symbol)
.expect("Expected symbol to be in the map")
} {
match self.symbols_rc_types.get(&symbol) {
// If the symbol is reference counted, we need to increment the usage count.
VarRcType::ReferenceCounted => {
Some(VarRcType::ReferenceCounted) => {
*symbol_usage.entry(symbol).or_default() += 1;
}
// If the symbol is not reference counted, we don't need to do anything.
VarRcType::NotReferenceCounted => continue,
Some(VarRcType::NotReferenceCounted) => continue,
None => {
internal_error!("symbol {symbol:?} does not have an rc type")
}
}
}
symbol_usage
@ -891,6 +890,7 @@ fn insert_refcount_operations_binding<'a>(
Expr::GetTagId { structure, .. }
| Expr::StructAtIndex { structure, .. }
| Expr::UnionAtIndex { structure, .. }
| Expr::UnionFieldPtrAtIndex { structure, .. }
| Expr::ExprUnbox { symbol: structure } => {
// All structures are alive at this point and don't have to be copied in order to take an index out/get tag id/copy values to the stack.
// But we do want to make sure to decrement this item if it is the last reference.
@ -904,6 +904,7 @@ fn insert_refcount_operations_binding<'a>(
match expr {
Expr::StructAtIndex { .. }
| Expr::UnionAtIndex { .. }
| Expr::UnionFieldPtrAtIndex { .. }
| Expr::ExprUnbox { .. } => insert_inc_stmt(arena, *binding, 1, new_stmt),
// No usage of an element of a reference counted symbol. No need to increment.
Expr::GetTagId { .. } => new_stmt,

View File

@ -403,41 +403,6 @@ impl<'a> Proc<'a> {
w.push(b'\n');
String::from_utf8(w).unwrap()
}
fn make_tail_recursive<I>(&mut self, interner: &mut I, env: &mut Env<'a, '_>)
where
I: LayoutInterner<'a>,
{
let mut args = Vec::with_capacity_in(self.args.len(), env.arena);
let mut proc_args = Vec::with_capacity_in(self.args.len(), env.arena);
for (layout, symbol) in self.args {
let new = env.unique_symbol();
args.push((*layout, *symbol, new));
proc_args.push((*layout, new));
}
use self::SelfRecursive::*;
if let SelfRecursive(id) = self.is_self_recursive {
if crate::tail_recursion::is_trmc_candidate(interner, self) {
*self = crate::tail_recursion::TrmcEnv::init(env, interner, self);
} else {
let transformed = crate::tail_recursion::make_tail_recursive(
env.arena,
id,
self.name,
self.body.clone(),
args.into_bump_slice(),
self.ret_layout,
);
if let Some(with_tco) = transformed {
self.body = with_tco;
self.args = proc_args.into_bump_slice();
}
}
}
}
}
/// A host-exposed function must be specialized; it's a seed for subsequent specializations
@ -1032,7 +997,7 @@ impl<'a> Procs<'a> {
MutMap::with_capacity_and_hasher(self.specialized.len(), default_hasher());
for (symbol, layout, mut proc) in self.specialized.into_iter_assert_done() {
proc.make_tail_recursive(&mut layout_cache.interner, env);
// proc.make_tail_recursive(&mut layout_cache.interner, env);
let key = (symbol, layout);
specialized_procs.insert(key, proc);
@ -1888,6 +1853,12 @@ pub enum Expr<'a> {
union_layout: UnionLayout<'a>,
index: u64,
},
UnionFieldPtrAtIndex {
structure: Symbol,
tag_id: TagIdIntType,
union_layout: UnionLayout<'a>,
index: u64,
},
Array {
elem_layout: InLayout<'a>,
@ -2105,6 +2076,19 @@ impl<'a> Expr<'a> {
..
} => text!(alloc, "UnionAtIndex (Id {}) (Index {}) ", tag_id, index)
.append(symbol_to_doc(alloc, *structure, pretty)),
UnionFieldPtrAtIndex {
tag_id,
structure,
index,
..
} => text!(
alloc,
"UnionFieldPtrAtIndex (Id {}) (Index {}) ",
tag_id,
index
)
.append(symbol_to_doc(alloc, *structure, pretty)),
}
}
@ -7678,6 +7662,21 @@ fn substitute_in_expr<'a>(
}),
None => None,
},
UnionFieldPtrAtIndex {
structure,
tag_id,
index,
union_layout,
} => match substitute(subs, *structure) {
Some(structure) => Some(UnionFieldPtrAtIndex {
structure,
tag_id: *tag_id,
index: *index,
union_layout: *union_layout,
}),
None => None,
},
}
}

View File

@ -1,11 +1,88 @@
#![allow(clippy::manual_map)]
use crate::borrow::Ownership;
use crate::ir::{Call, CallType, Env, Expr, JoinPointId, Param, Proc, SelfRecursive, Stmt};
use crate::layout::{InLayout, LambdaName, LayoutInterner, LayoutRepr, TagIdIntType, UnionLayout};
use crate::ir::{
Call, CallType, Expr, JoinPointId, Param, Proc, ProcLayout, SelfRecursive, Stmt, UpdateModeId,
};
use crate::layout::{
InLayout, LambdaName, Layout, LayoutInterner, LayoutRepr, STLayoutInterner, TagIdIntType,
UnionLayout,
};
use bumpalo::collections::Vec;
use bumpalo::Bump;
use roc_module::symbol::Symbol;
use roc_collections::MutMap;
use roc_module::low_level::LowLevel;
use roc_module::symbol::{IdentIds, ModuleId, Symbol};
pub struct Env<'a, 'i> {
arena: &'a Bump,
home: ModuleId,
interner: &'i mut STLayoutInterner<'a>,
ident_ids: &'i mut IdentIds,
}
impl<'a, 'i> Env<'a, 'i> {
pub fn unique_symbol(&mut self) -> Symbol {
let ident_id = self.ident_ids.gen_unique();
Symbol::new(self.home, ident_id)
}
pub fn named_unique_symbol(&mut self, name: &str) -> Symbol {
let ident_id = self.ident_ids.add_str(name);
Symbol::new(self.home, ident_id)
}
}
pub fn apply_trmc<'a, 'i>(
arena: &'a Bump,
interner: &'i mut STLayoutInterner<'a>,
home: ModuleId,
ident_ids: &'i mut IdentIds,
procs: &mut MutMap<(Symbol, ProcLayout<'a>), Proc<'a>>,
) {
let mut env = Env {
arena,
interner,
home,
ident_ids,
};
let env = &mut env;
for (_, proc) in procs {
use self::SelfRecursive::*;
if let SelfRecursive(id) = proc.is_self_recursive {
if crate::tail_recursion::is_trmc_candidate(env.interner, proc) {
let new_proc = crate::tail_recursion::TrmcEnv::init(env, proc);
*proc = new_proc;
} else {
let mut args = Vec::with_capacity_in(proc.args.len(), arena);
let mut proc_args = Vec::with_capacity_in(proc.args.len(), arena);
for (layout, symbol) in proc.args {
let new = env.unique_symbol();
args.push((*layout, *symbol, new));
proc_args.push((*layout, new));
}
let transformed = crate::tail_recursion::make_tail_recursive(
arena,
id,
proc.name,
proc.body.clone(),
args.into_bump_slice(),
proc.ret_layout,
);
if let Some(with_tco) = transformed {
proc.body = with_tco;
proc.args = proc_args.into_bump_slice();
}
}
}
}
}
/// Make tail calls into loops (using join points)
///
@ -325,7 +402,7 @@ fn insert_jumps<'a>(
}
}
pub(crate) fn is_trmc_candidate<'a, I>(interner: &I, proc: &Proc<'a>) -> bool
pub(crate) fn is_trmc_candidate<'a, I>(interner: &'_ I, proc: &'_ Proc<'a>) -> bool
where
I: LayoutInterner<'a>,
{
@ -338,10 +415,68 @@ where
}
// and return a recursive tag union
match interner.get_repr(proc.ret_layout) {
LayoutRepr::Union(union_layout) => union_layout.is_recursive(),
_ => false,
if !matches!(interner.get_repr(proc.ret_layout), LayoutRepr::Union(union_layout) if union_layout.is_recursive())
{
return false;
}
has_cons_in_tail_position(&proc.body, proc.name)
}
fn has_cons_in_tail_position(initial_stmt: &Stmt<'_>, function_name: LambdaName) -> bool {
// we are looking for code of the form
//
// let x = Tag a b c
// ret x
let mut stack = vec![(None, initial_stmt)];
while let Some((recursive_call, stmt)) = stack.pop() {
match stmt {
Stmt::Let(symbol, expr, _, next) => {
if let Some(cons_info) = TrmcEnv::is_terminal_constructor(stmt) {
// must use the result of a recursive call directly as an argument
if let Some(recursive_call) = recursive_call {
if cons_info.arguments.contains(&recursive_call) {
return true;
}
}
}
let recursive_call = recursive_call
.or_else(|| TrmcEnv::is_recursive_expr(expr, function_name).map(|_| *symbol));
stack.push((recursive_call, next));
}
Stmt::Switch {
branches,
default_branch,
..
} => {
for (_, _, stmt) in branches.iter() {
stack.push((recursive_call, stmt));
}
stack.push((recursive_call, default_branch.1));
}
Stmt::Refcounting(_, next) => {
stack.push((recursive_call, next));
}
Stmt::Expect { remainder, .. }
| Stmt::ExpectFx { remainder, .. }
| Stmt::Dbg { remainder, .. } => {
stack.push((recursive_call, remainder));
}
Stmt::Join {
body, remainder, ..
} => {
stack.push((recursive_call, body));
stack.push((recursive_call, remainder));
}
Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => { /* terminal */ }
}
}
false
}
#[derive(Clone)]
@ -358,6 +493,7 @@ pub(crate) struct TrmcEnv<'a> {
recursive_call: Option<(Symbol, Call<'a>)>,
}
#[derive(Debug)]
struct ConstructorInfo<'a> {
tag_layout: UnionLayout<'a>,
tag_id: TagIdIntType,
@ -365,25 +501,18 @@ struct ConstructorInfo<'a> {
}
impl<'a> TrmcEnv<'a> {
fn is_recursive_expr(&mut self, expr: &Expr<'a>) -> Option<Call<'a>> {
if let Expr::Call(call) = expr {
self.is_recursive_call(call).then_some(call.clone())
} else {
None
}
}
fn is_terminal_constructor(&mut self, stmt: &Stmt<'a>) -> Option<ConstructorInfo<'a>> {
#[inline(always)]
fn is_terminal_constructor(stmt: &Stmt<'a>) -> Option<ConstructorInfo<'a>> {
match stmt {
Stmt::Let(s1, expr, _layout, Stmt::Ret(s2)) if s1 == s2 => {
self.get_contructor_info(expr)
Self::get_contructor_info(expr)
}
_ => None,
}
}
fn get_contructor_info(&mut self, expr: &Expr<'a>) -> Option<ConstructorInfo<'a>> {
fn get_contructor_info(expr: &Expr<'a>) -> Option<ConstructorInfo<'a>> {
if let Expr::Tag {
tag_layout,
tag_id,
@ -402,16 +531,19 @@ impl<'a> TrmcEnv<'a> {
}
}
fn is_recursive_call(&mut self, call: &Call<'a>) -> bool {
fn is_recursive_expr(expr: &Expr<'a>, lambda_name: LambdaName<'_>) -> Option<Call<'a>> {
if let Expr::Call(call) = expr {
Self::is_recursive_call(call, lambda_name).then_some(call.clone())
} else {
None
}
}
fn is_recursive_call(call: &Call<'a>, lambda_name: LambdaName<'_>) -> bool {
match call.call_type {
CallType::ByName {
name,
ret_layout,
arg_layouts,
specialization_id,
} => {
CallType::ByName { name, .. } => {
// TODO are there other restrictions?
name == self.function_name
name == lambda_name
}
CallType::Foreign { .. } | CallType::LowLevel { .. } | CallType::HigherOrder(_) => {
false
@ -421,16 +553,16 @@ impl<'a> TrmcEnv<'a> {
fn ptr_write(
env: &mut Env<'a, '_>,
interner: &mut impl LayoutInterner<'a>,
return_layout: InLayout<'a>,
_return_layout: InLayout<'a>,
ptr: Symbol,
value: Symbol,
next: &'a Stmt<'a>,
) -> Stmt<'a> {
let box_write = Call {
call_type: crate::ir::CallType::LowLevel {
op: roc_module::low_level::LowLevel::PtrWrite,
update_mode: env.next_update_mode_id(),
op: LowLevel::PtrStore,
// update_mode: env.next_update_mode_id(),
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: env.arena.alloc([ptr, value]),
};
@ -438,16 +570,13 @@ impl<'a> TrmcEnv<'a> {
Stmt::Let(
env.named_unique_symbol("_ptr_write_unit"),
Expr::Call(box_write),
interner.insert_direct_no_semantic(LayoutRepr::Boxed(return_layout)),
// interner.insert_direct_no_semantic(LayoutRepr::Boxed(return_layout)),
Layout::UNIT,
next,
)
}
pub fn init(
env: &mut Env<'a, '_>,
interner: &mut impl LayoutInterner<'a>,
proc: &Proc<'a>,
) -> Proc<'a> {
pub fn init<'i>(env: &mut Env<'a, 'i>, proc: &Proc<'a>) -> Proc<'a> {
let arena = env.arena;
let return_layout = proc.ret_layout;
@ -475,8 +604,9 @@ impl<'a> TrmcEnv<'a> {
let null_symbol = env.named_unique_symbol("null");
let let_null = |next| Stmt::Let(null_symbol, Expr::NullPointer, return_layout, next);
let box_return_layout =
interner.insert_direct_no_semantic(LayoutRepr::Boxed(return_layout));
let box_return_layout = env
.interner
.insert_direct_no_semantic(LayoutRepr::Boxed(return_layout));
let box_null = Expr::ExprBox {
symbol: null_symbol,
};
@ -508,7 +638,7 @@ impl<'a> TrmcEnv<'a> {
let joinpoint = Stmt::Join {
id: joinpoint_id,
parameters: joinpoint_parameters.into_bump_slice(),
body: arena.alloc(this.walk_stmt(env, interner, &proc.body)),
body: arena.alloc(this.walk_stmt(env, &proc.body)),
remainder: arena.alloc(jump_stmt),
};
@ -534,24 +664,19 @@ impl<'a> TrmcEnv<'a> {
}
}
fn walk_stmt(
&mut self,
env: &mut Env<'a, '_>,
interner: &mut impl LayoutInterner<'a>,
stmt: &Stmt<'a>,
) -> Stmt<'a> {
fn walk_stmt(&mut self, env: &mut Env<'a, '_>, stmt: &Stmt<'a>) -> Stmt<'a> {
let arena = env.arena;
match stmt {
Stmt::Let(symbol, expr, layout, next) => {
if self.recursive_call.is_none() {
if let Some(call) = self.is_recursive_expr(expr) {
if let Some(call) = Self::is_recursive_expr(expr, self.function_name) {
self.recursive_call = Some((*symbol, call));
return self.walk_stmt(env, interner, next);
return self.walk_stmt(env, next);
}
}
if let Some(cons_info) = self.is_terminal_constructor(stmt) {
if let Some(cons_info) = Self::is_terminal_constructor(stmt) {
match &self.recursive_call {
None => {
// this control flow path did not encounter a recursive call. Just
@ -561,7 +686,7 @@ impl<'a> TrmcEnv<'a> {
let output = define_tag(arena.alloc(
//
self.non_trmc_return(env, interner, *symbol),
self.non_trmc_return(env, *symbol),
));
return output;
@ -571,11 +696,21 @@ impl<'a> TrmcEnv<'a> {
// branch.
// TODO remove unwrap. also what if the symbol occurs more than once?
let recursive_field_index = cons_info
.arguments
.iter()
.position(|s| *s == *call_symbol)
.unwrap();
let opt_recursive_field_index =
cons_info.arguments.iter().position(|s| *s == *call_symbol);
let recursive_field_index = match opt_recursive_field_index {
None => {
let next = self.walk_stmt(env, next);
return Stmt::Let(
*symbol,
expr.clone(),
*layout,
arena.alloc(next),
);
}
Some(v) => v,
};
let mut arguments =
Vec::from_iter_in(cons_info.arguments.iter().copied(), env.arena);
@ -589,8 +724,11 @@ impl<'a> TrmcEnv<'a> {
let let_tag = |next| Stmt::Let(*symbol, tag_expr, *layout, next);
let get_reference_expr = Expr::ExprBox {
symbol: self.null_symbol,
let get_reference_expr = Expr::UnionFieldPtrAtIndex {
structure: *symbol,
tag_id: cons_info.tag_id,
union_layout: cons_info.tag_layout,
index: recursive_field_index as _,
};
let new_hole_symbol = env.named_unique_symbol("newHole");
@ -616,7 +754,6 @@ impl<'a> TrmcEnv<'a> {
//
Self::ptr_write(
env,
interner,
*layout,
self.hole_symbol,
*symbol,
@ -630,7 +767,7 @@ impl<'a> TrmcEnv<'a> {
}
}
let next = self.walk_stmt(env, interner, next);
let next = self.walk_stmt(env, next);
Stmt::Let(*symbol, expr.clone(), *layout, arena.alloc(next))
}
Stmt::Switch {
@ -646,14 +783,13 @@ impl<'a> TrmcEnv<'a> {
for (id, info, stmt) in branches.iter() {
self.recursive_call = opt_recursive_call.clone();
let new_stmt = self.walk_stmt(env, interner, stmt);
let new_stmt = self.walk_stmt(env, stmt);
new_branches.push((*id, info.clone(), new_stmt));
}
self.recursive_call = opt_recursive_call;
let new_default_branch =
&*arena.alloc(self.walk_stmt(env, interner, default_branch.1));
let new_default_branch = &*arena.alloc(self.walk_stmt(env, default_branch.1));
Stmt::Switch {
cond_symbol: *cond_symbol,
@ -666,42 +802,92 @@ impl<'a> TrmcEnv<'a> {
Stmt::Ret(symbol) => {
// write the symbol we're supposed to return into the hole
// then read initial_symbol and return its contents
self.non_trmc_return(env, interner, *symbol)
self.non_trmc_return(env, *symbol)
}
Stmt::Refcounting(_, _) => todo!(),
Stmt::Expect { .. } => todo!(),
Stmt::ExpectFx { .. } => todo!(),
Stmt::Dbg { .. } => todo!(),
Stmt::Join { .. } => todo!(),
Stmt::Jump(_, _) => todo!(),
Stmt::Crash(_, _) => todo!(),
Stmt::Refcounting(op, next) => {
let new_next = self.walk_stmt(env, next);
Stmt::Refcounting(*op, arena.alloc(new_next))
}
Stmt::Expect {
condition,
region,
lookups,
variables,
remainder,
} => Stmt::Expect {
condition: *condition,
region: *region,
lookups,
variables,
remainder: arena.alloc(self.walk_stmt(env, remainder)),
},
Stmt::ExpectFx {
condition,
region,
lookups,
variables,
remainder,
} => Stmt::Expect {
condition: *condition,
region: *region,
lookups,
variables,
remainder: arena.alloc(self.walk_stmt(env, remainder)),
},
Stmt::Dbg {
symbol,
variable,
remainder,
} => Stmt::Dbg {
symbol: *symbol,
variable: *variable,
remainder: arena.alloc(self.walk_stmt(env, remainder)),
},
Stmt::Join {
id,
parameters,
body,
remainder,
} => {
let new_body = self.walk_stmt(env, body);
let new_remainder = self.walk_stmt(env, remainder);
Stmt::Join {
id: *id,
parameters,
body: arena.alloc(new_body),
remainder: arena.alloc(new_remainder),
}
}
Stmt::Jump(id, arguments) => Stmt::Jump(*id, arguments),
Stmt::Crash(symbol, crash_tag) => Stmt::Crash(*symbol, *crash_tag),
}
}
fn non_trmc_return(
&mut self,
env: &mut Env<'a, '_>,
interner: &mut impl LayoutInterner<'a>,
value_symbol: Symbol,
) -> Stmt<'a> {
fn non_trmc_return(&mut self, env: &mut Env<'a, '_>, value_symbol: Symbol) -> Stmt<'a> {
let arena = env.arena;
let layout = self.return_layout;
let unbox_expr = Expr::ExprUnbox {
symbol: self.initial_box_symbol,
};
let final_symbol = env.named_unique_symbol("final");
let unbox = |next| Stmt::Let(final_symbol, unbox_expr, layout, next);
let call = Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrLoad,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: &*arena.alloc([self.initial_box_symbol]),
};
let ptr_load = |next| Stmt::Let(final_symbol, Expr::Call(call), layout, next);
Self::ptr_write(
env,
interner,
layout,
self.hole_symbol,
value_symbol,
arena.alloc(
//
unbox(arena.alloc(Stmt::Ret(final_symbol))),
ptr_load(arena.alloc(Stmt::Ret(final_symbol))),
),
)
}

View File

@ -3,4 +3,29 @@ app "rocLovesZig"
imports []
provides [main] to pf
main = "Roc <3 Zig!\n"
LinkedList a : [Nil, Cons a (LinkedList a)]
map : LinkedList a, (a -> b) -> LinkedList b
map = \list, f ->
when list is
Nil -> Nil
Cons x xs -> Cons (f x) (map xs f)
unfold : a, Nat -> LinkedList a
unfold = \value, n ->
when n is
0 -> Nil
_ -> Cons value (unfold value (n - 1))
length : LinkedList a -> I64
length = \list ->
when list is
Nil -> 0
Cons _ rest -> 1 + length rest
main : Str
main =
unfold 42 5
|> map (\x -> x + 1i64)
|> length
|> Num.toStr