how to build an erased fn

This commit is contained in:
Ayaz Hafiz 2023-06-25 20:42:18 -05:00
parent cf30f02e01
commit a1eb641bb6
No known key found for this signature in database
GPG Key ID: 0E2A37416A25EF58

View File

@ -4,11 +4,12 @@ use roc_types::subs::Variable;
use crate::{
borrow::Ownership,
layout::{ErasedIndex, FunctionPointer, InLayout, Layout, LayoutCache, LayoutRepr},
layout::{ErasedIndex, FunctionPointer, InLayout, LambdaName, Layout, LayoutCache, LayoutRepr},
};
use super::{
with_hole, BranchInfo, Call, CallType, Env, Expr, JoinPointId, Param, Procs, Stmt, UpdateModeId,
with_hole, BranchInfo, Call, CallType, CapturedSymbols, Env, Expr, JoinPointId, Param, Procs,
Stmt, UpdateModeId,
};
const ERASED_FUNCTION_FIELD_LAYOUTS: &[InLayout] =
@ -264,3 +265,124 @@ pub fn call_erased_function<'a>(
env.arena.alloc(joinpoint),
)
}
/// Given
///
/// ```
/// f = \{} -> s
/// ```
///
/// We generate
///
/// ```
/// value = Expr::Box({s})
/// callee = Expr::FunctionPointer(f)
/// refcounter = TODO
/// f = Expr::Struct({ value, callee, refcounter })
/// ```
pub fn build_erased_function<'a>(
env: &mut Env<'a, '_>,
layout_cache: &mut LayoutCache<'a>,
lambda_name: LambdaName<'a>,
captures: CapturedSymbols<'a>,
assigned: Symbol,
hole: &'a Stmt<'a>,
) -> Stmt<'a> {
let value = env.unique_symbol();
let callee = env.unique_symbol();
let refcounter = env.unique_symbol();
// assigned = Expr::Struct({ value, callee, refcounter })
// hole <assigned>
let result = Stmt::Let(
assigned,
Expr::Struct(env.arena.alloc([value, callee, refcounter])),
Layout::ERASED,
hole,
);
// refcounter = TODO
// <hole>
let result = Stmt::Let(
refcounter,
Expr::NullPointer,
Layout::OPAQUE_PTR,
env.arena.alloc(result),
);
// callee = Expr::FunctionPointer(f)
let result = Stmt::Let(
callee,
Expr::FunctionPointer { lambda_name },
Layout::OPAQUE_PTR,
env.arena.alloc(result),
);
// value = Expr::Box({s})
match captures {
CapturedSymbols::None => {
// value = nullptr
// <hole>
Stmt::Let(
value,
Expr::NullPointer,
Layout::OPAQUE_PTR,
env.arena.alloc(result),
)
}
CapturedSymbols::Captured(captures) => {
// captures = {...captures}
// captures = Box(captures)
// value = Cast(captures, void*)
// <hole>
let stack_captures = env.unique_symbol();
let stack_captures_layout = {
let layouts = captures
.iter()
.map(|(_, var)| layout_cache.from_var(env.arena, *var, env.subs).unwrap());
let layouts = env.arena.alloc_slice_fill_iter(layouts);
layout_cache.put_in_direct_no_semantic(LayoutRepr::Struct(layouts))
};
let stack_captures_symbols = {
let symbols = captures.iter().map(|(sym, _)| *sym);
env.arena.alloc_slice_fill_iter(symbols)
};
let boxed_captures = env.unique_symbol();
let boxed_captures_layout =
layout_cache.put_in_direct_no_semantic(LayoutRepr::Boxed(stack_captures_layout));
let result = Stmt::Let(
stack_captures,
Expr::Struct(stack_captures_symbols),
stack_captures_layout,
env.arena.alloc(result),
);
let result = Stmt::Let(
boxed_captures,
Expr::ExprBox {
symbol: stack_captures,
},
boxed_captures_layout,
env.arena.alloc(result),
);
let result = Stmt::Let(
value,
Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::PtrCast,
update_mode: UpdateModeId::BACKEND_DUMMY,
},
arguments: env.arena.alloc([boxed_captures]),
}),
Layout::OPAQUE_PTR,
env.arena.alloc(result),
);
result
}
}
}