diff --git a/crates/compiler/mono/src/ir.rs b/crates/compiler/mono/src/ir.rs index b6c7bb3f86..80c8bb0702 100644 --- a/crates/compiler/mono/src/ir.rs +++ b/crates/compiler/mono/src/ir.rs @@ -1,6 +1,7 @@ #![allow(clippy::manual_map)] use crate::borrow::Ownership; +use crate::ir::erased::{build_erased_function, ResolvedErasedLambda}; use crate::ir::literal::{make_num_literal, IntOrFloatValue}; use crate::layout::{ self, Builtin, ClosureCallOptions, ClosureDataKind, ClosureRepresentation, EnumDispatch, @@ -3683,6 +3684,15 @@ fn specialize_proc_help<'a>( } } } + (Some(ClosureDataKind::Erased), CapturedSymbols::Captured(captured)) => { + specialized_body = erased::unpack_closure_data( + env, + layout_cache, + Symbol::ARG_CLOSURE, + captured, + specialized_body, + ); + } (None, CapturedSymbols::None) | (None, CapturedSymbols::Captured([])) => {} _ => unreachable!("to closure or not to closure?"), } @@ -5217,7 +5227,34 @@ pub fn with_hole<'a>( RawFunctionLayout::ZeroArgumentThunk(_) => { unreachable!("a closure syntactically always must have at least one argument") } - RawFunctionLayout::ErasedFunction(..) => todo_lambda_erasure!(), + RawFunctionLayout::ErasedFunction(_argument_layouts, _ret_layout) => { + let captured_symbols = Vec::from_iter_in(captured_symbols, env.arena); + let captured_symbols = captured_symbols.into_bump_slice(); + let captured_symbols = CapturedSymbols::Captured(captured_symbols); + let resolved_erased_lambda = + ResolvedErasedLambda::new(env, layout_cache, name, captured_symbols); + + let inserted = procs.insert_anonymous( + env, + resolved_erased_lambda.lambda_name(), + function_type, + arguments, + loc_body, + captured_symbols, + return_type, + layout_cache, + ); + + if let Err(e) = inserted { + return runtime_error( + env, + env.arena.alloc(format!("RuntimeError: {:?}", e,)), + ); + } + drop(inserted); + + build_erased_function(env, layout_cache, resolved_erased_lambda, assigned, hole) + } RawFunctionLayout::Function(_argument_layouts, lambda_set, _ret_layout) => { let mut captured_symbols = Vec::from_iter_in(captured_symbols, env.arena); captured_symbols.sort(); @@ -5250,9 +5287,8 @@ pub fn with_hole<'a>( env, env.arena.alloc(format!("RuntimeError: {e:?}",)), ); - } else { - drop(inserted); } + drop(inserted); // define the closure data diff --git a/crates/compiler/mono/src/ir/erased.rs b/crates/compiler/mono/src/ir/erased.rs index cccb5ef91d..0f52b928f8 100644 --- a/crates/compiler/mono/src/ir/erased.rs +++ b/crates/compiler/mono/src/ir/erased.rs @@ -433,3 +433,81 @@ impl<'a> ResolvedErasedLambda<'a> { self.lambda_name } } + +/// Given +/// +/// ``` +/// captures_symbol : void* +/// captures = { a: A, b: B } +/// ``` +/// +/// We generate +/// +/// ``` +/// heap_captures: Box { A, B } = Expr::Call(Lowlevel { Cast, captures_symbol }) +/// stack_captures = Expr::Unbox(heap_captures) +/// a = Expr::StructAtIndex(stack_captures, 0) +/// b = Expr::StructAtIndex(stack_captures, 1) +/// +/// ``` +pub fn unpack_closure_data<'a>( + env: &mut Env<'a, '_>, + layout_cache: &mut LayoutCache<'a>, + captures_symbol: Symbol, + captures: &[(Symbol, Variable)], + mut hole: Stmt<'a>, +) -> Stmt<'a> { + let heap_captures = env.unique_symbol(); + let stack_captures = env.unique_symbol(); + + let captures_layouts = { + let layouts = captures + .iter() + .map(|(_, var)| layout_cache.from_var(env.arena, *var, env.subs).unwrap()); + &*env.arena.alloc_slice_fill_iter(layouts) + }; + + let stack_captures_layout = + layout_cache.put_in_direct_no_semantic(LayoutRepr::Struct(captures_layouts)); + let heap_captures_layout = + layout_cache.put_in_direct_no_semantic(LayoutRepr::Boxed(stack_captures_layout)); + + for (i, ((capture, _capture_var), &capture_layout)) in + captures.iter().zip(captures_layouts).enumerate().rev() + { + hole = Stmt::Let( + *capture, + Expr::StructAtIndex { + index: i as _, + field_layouts: captures_layouts, + structure: stack_captures, + }, + capture_layout, + env.arena.alloc(hole), + ); + } + + hole = Stmt::Let( + stack_captures, + Expr::ExprUnbox { + symbol: heap_captures, + }, + stack_captures_layout, + env.arena.alloc(hole), + ); + + hole = Stmt::Let( + heap_captures, + Expr::Call(Call { + call_type: CallType::LowLevel { + op: LowLevel::PtrCast, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: env.arena.alloc([captures_symbol]), + }), + heap_captures_layout, + env.arena.alloc(hole), + ); + + hole +}