From 4e55a4bf9280bdaf5847aa5d641e39e853eb535a Mon Sep 17 00:00:00 2001 From: Folkert Date: Wed, 5 Aug 2020 16:10:45 +0200 Subject: [PATCH] more pattern matching fidling --- cli/src/repl.rs | 2 +- compiler/build/src/program.rs | 2 +- compiler/gen/src/llvm/build.rs | 810 ++++++++++++++++++++++++++- compiler/gen/tests/gen_list.rs | 138 ++--- compiler/gen/tests/gen_num.rs | 200 +++---- compiler/gen/tests/gen_primitives.rs | 54 +- compiler/gen/tests/gen_records.rs | 56 +- compiler/gen/tests/gen_tags.rs | 124 ++-- compiler/gen/tests/helpers/eval.rs | 230 +++++++- compiler/mono/src/decision_tree2.rs | 182 ++++-- compiler/mono/src/experiment.rs | 394 +++++++++++-- compiler/mono/tests/test_mono.rs | 344 +++++++++++- 12 files changed, 2141 insertions(+), 395 deletions(-) diff --git a/cli/src/repl.rs b/cli/src/repl.rs index 46e06d24be..4714066c68 100644 --- a/cli/src/repl.rs +++ b/cli/src/repl.rs @@ -333,7 +333,7 @@ pub fn gen(src: &str, target: Triple, opt_level: OptLevel) -> Result<(String, St let ret = roc_gen::llvm::build::build_expr( &env, &mut layout_ids, - &ImMap::default(), + &roc_gen::llvm::build::Scope::default(), main_fn, &main_body, ); diff --git a/compiler/build/src/program.rs b/compiler/build/src/program.rs index cd0f56e697..a2fda7c9b3 100644 --- a/compiler/build/src/program.rs +++ b/compiler/build/src/program.rs @@ -302,7 +302,7 @@ pub fn gen( let ret = roc_gen::llvm::build::build_expr( &env, &mut layout_ids, - &ImMap::default(), + &roc_gen::llvm::build::Scope::default(), main_fn, &main_body, ); diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index cce6bf6f8e..fa591dd281 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -5,6 +5,7 @@ use crate::llvm::convert::{ }; use bumpalo::collections::Vec; use bumpalo::Bump; +use inkwell::basic_block::BasicBlock; use inkwell::builder::Builder; use inkwell::context::Context; use inkwell::memory_buffer::MemoryBuffer; @@ -12,12 +13,13 @@ use inkwell::module::{Linkage, Module}; use inkwell::passes::{PassManager, PassManagerBuilder}; use inkwell::types::{BasicTypeEnum, FunctionType, IntType, PointerType, StructType}; use inkwell::values::BasicValueEnum::{self, *}; -use inkwell::values::{FloatValue, FunctionValue, IntValue, PointerValue, StructValue}; +use inkwell::values::{FloatValue, FunctionValue, IntValue, PhiValue, PointerValue, StructValue}; use inkwell::AddressSpace; use inkwell::{IntPredicate, OptimizationLevel}; use roc_collections::all::ImMap; use roc_module::low_level::LowLevel; use roc_module::symbol::{Interns, Symbol}; +use roc_mono::experiment::JoinPointId; use roc_mono::expr::{Expr, Proc}; use roc_mono::layout::{Builtin, Layout, Ownership}; use target_lexicon::CallingConvention; @@ -35,7 +37,38 @@ pub enum OptLevel { Optimize, } -pub type Scope<'a, 'ctx> = ImMap, PointerValue<'ctx>)>; +// pub type Scope<'a, 'ctx> = ImMap, PointerValue<'ctx>)>; +#[derive(Default, Debug, Clone, PartialEq)] +pub struct Scope<'a, 'ctx> { + symbols: ImMap, PointerValue<'ctx>)>, + join_points: ImMap>, +} + +impl<'a, 'ctx> Scope<'a, 'ctx> { + fn get(&self, symbol: &Symbol) -> Option<&(Layout<'a>, PointerValue<'ctx>)> { + self.symbols.get(symbol) + } + fn insert(&mut self, symbol: Symbol, value: (Layout<'a>, PointerValue<'ctx>)) { + self.symbols.insert(symbol, value); + } + fn remove(&mut self, symbol: &Symbol) { + self.symbols.remove(symbol); + } + /* + fn get_join_point(&self, symbol: &JoinPointId) -> Option<&PhiValue<'ctx>> { + self.join_points.get(symbol) + } + fn remove_join_point(&mut self, symbol: &JoinPointId) { + self.join_points.remove(symbol); + } + fn get_mut_join_point(&mut self, symbol: &JoinPointId) -> Option<&mut PhiValue<'ctx>> { + self.join_points.get_mut(symbol) + } + fn insert_join_point(&mut self, symbol: JoinPointId, value: PhiValue<'ctx>) { + self.join_points.insert(symbol, value); + } + */ +} pub struct Env<'a, 'ctx, 'env> { pub arena: &'a Bump, @@ -159,6 +192,524 @@ pub fn add_passes(fpm: &PassManager>, opt_level: OptLevel) { pmb.populate_function_pass_manager(&fpm); } +pub fn build_exp_literal<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + scope: &Scope<'a, 'ctx>, + parent: FunctionValue<'ctx>, + literal: &roc_mono::experiment::Literal<'a>, +) -> BasicValueEnum<'ctx> { + use roc_mono::experiment::Literal::*; + + match literal { + Int(num) => env.context.i64_type().const_int(*num as u64, true).into(), + Float(num) => env.context.f64_type().const_float(*num).into(), + Bool(b) => env.context.bool_type().const_int(*b as u64, false).into(), + Byte(b) => env.context.i8_type().const_int(*b as u64, false).into(), + _ => todo!("unsupported literal {:?}", literal), + } +} + +pub fn build_exp_expr<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + scope: &Scope<'a, 'ctx>, + parent: FunctionValue<'ctx>, + expr: &roc_mono::experiment::Expr<'a>, +) -> BasicValueEnum<'ctx> { + use roc_mono::experiment::CallType::*; + use roc_mono::experiment::Expr::*; + + match expr { + Literal(literal) => build_exp_literal(env, layout_ids, scope, parent, literal), + RunLowLevel(op, symbols) => { + let mut args = Vec::with_capacity_in(symbols.len(), env.arena); + + for symbol in symbols.iter() { + match scope.get(symbol) { + Some((layout, _)) => { + args.push((roc_mono::expr::Expr::Load(*symbol), layout.clone())) + } + None => panic!("There was no entry for {:?} in scope {:?}", symbol, scope), + } + } + + run_low_level(env, layout_ids, scope, parent, *op, args.into_bump_slice()) + } + + FunctionCall { + call_type: ByName(name), + layout, + args, + } => { + let mut arg_tuples: Vec = Vec::with_capacity_in(args.len(), env.arena); + + for symbol in args.iter() { + arg_tuples.push(load_symbol(env, scope, symbol)); + } + + call_with_args_ir( + env, + layout_ids, + layout, + *name, + parent, + arg_tuples.into_bump_slice(), + ) + } + + Struct(sorted_fields) => { + let ctx = env.context; + let builder = env.builder; + let ptr_bytes = env.ptr_bytes; + + // Determine types + let num_fields = sorted_fields.len(); + let mut field_types = Vec::with_capacity_in(num_fields, env.arena); + let mut field_vals = Vec::with_capacity_in(num_fields, env.arena); + + for symbol in sorted_fields.iter() { + // Zero-sized fields have no runtime representation. + // The layout of the struct expects them to be dropped! + let (field_layout, field_expr) = load_symbol_and_layout(env, scope, symbol); + if field_layout.stack_size(ptr_bytes) != 0 { + field_types.push(basic_type_from_layout( + env.arena, + env.context, + &field_layout, + env.ptr_bytes, + )); + + field_vals.push(field_expr); + } + } + + // If the record has only one field that isn't zero-sized, + // unwrap it. This is what the layout expects us to do. + if field_vals.len() == 1 { + field_vals.pop().unwrap() + } else { + // Create the struct_type + let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); + let mut struct_val = struct_type.const_zero().into(); + + // Insert field exprs into struct_val + for (index, field_val) in field_vals.into_iter().enumerate() { + struct_val = builder + .build_insert_value(struct_val, field_val, index as u32, "insert_field") + .unwrap(); + } + + BasicValueEnum::StructValue(struct_val.into_struct_value()) + } + } + + Tag { + union_size, + arguments, + .. + } if *union_size == 1 => { + let it = arguments.iter(); + + let ctx = env.context; + let ptr_bytes = env.ptr_bytes; + let builder = env.builder; + + // Determine types + let num_fields = arguments.len() + 1; + let mut field_types = Vec::with_capacity_in(num_fields, env.arena); + let mut field_vals = Vec::with_capacity_in(num_fields, env.arena); + + for (field_symbol) in it { + let (field_layout, val) = load_symbol_and_layout(env, scope, field_symbol); + // Zero-sized fields have no runtime representation. + // The layout of the struct expects them to be dropped! + if field_layout.stack_size(ptr_bytes) != 0 { + let field_type = basic_type_from_layout( + env.arena, + env.context, + &field_layout, + env.ptr_bytes, + ); + + field_types.push(field_type); + field_vals.push(val); + } + } + + dbg!(&field_vals); + // If the struct has only one field that isn't zero-sized, + // unwrap it. This is what the layout expects us to do. + if field_vals.len() == 1 { + field_vals.pop().unwrap() + } else { + // Create the struct_type + let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); + let mut struct_val = struct_type.const_zero().into(); + + // Insert field exprs into struct_val + for (index, field_val) in field_vals.into_iter().enumerate() { + struct_val = builder + .build_insert_value(struct_val, field_val, index as u32, "insert_field") + .unwrap(); + } + + BasicValueEnum::StructValue(struct_val.into_struct_value()) + } + } + + Tag { + arguments, + tag_layout, + union_size, + .. + } => { + debug_assert!(*union_size > 1); + let ptr_size = env.ptr_bytes; + + let whole_size = tag_layout.stack_size(ptr_size); + let mut filler = tag_layout.stack_size(ptr_size); + + let ctx = env.context; + let builder = env.builder; + + // Determine types + let num_fields = arguments.len() + 1; + let mut field_types = Vec::with_capacity_in(num_fields, env.arena); + let mut field_vals = Vec::with_capacity_in(num_fields, env.arena); + + for field_symbol in arguments.iter() { + let (field_layout, val) = load_symbol_and_layout(env, scope, field_symbol); + let field_size = field_layout.stack_size(ptr_size); + + // Zero-sized fields have no runtime representation. + // The layout of the struct expects them to be dropped! + if field_size != 0 { + let field_type = + basic_type_from_layout(env.arena, env.context, field_layout, ptr_size); + + field_types.push(field_type); + field_vals.push(val); + + filler -= field_size; + } + } + + // TODO verify that this is required (better safe than sorry) + if filler > 0 { + field_types.push(env.context.i8_type().array_type(filler).into()); + } + + // Create the struct_type + let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); + let mut struct_val = struct_type.const_zero().into(); + + // Insert field exprs into struct_val + for (index, field_val) in field_vals.into_iter().enumerate() { + struct_val = builder + .build_insert_value(struct_val, field_val, index as u32, "insert_field") + .unwrap(); + } + + // How we create tag values + // + // The memory layout of tags can be different. e.g. in + // + // [ Ok Int, Err Str ] + // + // the `Ok` tag stores a 64-bit integer, the `Err` tag stores a struct. + // All tags of a union must have the same length, for easy addressing (e.g. array lookups). + // So we need to ask for the maximum of all tag's sizes, even if most tags won't use + // all that memory, and certainly won't use it in the same way (the tags have fields of + // different types/sizes) + // + // In llvm, we must be explicit about the type of value we're creating: we can't just + // make a unspecified block of memory. So what we do is create a byte array of the + // desired size. Then when we know which tag we have (which is here, in this function), + // we need to cast that down to the array of bytes that llvm expects + // + // There is the bitcast instruction, but it doesn't work for arrays. So we need to jump + // through some hoops using store and load to get this to work: the array is put into a + // one-element struct, which can be cast to the desired type. + // + // This tricks comes from + // https://github.com/raviqqe/ssf/blob/bc32aae68940d5bddf5984128e85af75ca4f4686/ssf-llvm/src/expression_compiler.rs#L116 + + let array_type = ctx.i8_type().array_type(whole_size); + + let result = cast_basic_basic( + builder, + struct_val.into_struct_value().into(), + array_type.into(), + ); + + // For unclear reasons, we can't cast an array to a struct on the other side. + // the solution is to wrap the array in a struct (yea...) + let wrapper_type = ctx.struct_type(&[array_type.into()], false); + let mut wrapper_val = wrapper_type.const_zero().into(); + wrapper_val = builder + .build_insert_value(wrapper_val, result, 0, "insert_field") + .unwrap(); + + wrapper_val.into_struct_value().into() + } + AccessAtIndex { + index, + structure, + is_unwrapped, + .. + } if *is_unwrapped => { + use inkwell::values::BasicValueEnum::*; + + let builder = env.builder; + + // Get Struct val + // Since this is a one-element tag union, we get the underlying value + // right away. However, that struct might have only one field which + // is not zero-sized, which would make it unwrapped. If that happens, + // we must be + match load_symbol(env, scope, structure) { + StructValue(argument) => builder + .build_extract_value( + argument, + *index as u32, + env.arena.alloc(format!("tag_field_access_{}_", index)), + ) + .unwrap(), + other => { + // If it's not a Struct, that means it was unwrapped, + // so we should return it directly. + other + } + } + } + + AccessAtIndex { + index, + structure, + field_layouts, + .. + } => { + let builder = env.builder; + + // Determine types, assumes the descriminant is in the field layouts + let num_fields = field_layouts.len(); + let mut field_types = Vec::with_capacity_in(num_fields, env.arena); + let ptr_bytes = env.ptr_bytes; + + for field_layout in field_layouts.iter() { + let field_type = + basic_type_from_layout(env.arena, env.context, &field_layout, ptr_bytes); + field_types.push(field_type); + } + + // Create the struct_type + let struct_type = env + .context + .struct_type(field_types.into_bump_slice(), false); + + // cast the argument bytes into the desired shape for this tag + let argument = load_symbol(env, scope, structure).into_struct_value(); + + let struct_value = cast_struct_struct(builder, argument, struct_type); + + builder + .build_extract_value(struct_value, *index as u32, "") + .expect("desired field did not decode") + } + _ => todo!("unsupported literal {:?}", expr), + } +} + +pub fn build_exp_stmt<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + scope: &mut Scope<'a, 'ctx>, + parent: FunctionValue<'ctx>, + stmt: &roc_mono::experiment::Stmt<'a>, +) -> BasicValueEnum<'ctx> { + use roc_mono::experiment::Stmt::*; + + match stmt { + Let(symbol, expr, layout, cont) => { + println!("{} {:?}", symbol, expr); + let context = &env.context; + + let val = build_exp_expr(env, layout_ids, &scope, parent, &expr); + let expr_bt = basic_type_from_layout(env.arena, context, &layout, env.ptr_bytes); + let alloca = + create_entry_block_alloca(env, parent, expr_bt, symbol.ident_string(&env.interns)); + + env.builder.build_store(alloca, val); + + // Make a new scope which includes the binding we just encountered. + // This should be done *after* compiling the bound expr, since any + // recursive (in the LetRec sense) bindings should already have + // been extracted as procedures. Nothing in here should need to + // access itself! + // scope = scope.clone(); + + scope.insert(*symbol, (layout.clone(), alloca)); + let result = build_exp_stmt(env, layout_ids, scope, parent, cont); + scope.remove(symbol); + + result + } + Ret(symbol) => { + dbg!(symbol, &scope); + + load_symbol(env, scope, symbol) + } + + Cond { + branching_symbol, + pass: pass_stmt, + fail: fail_stmt, + ret_layout, + .. + } => { + let ret_type = + basic_type_from_layout(env.arena, env.context, &ret_layout, env.ptr_bytes); + + let cond_expr = load_symbol(env, scope, branching_symbol); + + match cond_expr { + IntValue(value) => { + // This is a call tobuild_basic_phi2, except inlined to prevent + // problems with lifetimes and closures involving layout_ids. + let builder = env.builder; + let context = env.context; + + // build blocks + let then_block = context.append_basic_block(parent, "then"); + let else_block = context.append_basic_block(parent, "else"); + let mut blocks: std::vec::Vec<( + &dyn inkwell::values::BasicValue<'_>, + inkwell::basic_block::BasicBlock<'_>, + )> = std::vec::Vec::with_capacity(2); + let cont_block = context.append_basic_block(parent, "branchcont"); + + builder.build_conditional_branch(value, then_block, else_block); + + // build then block + builder.position_at_end(then_block); + let then_val = build_exp_stmt(env, layout_ids, scope, parent, pass_stmt); + if then_block.get_terminator().is_none() { + builder.build_unconditional_branch(cont_block); + let then_block = builder.get_insert_block().unwrap(); + blocks.push((&then_val, then_block)); + } + + // build else block + builder.position_at_end(else_block); + let else_val = build_exp_stmt(env, layout_ids, scope, parent, fail_stmt); + if else_block.get_terminator().is_none() { + let else_block = builder.get_insert_block().unwrap(); + builder.build_unconditional_branch(cont_block); + blocks.push((&else_val, else_block)); + } + + // emit merge block + if blocks.is_empty() { + // SAFETY there are no other references to this block in this case + unsafe { + cont_block.delete().unwrap(); + } + + // return garbage value + context.i64_type().const_int(0, false).into() + } else { + builder.position_at_end(cont_block); + + let phi = builder.build_phi(ret_type, "branch"); + + // phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); + phi.add_incoming(&blocks); + + phi.as_basic_value() + } + } + _ => panic!( + "Tried to make a branch out of an invalid condition: cond_expr = {:?}", + cond_expr, + ), + } + } + + Switch { + branches, + default_branch, + ret_layout, + cond_layout, + cond_symbol, + } => { + let ret_type = + basic_type_from_layout(env.arena, env.context, &ret_layout, env.ptr_bytes); + + let switch_args = SwitchArgsIr { + cond_layout: cond_layout.clone(), + cond_symbol: *cond_symbol, + branches, + default_branch, + ret_type, + }; + + build_switch_ir(env, layout_ids, scope, parent, switch_args) + } + Join { + id, + arguments, + remainder, + continuation, + } => { + let builder = env.builder; + let context = env.context; + + // create new block + let cont_block = context.append_basic_block(parent, "joinpointcont"); + + // store this join point + scope.join_points.insert(*id, cont_block); + + // construct the blocks that may jump to this join point + build_exp_stmt(env, layout_ids, scope, parent, remainder); + + // remove this join point again + scope.join_points.remove(&id); + + // Assumptions + // + // - `remainder` is either a Cond or Switch where + // - all branches jump to this join point + // + // we should improve this in the future! + let phi_block = builder.get_insert_block().unwrap(); + //builder.build_unconditional_branch(cont_block); + + // put the cont block at the back + builder.position_at_end(cont_block); + + // put the continuation in + let result = build_exp_stmt(env, layout_ids, scope, parent, continuation); + + cont_block.move_after(phi_block).unwrap(); + + result + } + Jump(join_point, _arguments) => { + let builder = env.builder; + let context = env.context; + let cont_block = scope.join_points.get(join_point).unwrap(); + let jmp = builder.build_unconditional_branch(*cont_block); + // builder.insert_instruction(&jmp, None); + + // This doesn't currently do anything + context.i64_type().const_int(0, false).into() + } + _ => todo!("unsupported expr {:?}", stmt), + } +} + #[allow(clippy::cognitive_complexity)] pub fn build_expr<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, @@ -262,7 +813,7 @@ pub fn build_expr<'a, 'ctx, 'env>( build_switch(env, layout_ids, scope, parent, switch_args) } Store(stores, ret) => { - let mut scope = im_rc::HashMap::clone(scope); + let mut scope = scope.clone(); let context = &env.context; for (symbol, layout, expr) in stores.iter() { @@ -282,7 +833,7 @@ pub fn build_expr<'a, 'ctx, 'env>( // recursive (in the LetRec sense) bindings should already have // been extracted as procedures. Nothing in here should need to // access itself! - scope = im_rc::HashMap::clone(&scope); + scope = scope.clone(); scope.insert(*symbol, (layout.clone(), alloca)); } @@ -842,13 +1393,28 @@ fn load_symbol<'a, 'ctx, 'env>( symbol: &Symbol, ) -> BasicValueEnum<'ctx> { match scope.get(symbol) { - Some((_, ptr)) => env + Some((layout, ptr)) => env .builder .build_load(*ptr, symbol.ident_string(&env.interns)), None => panic!("There was no entry for {:?} in scope {:?}", symbol, scope), } } +fn load_symbol_and_layout<'a, 'ctx, 'env, 'b>( + env: &Env<'a, 'ctx, 'env>, + scope: &'b Scope<'a, 'ctx>, + symbol: &Symbol, +) -> (&'b Layout<'a>, BasicValueEnum<'ctx>) { + match scope.get(symbol) { + Some((layout, ptr)) => ( + layout, + env.builder + .build_load(*ptr, symbol.ident_string(&env.interns)), + ), + None => panic!("There was no entry for {:?} in scope {:?}", symbol, scope), + } +} + /// Cast a struct to another struct of the same (or smaller?) size fn cast_struct_struct<'ctx>( builder: &Builder<'ctx>, @@ -1015,6 +1581,126 @@ fn build_switch<'a, 'ctx, 'env>( phi.as_basic_value() } +struct SwitchArgsIr<'a, 'ctx> { + pub cond_symbol: Symbol, + pub cond_layout: Layout<'a>, + pub branches: &'a [(u64, roc_mono::experiment::Stmt<'a>)], + pub default_branch: &'a roc_mono::experiment::Stmt<'a>, + pub ret_type: BasicTypeEnum<'ctx>, +} + +fn build_switch_ir<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + scope: &Scope<'a, 'ctx>, + parent: FunctionValue<'ctx>, + switch_args: SwitchArgsIr<'a, 'ctx>, +) -> BasicValueEnum<'ctx> { + let arena = env.arena; + let builder = env.builder; + let context = env.context; + let SwitchArgsIr { + branches, + cond_symbol, + mut cond_layout, + default_branch, + ret_type, + .. + } = switch_args; + + let mut copy = scope.clone(); + let scope = &mut copy; + + let cond_symbol = &cond_symbol; + + let cont_block = context.append_basic_block(parent, "cont"); + + // Build the condition + let cond = match cond_layout { + Layout::Builtin(Builtin::Float64) => { + // float matches are done on the bit pattern + cond_layout = Layout::Builtin(Builtin::Int64); + let full_cond = load_symbol(env, scope, cond_symbol); + + builder + .build_bitcast(full_cond, env.context.i64_type(), "") + .into_int_value() + } + Layout::Union(_) => { + // we match on the discriminant, not the whole Tag + cond_layout = Layout::Builtin(Builtin::Int64); + let full_cond = load_symbol(env, scope, cond_symbol).into_struct_value(); + + extract_tag_discriminant(env, full_cond) + } + Layout::Builtin(_) => load_symbol(env, scope, cond_symbol).into_int_value(), + other => todo!("Build switch value from layout: {:?}", other), + }; + + // Build the cases + let mut incoming = Vec::with_capacity_in(branches.len(), arena); + let mut cases = Vec::with_capacity_in(branches.len(), arena); + + for (int, _) in branches.iter() { + // Switch constants must all be same type as switch value! + // e.g. this is incorrect, and will trigger a LLVM warning: + // + // switch i8 %apple1, label %default [ + // i64 2, label %branch2 + // i64 0, label %branch0 + // i64 1, label %branch1 + // ] + // + // they either need to all be i8, or i64 + let int_val = match cond_layout { + Layout::Builtin(Builtin::Int128) => context.i128_type().const_int(*int as u64, false), /* TODO file an issue: you can't currently have an int literal bigger than 64 bits long, and also (as we see here), you can't currently have (at least in Inkwell) a when-branch with an i128 literal in its pattren */ + Layout::Builtin(Builtin::Int64) => context.i64_type().const_int(*int as u64, false), + Layout::Builtin(Builtin::Int32) => context.i32_type().const_int(*int as u64, false), + Layout::Builtin(Builtin::Int16) => context.i16_type().const_int(*int as u64, false), + Layout::Builtin(Builtin::Int8) => context.i8_type().const_int(*int as u64, false), + Layout::Builtin(Builtin::Int1) => context.bool_type().const_int(*int as u64, false), + _ => panic!("Can't cast to cond_layout = {:?}", cond_layout), + }; + let block = context.append_basic_block(parent, format!("branch{}", int).as_str()); + + cases.push((int_val, block)); + } + + let default_block = context.append_basic_block(parent, "default"); + + builder.build_switch(cond, default_block, &cases); + + for ((_, branch_expr), (_, block)) in branches.iter().zip(cases) { + builder.position_at_end(block); + + let branch_val = build_exp_stmt(env, layout_ids, scope, parent, branch_expr); + + builder.build_unconditional_branch(cont_block); + + incoming.push((branch_val, block)); + } + + // The block for the conditional's default branch. + builder.position_at_end(default_block); + + let default_val = build_exp_stmt(env, layout_ids, scope, parent, default_branch); + + builder.build_unconditional_branch(cont_block); + + incoming.push((default_val, default_block)); + + // emit merge block + builder.position_at_end(cont_block); + + let phi = builder.build_phi(ret_type, "branch"); + + for (branch_val, block) in incoming { + phi.add_incoming(&[(&Into::::into(branch_val), block)]); + } + + phi.as_basic_value() +} + fn build_basic_phi2<'a, 'ctx, 'env, PassFn, FailFn>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, @@ -1142,7 +1828,7 @@ pub fn build_proc<'a, 'ctx, 'env>( builder.position_at_end(entry); - let mut scope = ImMap::default(); + let mut scope = Scope::default(); // Add args to scope for ((arg_val, arg_type), (layout, arg_symbol)) in @@ -1163,6 +1849,78 @@ pub fn build_proc<'a, 'ctx, 'env>( builder.build_return(Some(&body)); } +pub fn build_proc_header_ir<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + symbol: Symbol, + layout: &Layout<'a>, + proc: &roc_mono::experiment::Proc<'a>, +) -> (FunctionValue<'ctx>, Vec<'a, BasicTypeEnum<'ctx>>) { + let args = proc.args; + let arena = env.arena; + let context = &env.context; + let ret_type = basic_type_from_layout(arena, context, &proc.ret_layout, env.ptr_bytes); + let mut arg_basic_types = Vec::with_capacity_in(args.len(), arena); + let mut arg_symbols = Vec::new_in(arena); + + for (layout, arg_symbol) in args.iter() { + let arg_type = basic_type_from_layout(arena, env.context, &layout, env.ptr_bytes); + + arg_basic_types.push(arg_type); + arg_symbols.push(arg_symbol); + } + + let fn_type = get_fn_type(&ret_type, &arg_basic_types); + + let fn_name = layout_ids + .get(symbol, layout) + .to_symbol_string(symbol, &env.interns); + let fn_val = env + .module + .add_function(fn_name.as_str(), fn_type, Some(Linkage::Private)); + + fn_val.set_call_conventions(fn_val.get_call_conventions()); + + (fn_val, arg_basic_types) +} + +pub fn build_proc_ir<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + proc: roc_mono::experiment::Proc<'a>, + fn_val: FunctionValue<'ctx>, + arg_basic_types: Vec<'a, BasicTypeEnum<'ctx>>, +) { + let args = proc.args; + let context = &env.context; + + // Add a basic block for the entry point + let entry = context.append_basic_block(fn_val, "entry"); + let builder = env.builder; + + builder.position_at_end(entry); + + let mut scope = Scope::default(); + + // Add args to scope + for ((arg_val, arg_type), (layout, arg_symbol)) in + fn_val.get_param_iter().zip(arg_basic_types).zip(args) + { + set_name(arg_val, arg_symbol.ident_string(&env.interns)); + + let alloca = + create_entry_block_alloca(env, fn_val, arg_type, arg_symbol.ident_string(&env.interns)); + + builder.build_store(alloca, arg_val); + + scope.insert(*arg_symbol, (layout.clone(), alloca)); + } + + let body = build_exp_stmt(env, layout_ids, &mut scope, fn_val, &proc.body); + + builder.build_return(Some(&body)); +} + pub fn verify_fn(fn_val: FunctionValue<'_>) { if !fn_val.verify(PRINT_FN_VERIFICATION_OUTPUT) { unsafe { @@ -1447,6 +2205,46 @@ fn call_with_args<'a, 'ctx, 'env>( .unwrap_or_else(|| panic!("LLVM error: Invalid call by name for name {:?}", symbol)) } +#[inline(always)] +#[allow(clippy::cognitive_complexity)] +fn call_with_args_ir<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + layout: &Layout<'a>, + symbol: Symbol, + _parent: FunctionValue<'ctx>, + args: &[BasicValueEnum<'ctx>], +) -> BasicValueEnum<'ctx> { + let fn_name = layout_ids + .get(symbol, layout) + .to_symbol_string(symbol, &env.interns); + let fn_val = env + .module + .get_function(fn_name.as_str()) + .unwrap_or_else(|| { + if symbol.is_builtin() { + panic!("Unrecognized builtin function: {:?}", symbol) + } else { + panic!("Unrecognized non-builtin function: {:?}", symbol) + } + }); + let mut arg_vals: Vec = Vec::with_capacity_in(args.len(), env.arena); + + for (arg) in args.iter() { + arg_vals.push(*arg); + } + + let call = env + .builder + .build_call(fn_val, arg_vals.into_bump_slice(), "call"); + + call.set_call_convention(fn_val.get_call_conventions()); + + call.try_as_basic_value() + .left() + .unwrap_or_else(|| panic!("LLVM error: Invalid call by name for name {:?}", symbol)) +} + fn call_intrinsic<'a, 'ctx, 'env>( intrinsic_name: &'static str, env: &Env<'a, 'ctx, 'env>, diff --git a/compiler/gen/tests/gen_list.rs b/compiler/gen/tests/gen_list.rs index 4becd16a46..724b7d7545 100644 --- a/compiler/gen/tests/gen_list.rs +++ b/compiler/gen/tests/gen_list.rs @@ -29,20 +29,20 @@ mod gen_list { #[test] fn empty_list_literal() { - assert_evals_to!("[]", &[], &'static [i64]); + assert_evals_to_ir!("[]", &[], &'static [i64]); } #[test] fn int_list_literal() { - assert_evals_to!("[ 12, 9, 6, 3 ]", &[12, 9, 6, 3], &'static [i64]); + assert_evals_to_ir!("[ 12, 9, 6, 3 ]", &[12, 9, 6, 3], &'static [i64]); } #[test] fn list_push() { - assert_evals_to!("List.push [1] 2", &[1, 2], &'static [i64]); - assert_evals_to!("List.push [1, 1] 2", &[1, 1, 2], &'static [i64]); - assert_evals_to!("List.push [] 3", &[3], &'static [i64]); - assert_evals_to!( + assert_evals_to_ir!("List.push [1] 2", &[1, 2], &'static [i64]); + assert_evals_to_ir!("List.push [1, 1] 2", &[1, 1, 2], &'static [i64]); + assert_evals_to_ir!("List.push [] 3", &[3], &'static [i64]); + assert_evals_to_ir!( indoc!( r#" initThrees : List Int @@ -55,12 +55,12 @@ mod gen_list { &[3, 3], &'static [i64] ); - assert_evals_to!( + assert_evals_to_ir!( "List.push [ True, False ] True", &[true, false, true], &'static [bool] ); - assert_evals_to!( + assert_evals_to_ir!( "List.push [ 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22 ] 23", &[11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], &'static [i64] @@ -69,17 +69,17 @@ mod gen_list { #[test] fn list_single() { - assert_evals_to!("List.single 1", &[1], &'static [i64]); - assert_evals_to!("List.single 5.6", &[5.6], &'static [f64]); + assert_evals_to_ir!("List.single 1", &[1], &'static [i64]); + assert_evals_to_ir!("List.single 5.6", &[5.6], &'static [f64]); } #[test] fn list_repeat() { - assert_evals_to!("List.repeat 5 1", &[1, 1, 1, 1, 1], &'static [i64]); - assert_evals_to!("List.repeat 4 2", &[2, 2, 2, 2], &'static [i64]); + assert_evals_to_ir!("List.repeat 5 1", &[1, 1, 1, 1, 1], &'static [i64]); + assert_evals_to_ir!("List.repeat 4 2", &[2, 2, 2, 2], &'static [i64]); - assert_evals_to!("List.repeat 2 []", &[&[], &[]], &'static [&'static [i64]]); - assert_evals_to!( + assert_evals_to_ir!("List.repeat 2 []", &[&[], &[]], &'static [&'static [i64]]); + assert_evals_to_ir!( indoc!( r#" noStrs : List Str @@ -93,7 +93,7 @@ mod gen_list { &'static [&'static [i64]] ); - assert_evals_to!( + assert_evals_to_ir!( "List.repeat 15 4", &[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], &'static [i64] @@ -102,14 +102,14 @@ mod gen_list { #[test] fn list_reverse() { - assert_evals_to!( + assert_evals_to_ir!( "List.reverse [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 ]", &[12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1], &'static [i64] ); - assert_evals_to!("List.reverse [1, 2, 3]", &[3, 2, 1], &'static [i64]); - assert_evals_to!("List.reverse [4]", &[4], &'static [i64]); - assert_evals_to!( + assert_evals_to_ir!("List.reverse [1, 2, 3]", &[3, 2, 1], &'static [i64]); + assert_evals_to_ir!("List.reverse [4]", &[4], &'static [i64]); + assert_evals_to_ir!( indoc!( r#" emptyList : List Int @@ -122,14 +122,14 @@ mod gen_list { &[], &'static [i64] ); - assert_evals_to!("List.reverse []", &[], &'static [i64]); + assert_evals_to_ir!("List.reverse []", &[], &'static [i64]); } #[test] fn list_append() { - assert_evals_to!("List.append [] []", &[], &'static [i64]); + assert_evals_to_ir!("List.append [] []", &[], &'static [i64]); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" firstList : List Int @@ -147,16 +147,16 @@ mod gen_list { &'static [i64] ); - assert_evals_to!("List.append [ 12, 13 ] []", &[12, 13], &'static [i64]); - assert_evals_to!( + assert_evals_to_ir!("List.append [ 12, 13 ] []", &[12, 13], &'static [i64]); + assert_evals_to_ir!( "List.append [ 34, 43 ] [ 64, 55, 66 ]", &[34, 43, 64, 55, 66], &'static [i64] ); - assert_evals_to!("List.append [] [ 23, 24 ]", &[23, 24], &'static [i64]); + assert_evals_to_ir!("List.append [] [ 23, 24 ]", &[23, 24], &'static [i64]); - assert_evals_to!( + assert_evals_to_ir!( "List.append [ 1, 2 ] [ 3, 4 ]", &[1, 2, 3, 4], &'static [i64] @@ -178,7 +178,7 @@ mod gen_list { let expected_slice: &[i64] = expected.as_ref(); - assert_evals_to!( + assert_evals_to_ir!( &format!("List.append {} {}", slice_str1, slice_str2), expected_slice, &'static [i64] @@ -228,17 +228,17 @@ mod gen_list { #[test] fn empty_list_len() { - assert_evals_to!("List.len []", 0, usize); + assert_evals_to_ir!("List.len []", 0, usize); } #[test] fn basic_int_list_len() { - assert_evals_to!("List.len [ 12, 9, 6, 3 ]", 4, usize); + assert_evals_to_ir!("List.len [ 12, 9, 6, 3 ]", 4, usize); } #[test] fn loaded_int_list_len() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" nums = [ 2, 4, 6 ] @@ -253,7 +253,7 @@ mod gen_list { #[test] fn fn_int_list_len() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" getLen = \list -> List.len list @@ -270,17 +270,17 @@ mod gen_list { #[test] fn int_list_is_empty() { - assert_evals_to!("List.isEmpty [ 12, 9, 6, 3 ]", false, bool); + assert_evals_to_ir!("List.isEmpty [ 12, 9, 6, 3 ]", false, bool); } #[test] fn empty_list_is_empty() { - assert_evals_to!("List.isEmpty []", true, bool); + assert_evals_to_ir!("List.isEmpty []", true, bool); } #[test] fn first_int_list() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when List.first [ 12, 9, 6, 3 ] is @@ -295,7 +295,7 @@ mod gen_list { #[test] fn first_wildcard_empty_list() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when List.first [] is @@ -314,7 +314,7 @@ mod gen_list { // // #[test] // fn first_empty_list() { - // assert_evals_to!( + // assert_evals_to_ir!( // indoc!( // r#" // when List.first [] is @@ -329,7 +329,7 @@ mod gen_list { #[test] fn get_empty_list() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when List.get [] 0 is @@ -344,7 +344,7 @@ mod gen_list { #[test] fn get_wildcard_empty_list() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when List.get [] 0 is @@ -359,7 +359,7 @@ mod gen_list { #[test] fn get_int_list_ok() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when List.get [ 12, 9, 6 ] 1 is @@ -374,7 +374,7 @@ mod gen_list { #[test] fn get_int_list_oob() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when List.get [ 12, 9, 6 ] 1000 is @@ -389,7 +389,7 @@ mod gen_list { #[test] fn get_set_unique_int_list() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when List.get (List.set [ 12, 9, 7, 3 ] 1 42) 1 is @@ -404,7 +404,7 @@ mod gen_list { #[test] fn set_unique_int_list() { - assert_evals_to!( + assert_evals_to_ir!( "List.set [ 12, 9, 7, 1, 5 ] 2 33", &[12, 9, 33, 1, 5], &'static [i64] @@ -413,7 +413,7 @@ mod gen_list { #[test] fn set_unique_list_oob() { - assert_evals_to!( + assert_evals_to_ir!( "List.set [ 3, 17, 4.1 ] 1337 9.25", &[3.0, 17.0, 4.1], &'static [f64] @@ -422,7 +422,7 @@ mod gen_list { #[test] fn set_shared_int_list() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" shared = [ 2.1, 4.3 ] @@ -448,7 +448,7 @@ mod gen_list { #[test] fn set_shared_list_oob() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" shared = [ 2, 4 ] @@ -474,7 +474,7 @@ mod gen_list { #[test] fn get_unique_int_list() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" unique = [ 2, 4 ] @@ -491,7 +491,7 @@ mod gen_list { #[test] fn gen_wrap_len() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" wrapLen = \list -> @@ -507,7 +507,7 @@ mod gen_list { #[test] fn gen_wrap_first() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" wrapFirst = \list -> @@ -523,7 +523,7 @@ mod gen_list { #[test] fn gen_duplicate() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" # Duplicate the first element into the second index @@ -545,7 +545,7 @@ mod gen_list { #[test] fn gen_swap() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" swap : Int, Int, List a -> List a @@ -568,7 +568,7 @@ mod gen_list { // #[test] // fn gen_partition() { - // assert_evals_to!( + // assert_evals_to_ir!( // indoc!( // r#" // swap : Int, Int, List a -> List a @@ -620,7 +620,7 @@ mod gen_list { // #[test] // fn gen_partition() { - // assert_evals_to!( + // assert_evals_to_ir!( // indoc!( // r#" // swap : Int, Int, List a -> List a @@ -660,7 +660,7 @@ mod gen_list { #[test] fn gen_quicksort() { with_larger_debug_stack(|| { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" quicksort : List (Num a) -> List (Num a) @@ -734,7 +734,7 @@ mod gen_list { #[test] fn foobar2() { with_larger_debug_stack(|| { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" quicksort : List (Num a) -> List (Num a) @@ -809,7 +809,7 @@ mod gen_list { #[test] fn foobar() { with_larger_debug_stack(|| { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" quicksort : List (Num a) -> List (Num a) @@ -883,7 +883,7 @@ mod gen_list { #[test] fn empty_list_increment_decrement() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" x : List Int @@ -899,7 +899,7 @@ mod gen_list { #[test] fn list_literal_increment_decrement() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" x : List Int @@ -915,7 +915,7 @@ mod gen_list { #[test] fn list_pass_to_function() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" x : List Int @@ -933,7 +933,7 @@ mod gen_list { } // fn bad() { - // assert_evals_to!( + // assert_evals_to_ir!( // indoc!( // r#" // id : List Int -> [ Id (List Int) ] @@ -950,30 +950,14 @@ mod gen_list { // #[test] fn list_wrap_in_tag() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" id : List Int -> [ Pair (List Int) Int, Nil ] id = \y -> Pair y 4 when id [1,2,3] is - Pair _ _ -> v - "# - ), - &[1, 2, 3], - &'static [i64] - ); - } - - #[test] - fn list_wrap_in_tag() { - assert_evals_to!( - indoc!( - r#" - x = [1,2,3] - - when id [1,2,3] is - Pair _ _ -> v + Pair v _ -> v "# ), &[1, 2, 3], diff --git a/compiler/gen/tests/gen_num.rs b/compiler/gen/tests/gen_num.rs index cb0afe9ecd..59b409bdcb 100644 --- a/compiler/gen/tests/gen_num.rs +++ b/compiler/gen/tests/gen_num.rs @@ -30,7 +30,7 @@ mod gen_num { #[test] fn f64_sqrt() { // FIXME this works with normal types, but fails when checking uniqueness types - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when Num.sqrt 100 is @@ -44,31 +44,31 @@ mod gen_num { } #[test] - fn f64_round() { - assert_evals_to!("Num.round 3.6", 4, i64); + fn f64_round_old() { + assert_evals_to_ir!("Num.round 3.6", 4, i64); } #[test] fn f64_abs() { - assert_evals_to!("Num.abs -4.7", 4.7, f64); - assert_evals_to!("Num.abs 5.8", 5.8, f64); + assert_evals_to_ir!("Num.abs -4.7", 4.7, f64); + assert_evals_to_ir!("Num.abs 5.8", 5.8, f64); } #[test] fn i64_abs() { - assert_evals_to!("Num.abs -6", 6, i64); - assert_evals_to!("Num.abs 7", 7, i64); - assert_evals_to!("Num.abs 0", 0, i64); - assert_evals_to!("Num.abs -0", 0, i64); - assert_evals_to!("Num.abs -1", 1, i64); - assert_evals_to!("Num.abs 1", 1, i64); - assert_evals_to!("Num.abs 9_000_000_000_000", 9_000_000_000_000, i64); - assert_evals_to!("Num.abs -9_000_000_000_000", 9_000_000_000_000, i64); + assert_evals_to_ir!("Num.abs -6", 6, i64); + assert_evals_to_ir!("Num.abs 7", 7, i64); + assert_evals_to_ir!("Num.abs 0", 0, i64); + assert_evals_to_ir!("Num.abs -0", 0, i64); + assert_evals_to_ir!("Num.abs -1", 1, i64); + assert_evals_to_ir!("Num.abs 1", 1, i64); + assert_evals_to_ir!("Num.abs 9_000_000_000_000", 9_000_000_000_000, i64); + assert_evals_to_ir!("Num.abs -9_000_000_000_000", 9_000_000_000_000, i64); } #[test] fn gen_if_fn() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" limitedNegate = \num -> @@ -89,7 +89,7 @@ mod gen_num { #[test] fn gen_float_eq() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" 1.0 == 1.0 @@ -102,7 +102,7 @@ mod gen_num { #[test] fn gen_add_f64() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" 1.1 + 2.4 + 3 @@ -115,7 +115,7 @@ mod gen_num { #[test] fn gen_wrap_add_nums() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" add2 = \num1, num2 -> num1 + num2 @@ -131,7 +131,7 @@ mod gen_num { #[test] fn gen_div_f64() { // FIXME this works with normal types, but fails when checking uniqueness types - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when 48 / 2 is @@ -146,7 +146,7 @@ mod gen_num { #[test] fn gen_int_eq() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" 4 == 4 @@ -159,7 +159,7 @@ mod gen_num { #[test] fn gen_int_neq() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" 4 != 5 @@ -172,7 +172,7 @@ mod gen_num { #[test] fn gen_wrap_int_neq() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" wrappedNotEq : a, a -> Bool @@ -189,7 +189,7 @@ mod gen_num { #[test] fn gen_add_i64() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" 1 + 2 + 3 @@ -202,7 +202,7 @@ mod gen_num { #[test] fn gen_sub_f64() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" 1.5 - 2.4 - 3 @@ -215,7 +215,7 @@ mod gen_num { #[test] fn gen_sub_i64() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" 1 - 2 - 3 @@ -228,7 +228,7 @@ mod gen_num { #[test] fn gen_mul_i64() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" 2 * 4 * 6 @@ -241,7 +241,7 @@ mod gen_num { #[test] fn gen_div_i64() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when 1000 // 10 is @@ -256,7 +256,7 @@ mod gen_num { #[test] fn gen_div_by_zero_i64() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when 1000 // 0 is @@ -271,7 +271,7 @@ mod gen_num { #[test] fn gen_rem_i64() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when Num.rem 8 3 is @@ -286,7 +286,7 @@ mod gen_num { #[test] fn gen_rem_div_by_zero_i64() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when Num.rem 8 0 is @@ -301,143 +301,143 @@ mod gen_num { #[test] fn gen_is_zero_i64() { - assert_evals_to!("Num.isZero 0", true, bool); - assert_evals_to!("Num.isZero 1", false, bool); + assert_evals_to_ir!("Num.isZero 0", true, bool); + assert_evals_to_ir!("Num.isZero 1", false, bool); } #[test] fn gen_is_positive_i64() { - assert_evals_to!("Num.isPositive 0", false, bool); - assert_evals_to!("Num.isPositive 1", true, bool); - assert_evals_to!("Num.isPositive -5", false, bool); + assert_evals_to_ir!("Num.isPositive 0", false, bool); + assert_evals_to_ir!("Num.isPositive 1", true, bool); + assert_evals_to_ir!("Num.isPositive -5", false, bool); } #[test] fn gen_is_negative_i64() { - assert_evals_to!("Num.isNegative 0", false, bool); - assert_evals_to!("Num.isNegative 3", false, bool); - assert_evals_to!("Num.isNegative -2", true, bool); + assert_evals_to_ir!("Num.isNegative 0", false, bool); + assert_evals_to_ir!("Num.isNegative 3", false, bool); + assert_evals_to_ir!("Num.isNegative -2", true, bool); } #[test] fn gen_is_positive_f64() { - assert_evals_to!("Num.isPositive 0.0", false, bool); - assert_evals_to!("Num.isPositive 4.7", true, bool); - assert_evals_to!("Num.isPositive -8.5", false, bool); + assert_evals_to_ir!("Num.isPositive 0.0", false, bool); + assert_evals_to_ir!("Num.isPositive 4.7", true, bool); + assert_evals_to_ir!("Num.isPositive -8.5", false, bool); } #[test] fn gen_is_negative_f64() { - assert_evals_to!("Num.isNegative 0.0", false, bool); - assert_evals_to!("Num.isNegative 9.9", false, bool); - assert_evals_to!("Num.isNegative -4.4", true, bool); + assert_evals_to_ir!("Num.isNegative 0.0", false, bool); + assert_evals_to_ir!("Num.isNegative 9.9", false, bool); + assert_evals_to_ir!("Num.isNegative -4.4", true, bool); } #[test] fn gen_is_zero_f64() { - assert_evals_to!("Num.isZero 0", true, bool); - assert_evals_to!("Num.isZero 0_0", true, bool); - assert_evals_to!("Num.isZero 0.0", true, bool); - assert_evals_to!("Num.isZero 1", false, bool); + assert_evals_to_ir!("Num.isZero 0", true, bool); + assert_evals_to_ir!("Num.isZero 0_0", true, bool); + assert_evals_to_ir!("Num.isZero 0.0", true, bool); + assert_evals_to_ir!("Num.isZero 1", false, bool); } #[test] fn gen_is_odd() { - assert_evals_to!("Num.isOdd 4", false, bool); - assert_evals_to!("Num.isOdd 5", true, bool); + assert_evals_to_ir!("Num.isOdd 4", false, bool); + assert_evals_to_ir!("Num.isOdd 5", true, bool); } #[test] fn gen_is_even() { - assert_evals_to!("Num.isEven 6", true, bool); - assert_evals_to!("Num.isEven 7", false, bool); + assert_evals_to_ir!("Num.isEven 6", true, bool); + assert_evals_to_ir!("Num.isEven 7", false, bool); } #[test] fn sin() { - assert_evals_to!("Num.sin 0", 0.0, f64); - assert_evals_to!("Num.sin 1.41421356237", 0.9877659459922529, f64); + assert_evals_to_ir!("Num.sin 0", 0.0, f64); + assert_evals_to_ir!("Num.sin 1.41421356237", 0.9877659459922529, f64); } #[test] fn cos() { - assert_evals_to!("Num.cos 0", 1.0, f64); - assert_evals_to!("Num.cos 3.14159265359", -1.0, f64); + assert_evals_to_ir!("Num.cos 0", 1.0, f64); + assert_evals_to_ir!("Num.cos 3.14159265359", -1.0, f64); } #[test] fn tan() { - assert_evals_to!("Num.tan 0", 0.0, f64); - assert_evals_to!("Num.tan 1", 1.557407724654902, f64); + assert_evals_to_ir!("Num.tan 0", 0.0, f64); + assert_evals_to_ir!("Num.tan 1", 1.557407724654902, f64); } #[test] fn lt_i64() { - assert_evals_to!("1 < 2", true, bool); - assert_evals_to!("1 < 1", false, bool); - assert_evals_to!("2 < 1", false, bool); - assert_evals_to!("0 < 0", false, bool); + assert_evals_to_ir!("1 < 2", true, bool); + assert_evals_to_ir!("1 < 1", false, bool); + assert_evals_to_ir!("2 < 1", false, bool); + assert_evals_to_ir!("0 < 0", false, bool); } #[test] fn lte_i64() { - assert_evals_to!("1 <= 1", true, bool); - assert_evals_to!("2 <= 1", false, bool); - assert_evals_to!("1 <= 2", true, bool); - assert_evals_to!("0 <= 0", true, bool); + assert_evals_to_ir!("1 <= 1", true, bool); + assert_evals_to_ir!("2 <= 1", false, bool); + assert_evals_to_ir!("1 <= 2", true, bool); + assert_evals_to_ir!("0 <= 0", true, bool); } #[test] fn gt_i64() { - assert_evals_to!("2 > 1", true, bool); - assert_evals_to!("2 > 2", false, bool); - assert_evals_to!("1 > 1", false, bool); - assert_evals_to!("0 > 0", false, bool); + assert_evals_to_ir!("2 > 1", true, bool); + assert_evals_to_ir!("2 > 2", false, bool); + assert_evals_to_ir!("1 > 1", false, bool); + assert_evals_to_ir!("0 > 0", false, bool); } #[test] fn gte_i64() { - assert_evals_to!("1 >= 1", true, bool); - assert_evals_to!("1 >= 2", false, bool); - assert_evals_to!("2 >= 1", true, bool); - assert_evals_to!("0 >= 0", true, bool); + assert_evals_to_ir!("1 >= 1", true, bool); + assert_evals_to_ir!("1 >= 2", false, bool); + assert_evals_to_ir!("2 >= 1", true, bool); + assert_evals_to_ir!("0 >= 0", true, bool); } #[test] fn lt_f64() { - assert_evals_to!("1.1 < 1.2", true, bool); - assert_evals_to!("1.1 < 1.1", false, bool); - assert_evals_to!("1.2 < 1.1", false, bool); - assert_evals_to!("0.0 < 0.0", false, bool); + assert_evals_to_ir!("1.1 < 1.2", true, bool); + assert_evals_to_ir!("1.1 < 1.1", false, bool); + assert_evals_to_ir!("1.2 < 1.1", false, bool); + assert_evals_to_ir!("0.0 < 0.0", false, bool); } #[test] fn lte_f64() { - assert_evals_to!("1.1 <= 1.1", true, bool); - assert_evals_to!("1.2 <= 1.1", false, bool); - assert_evals_to!("1.1 <= 1.2", true, bool); - assert_evals_to!("0.0 <= 0.0", true, bool); + assert_evals_to_ir!("1.1 <= 1.1", true, bool); + assert_evals_to_ir!("1.2 <= 1.1", false, bool); + assert_evals_to_ir!("1.1 <= 1.2", true, bool); + assert_evals_to_ir!("0.0 <= 0.0", true, bool); } #[test] fn gt_f64() { - assert_evals_to!("2.2 > 1.1", true, bool); - assert_evals_to!("2.2 > 2.2", false, bool); - assert_evals_to!("1.1 > 2.2", false, bool); - assert_evals_to!("0.0 > 0.0", false, bool); + assert_evals_to_ir!("2.2 > 1.1", true, bool); + assert_evals_to_ir!("2.2 > 2.2", false, bool); + assert_evals_to_ir!("1.1 > 2.2", false, bool); + assert_evals_to_ir!("0.0 > 0.0", false, bool); } #[test] fn gte_f64() { - assert_evals_to!("1.1 >= 1.1", true, bool); - assert_evals_to!("1.1 >= 1.2", false, bool); - assert_evals_to!("1.2 >= 1.1", true, bool); - assert_evals_to!("0.0 >= 0.0", true, bool); + assert_evals_to_ir!("1.1 >= 1.1", true, bool); + assert_evals_to_ir!("1.1 >= 1.2", false, bool); + assert_evals_to_ir!("1.2 >= 1.1", true, bool); + assert_evals_to_ir!("0.0 >= 0.0", true, bool); } #[test] fn gen_order_of_arithmetic_ops() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" 1 + 3 * 7 - 2 @@ -450,7 +450,7 @@ mod gen_num { #[test] fn gen_order_of_arithmetic_ops_complex_float() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" 3 - 48 * 2.0 @@ -463,7 +463,7 @@ mod gen_num { #[test] fn if_guard_bind_variable() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when 10 is @@ -475,7 +475,7 @@ mod gen_num { i64 ); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when 10 is @@ -490,7 +490,7 @@ mod gen_num { #[test] fn tail_call_elimination() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" sum = \n, accum -> @@ -508,12 +508,12 @@ mod gen_num { #[test] fn int_negate() { - assert_evals_to!("Num.neg 123", -123, i64); + assert_evals_to_ir!("Num.neg 123", -123, i64); } #[test] fn gen_wrap_int_neg() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" wrappedNeg = \num -> -num @@ -528,7 +528,7 @@ mod gen_num { #[test] fn gen_basic_fn() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" always42 : Num.Num Num.Integer -> Num.Num Num.Integer @@ -544,16 +544,16 @@ mod gen_num { #[test] fn int_to_float() { - assert_evals_to!("Num.toFloat 0x9", 9.0, f64); + assert_evals_to_ir!("Num.toFloat 0x9", 9.0, f64); } #[test] fn num_to_float() { - assert_evals_to!("Num.toFloat 9", 9.0, f64); + assert_evals_to_ir!("Num.toFloat 9", 9.0, f64); } #[test] fn float_to_float() { - assert_evals_to!("Num.toFloat 0.5", 0.5, f64); + assert_evals_to_ir!("Num.toFloat 0.5", 0.5, f64); } } diff --git a/compiler/gen/tests/gen_primitives.rs b/compiler/gen/tests/gen_primitives.rs index 8b5b9037cd..db024315fb 100644 --- a/compiler/gen/tests/gen_primitives.rs +++ b/compiler/gen/tests/gen_primitives.rs @@ -31,7 +31,7 @@ mod gen_primitives { #[test] fn basic_str() { - assert_evals_to!( + assert_evals_to_ir!( "\"shirt and hat\"", CString::new("shirt and hat").unwrap().as_c_str(), *const c_char, @@ -41,17 +41,17 @@ mod gen_primitives { #[test] fn basic_int() { - assert_evals_to!("123", 123, i64); + assert_evals_to_ir!("123", 123, i64); } #[test] fn basic_float() { - assert_evals_to!("1234.0", 1234.0, f64); + assert_evals_to_ir!("1234.0", 1234.0, f64); } #[test] fn branch_first_float() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when 1.23 is @@ -66,7 +66,7 @@ mod gen_primitives { #[test] fn branch_second_float() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when 2.34 is @@ -81,7 +81,7 @@ mod gen_primitives { #[test] fn branch_third_float() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when 10.0 is @@ -97,7 +97,7 @@ mod gen_primitives { #[test] fn branch_first_int() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when 1 is @@ -112,7 +112,7 @@ mod gen_primitives { #[test] fn branch_second_int() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when 2 is @@ -127,7 +127,7 @@ mod gen_primitives { #[test] fn branch_third_int() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when 10 is @@ -143,7 +143,7 @@ mod gen_primitives { #[test] fn branch_store_variable() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when 0 is @@ -158,7 +158,7 @@ mod gen_primitives { #[test] fn when_one_element_tag() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" x : [ Pair Int Int ] @@ -175,7 +175,7 @@ mod gen_primitives { #[test] fn when_two_element_tag_first() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" x : [A Int, B Int] @@ -193,7 +193,7 @@ mod gen_primitives { #[test] fn when_two_element_tag_second() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" x : [A Int, B Int] @@ -211,7 +211,7 @@ mod gen_primitives { #[test] fn gen_when_one_branch() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when 3.14 is @@ -225,7 +225,7 @@ mod gen_primitives { #[test] fn gen_large_when_int() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" foo = \num -> @@ -247,7 +247,7 @@ mod gen_primitives { // #[test] // fn gen_large_when_float() { - // assert_evals_to!( + // assert_evals_to_ir!( // indoc!( // r#" // foo = \num -> @@ -269,7 +269,7 @@ mod gen_primitives { #[test] fn or_pattern() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when 2 is @@ -284,7 +284,7 @@ mod gen_primitives { #[test] fn apply_identity() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" identity = \a -> a @@ -299,7 +299,7 @@ mod gen_primitives { #[test] fn apply_unnamed_identity() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" (\a -> a) 5 @@ -312,7 +312,7 @@ mod gen_primitives { #[test] fn return_unnamed_fn() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" alwaysFloatIdentity : Int -> (Float -> Float) @@ -329,7 +329,7 @@ mod gen_primitives { #[test] fn gen_when_fn() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" limitedNegate = \num -> @@ -348,7 +348,7 @@ mod gen_primitives { #[test] fn gen_basic_def() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" answer = 42 @@ -360,7 +360,7 @@ mod gen_primitives { i64 ); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" pi = 3.14 @@ -375,7 +375,7 @@ mod gen_primitives { #[test] fn gen_multiple_defs() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" answer = 42 @@ -389,7 +389,7 @@ mod gen_primitives { i64 ); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" answer = 42 @@ -406,7 +406,7 @@ mod gen_primitives { #[test] fn gen_chained_defs() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" x = i1 @@ -424,7 +424,7 @@ mod gen_primitives { } #[test] fn gen_nested_defs() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" x = 5 diff --git a/compiler/gen/tests/gen_records.rs b/compiler/gen/tests/gen_records.rs index 7066594b44..5187e5e6e7 100644 --- a/compiler/gen/tests/gen_records.rs +++ b/compiler/gen/tests/gen_records.rs @@ -29,7 +29,7 @@ mod gen_records { #[test] fn basic_record() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" { y: 17, x: 15, z: 19 }.x @@ -39,7 +39,7 @@ mod gen_records { i64 ); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" { x: 15, y: 17, z: 19 }.y @@ -49,7 +49,7 @@ mod gen_records { i64 ); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" { x: 15, y: 17, z: 19 }.z @@ -62,7 +62,7 @@ mod gen_records { #[test] fn f64_record() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" rec = { y: 17.2, x: 15.1, z: 19.3 } @@ -74,7 +74,7 @@ mod gen_records { f64 ); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" rec = { y: 17.2, x: 15.1, z: 19.3 } @@ -86,7 +86,7 @@ mod gen_records { f64 ); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" rec = { y: 17.2, x: 15.1, z: 19.3 } @@ -101,7 +101,7 @@ mod gen_records { #[test] fn fn_record() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" getRec = \x -> { y: 17, x, z: 19 } @@ -113,7 +113,7 @@ mod gen_records { i64 ); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" rec = { x: 15, y: 17, z: 19 } @@ -125,7 +125,7 @@ mod gen_records { i64 ); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" rec = { x: 15, y: 17, z: 19 } @@ -137,7 +137,7 @@ mod gen_records { i64 ); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" rec = { x: 15, y: 17, z: 19 } @@ -152,7 +152,7 @@ mod gen_records { #[test] fn def_record() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" rec = { y: 17, x: 15, z: 19 } @@ -164,7 +164,7 @@ mod gen_records { i64 ); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" rec = { x: 15, y: 17, z: 19 } @@ -176,7 +176,7 @@ mod gen_records { i64 ); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" rec = { x: 15, y: 17, z: 19 } @@ -191,7 +191,7 @@ mod gen_records { #[test] fn when_on_record() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when { x: 0x2 } is @@ -202,7 +202,7 @@ mod gen_records { i64 ); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when { x: 0x2, y: 3.14 } is @@ -213,7 +213,7 @@ mod gen_records { i64 ); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" { x } = { x: 0x2, y: 3.14 } @@ -228,7 +228,7 @@ mod gen_records { #[test] fn record_guard_pattern() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when { x: 0x2, y: 3.14 } is @@ -243,7 +243,7 @@ mod gen_records { #[test] fn twice_record_access() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" x = {a: 0x2, b: 0x3 } @@ -257,7 +257,7 @@ mod gen_records { } #[test] fn empty_record() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" v = {} @@ -271,7 +271,7 @@ mod gen_records { } #[test] fn i64_record2_literal() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" { x: 3, y: 5 } @@ -284,7 +284,7 @@ mod gen_records { // #[test] // fn i64_record3_literal() { - // assert_evals_to!( + // assert_evals_to_ir!( // indoc!( // r#" // { x: 3, y: 5, z: 17 } @@ -297,7 +297,7 @@ mod gen_records { #[test] fn f64_record2_literal() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" { x: 3.1, y: 5.1 } @@ -310,7 +310,7 @@ mod gen_records { // #[test] // fn f64_record3_literal() { - // assert_evals_to!( + // assert_evals_to_ir!( // indoc!( // r#" // { x: 3.1, y: 5.1, z: 17.1 } @@ -323,7 +323,7 @@ mod gen_records { // #[test] // fn bool_record4_literal() { - // assert_evals_to!( + // assert_evals_to_ir!( // indoc!( // r#" // record : { a : Bool, b : Bool, c : Bool, d : Bool } @@ -339,7 +339,7 @@ mod gen_records { #[test] fn i64_record1_literal() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" { a: 3 } @@ -352,7 +352,7 @@ mod gen_records { // #[test] // fn i64_record9_literal() { - // assert_evals_to!( + // assert_evals_to_ir!( // indoc!( // r#" // { a: 3, b: 5, c: 17, d: 1, e: 9, f: 12, g: 13, h: 14, i: 15 } @@ -365,7 +365,7 @@ mod gen_records { // #[test] // fn f64_record3_literal() { - // assert_evals_to!( + // assert_evals_to_ir!( // indoc!( // r#" // { x: 3.1, y: 5.1, z: 17.1 } @@ -378,7 +378,7 @@ mod gen_records { #[test] fn bool_literal() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" x : Bool diff --git a/compiler/gen/tests/gen_tags.rs b/compiler/gen/tests/gen_tags.rs index 96ae1a3477..841f56be26 100644 --- a/compiler/gen/tests/gen_tags.rs +++ b/compiler/gen/tests/gen_tags.rs @@ -27,9 +27,29 @@ mod gen_tags { use roc_mono::layout::Layout; use roc_types::subs::Subs; + #[test] + fn applied_tag_nothing_ir() { + assert_llvm_ir_evals_to!( + indoc!( + r#" + Maybe a : [ Just a, Nothing ] + + x : Maybe Int + x = Nothing + + 0x1 + "# + ), + 1, + i64, + |x| x, + false + ); + } + #[test] fn applied_tag_nothing() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" Maybe a : [ Just a, Nothing ] @@ -47,7 +67,7 @@ mod gen_tags { #[test] fn applied_tag_just() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" Maybe a : [ Just a, Nothing ] @@ -63,9 +83,29 @@ mod gen_tags { ); } + #[test] + fn applied_tag_just_ir() { + assert_llvm_ir_evals_to!( + indoc!( + r#" + Maybe a : [ Just a, Nothing ] + + y : Maybe Int + y = Just 0x4 + + 0x1 + "# + ), + 1, + i64, + |x| x, + false + ); + } + #[test] fn applied_tag_just_unit() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" Fruit : [ Orange, Apple, Banana ] @@ -87,7 +127,7 @@ mod gen_tags { // #[test] // fn raw_result() { - // assert_evals_to!( + // assert_evals_to_ir!( // indoc!( // r#" // x : Result Int Int @@ -103,7 +143,7 @@ mod gen_tags { #[test] fn true_is_true() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" bool : [True, False] @@ -119,7 +159,7 @@ mod gen_tags { #[test] fn false_is_false() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" bool : [True, False] @@ -135,7 +175,7 @@ mod gen_tags { #[test] fn basic_enum() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" Fruit : [ Apple, Orange, Banana ] @@ -156,7 +196,7 @@ mod gen_tags { // #[test] // fn linked_list_empty() { - // assert_evals_to!( + // assert_evals_to_ir!( // indoc!( // r#" // LinkedList a : [ Cons a (LinkedList a), Nil ] @@ -174,7 +214,7 @@ mod gen_tags { // // #[test] // fn linked_list_singleton() { - // assert_evals_to!( + // assert_evals_to_ir!( // indoc!( // r#" // LinkedList a : [ Cons a (LinkedList a), Nil ] @@ -192,7 +232,7 @@ mod gen_tags { // // #[test] // fn linked_list_is_empty() { - // assert_evals_to!( + // assert_evals_to_ir!( // indoc!( // r#" // LinkedList a : [ Cons a (LinkedList a), Nil ] @@ -213,7 +253,7 @@ mod gen_tags { #[test] fn even_odd() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" even = \n -> @@ -238,7 +278,7 @@ mod gen_tags { #[test] fn gen_literal_true() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" if True then -1 else 1 @@ -251,7 +291,7 @@ mod gen_tags { #[test] fn gen_if_float() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" if True then -1.0 else 1.0 @@ -263,7 +303,7 @@ mod gen_tags { } #[test] fn when_on_nothing() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" x : [ Nothing, Just Int ] @@ -281,7 +321,7 @@ mod gen_tags { #[test] fn when_on_just() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" x : [ Nothing, Just Int ] @@ -299,7 +339,7 @@ mod gen_tags { #[test] fn when_on_result() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" x : Result Int Int @@ -317,7 +357,7 @@ mod gen_tags { #[test] fn when_on_these() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" These a b : [ This a, That b, These a b ] @@ -339,7 +379,7 @@ mod gen_tags { #[test] fn match_on_two_values() { // this will produce a Chain internally - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when Pair 2 3 is @@ -354,7 +394,7 @@ mod gen_tags { #[test] fn pair_with_guard_pattern() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when Pair 2 3 is @@ -371,7 +411,7 @@ mod gen_tags { #[test] fn result_with_guard_pattern() { // This test revealed an issue with hashing Test values - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" x : Result Int Int @@ -390,7 +430,7 @@ mod gen_tags { #[test] fn maybe_is_just() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" Maybe a : [ Just a, Nothing ] @@ -411,7 +451,7 @@ mod gen_tags { #[test] fn nested_pattern_match() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" Maybe a : [ Nothing, Just a ] @@ -430,7 +470,7 @@ mod gen_tags { } #[test] fn if_guard_pattern_false() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when 2 is @@ -445,7 +485,7 @@ mod gen_tags { #[test] fn if_guard_pattern_true() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when 2 is @@ -460,7 +500,7 @@ mod gen_tags { #[test] fn if_guard_exhaustiveness() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when 2 is @@ -475,7 +515,7 @@ mod gen_tags { #[test] fn when_on_enum() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" Fruit : [ Apple, Orange, Banana ] @@ -496,7 +536,7 @@ mod gen_tags { #[test] fn pattern_matching_unit() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" Unit : [ Unit ] @@ -511,7 +551,7 @@ mod gen_tags { i64 ); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" Unit : [ Unit ] @@ -527,7 +567,7 @@ mod gen_tags { i64 ); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" f : {} -> Int @@ -540,7 +580,7 @@ mod gen_tags { i64 ); - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" when {} is @@ -554,7 +594,7 @@ mod gen_tags { #[test] fn one_element_tag() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" x : [ Pair Int ] @@ -570,7 +610,7 @@ mod gen_tags { #[test] fn nested_tag_union() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" Maybe a : [ Nothing, Just a ] @@ -587,7 +627,7 @@ mod gen_tags { } #[test] fn unit_type() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" Unit : [ Unit ] @@ -605,7 +645,7 @@ mod gen_tags { #[test] fn nested_record_load() { - assert_evals_to!( + assert_evals_to_ir!( indoc!( r#" Maybe a : [ Nothing, Just a ] @@ -621,4 +661,20 @@ mod gen_tags { i64 ); } + + #[test] + fn join_points() { + assert_evals_to_ir!( + indoc!( + r#" + x = + if True then 1 else 2 + + 5 + "# + ), + 5, + i64 + ); + } } diff --git a/compiler/gen/tests/helpers/eval.rs b/compiler/gen/tests/helpers/eval.rs index 5849f5ff0d..9f54cbbfaa 100644 --- a/compiler/gen/tests/helpers/eval.rs +++ b/compiler/gen/tests/helpers/eval.rs @@ -1,4 +1,6 @@ -#[macro_export] +use roc_gen::llvm::build::Scope; + +// #[macro_export] macro_rules! assert_llvm_evals_to { ($src:expr, $expected:expr, $ty:ty, $transform:expr, $leak:expr) => { let target = target_lexicon::Triple::host(); @@ -147,7 +149,7 @@ macro_rules! assert_llvm_evals_to { let ret = roc_gen::llvm::build::build_expr( &env, &mut layout_ids, - &ImMap::default(), + &mut Scope::default(), main_fn, &main_body, ); @@ -155,7 +157,7 @@ macro_rules! assert_llvm_evals_to { builder.build_return(Some(&ret)); // Uncomment this to see the module's un-optimized LLVM instruction output: - // env.module.print_to_stderr(); + env.module.print_to_stderr(); if main_fn.verify(true) { fpm.run_on(&main_fn); @@ -190,7 +192,7 @@ macro_rules! assert_llvm_evals_to { // TODO this is almost all code duplication with assert_llvm_evals_to // the only difference is that this calls uniq_expr instead of can_expr. // Should extract the common logic into test helpers. -#[macro_export] +// #[macro_export] macro_rules! assert_opt_evals_to { ($src:expr, $expected:expr, $ty:ty, $transform:expr, $leak:expr) => { let arena = Bump::new(); @@ -338,7 +340,7 @@ macro_rules! assert_opt_evals_to { let ret = roc_gen::llvm::build::build_expr( &env, &mut layout_ids, - &ImMap::default(), + &mut Scope::default(), main_fn, &main_body, ); @@ -378,7 +380,7 @@ macro_rules! assert_opt_evals_to { }; } -#[macro_export] +// #[macro_export] macro_rules! assert_evals_to { ($src:expr, $expected:expr, $ty:ty) => { // Run un-optimized tests, and then optimized tests, in separate scopes. @@ -411,3 +413,219 @@ macro_rules! assert_evals_to { } }; } + +#[macro_export] +macro_rules! assert_llvm_ir_evals_to { + ($src:expr, $expected:expr, $ty:ty, $transform:expr, $leak:expr) => { + let target = target_lexicon::Triple::host(); + let ptr_bytes = target.pointer_width().unwrap().bytes() as u32; + let arena = Bump::new(); + let CanExprOut { loc_expr, var_store, var, constraint, home, interns, problems, .. } = can_expr($src); + let errors = problems.into_iter().filter(|problem| { + use roc_problem::can::Problem::*; + + // Ignore "unused" problems + match problem { + UnusedDef(_, _) | UnusedArgument(_, _, _) | UnusedImport(_, _) => false, + _ => true, + } + }).collect::>(); + + assert_eq!(errors, Vec::new(), "Encountered errors: {:?}", errors); + + let subs = Subs::new(var_store.into()); + let mut unify_problems = Vec::new(); + let (content, mut subs) = infer_expr(subs, &mut unify_problems, &constraint, var); + + assert_eq!(unify_problems, Vec::new(), "Encountered type mismatches: {:?}", unify_problems); + + let context = Context::create(); + let module = roc_gen::llvm::build::module_from_builtins(&context, "app"); + let builder = context.create_builder(); + let opt_level = if cfg!(debug_assertions) { + roc_gen::llvm::build::OptLevel::Normal + } else { + roc_gen::llvm::build::OptLevel::Optimize + }; + let fpm = PassManager::create(&module); + + roc_gen::llvm::build::add_passes(&fpm, opt_level); + + fpm.initialize(); + + // Compute main_fn_type before moving subs to Env + let layout = Layout::new(&arena, content, &subs, ptr_bytes) + .unwrap_or_else(|err| panic!("Code gen error in NON-OPTIMIZED test: could not convert to layout. Err was {:?}", err)); + let execution_engine = + module + .create_jit_execution_engine(OptimizationLevel::None) + .expect("Error creating JIT execution engine for test"); + + let main_fn_type = basic_type_from_layout(&arena, &context, &layout, ptr_bytes) + .fn_type(&[], false); + let main_fn_name = "$Test.main"; + + // Compile and add all the Procs before adding main + let mut env = roc_gen::llvm::build::Env { + arena: &arena, + builder: &builder, + context: &context, + interns, + module: arena.alloc(module), + ptr_bytes, + leak: $leak + + }; + let mut procs = roc_mono::experiment::Procs::default(); + let mut ident_ids = env.interns.all_ident_ids.remove(&home).unwrap(); + let mut layout_ids = roc_gen::layout_id::LayoutIds::default(); + + // Populate Procs and get the low-level Expr from the canonical Expr + let mut mono_problems = Vec::new(); + let mut mono_env = roc_mono::experiment::Env { + arena: &arena, + subs: &mut subs, + problems: &mut mono_problems, + home, + ident_ids: &mut ident_ids, + pointer_size: ptr_bytes, + jump_counter: arena.alloc(0), + }; + + let main_body = roc_mono::experiment::Stmt::new(&mut mono_env, loc_expr.value, &mut procs); + let mut headers = { + let num_headers = match &procs.pending_specializations { + Some(map) => map.len(), + None => 0 + }; + + Vec::with_capacity(num_headers) + }; + let mut layout_cache = roc_mono::layout::LayoutCache::default(); + let mut procs = roc_mono::experiment::specialize_all(&mut mono_env, procs, &mut layout_cache); + + assert_eq!(procs.runtime_errors, roc_collections::all::MutMap::default()); + + // Put this module's ident_ids back in the interns, so we can use them in env. + // This must happen *after* building the headers, because otherwise there's + // a conflicting mutable borrow on ident_ids. + env.interns.all_ident_ids.insert(home, ident_ids); + + use roc_gen::llvm::build::{build_proc_header_ir, build_proc_ir }; + // Add all the Proc headers to the module. + // We have to do this in a separate pass first, + // because their bodies may reference each other. + for ((symbol, layout), proc) in procs.specialized.drain() { + use roc_mono::experiment::InProgressProc::*; + + match proc { + InProgress => { + panic!("A specialization was still marked InProgress after monomorphization had completed: {:?} with layout {:?}", symbol, layout); + } + Done(proc) => { + let (fn_val, arg_basic_types) = + build_proc_header_ir(&env, &mut layout_ids, symbol, &layout, &proc); + + headers.push((proc, fn_val, arg_basic_types)); + } + } + } + + // Build each proc using its header info. + for (proc, fn_val, arg_basic_types) in headers { + build_proc_ir(&env, &mut layout_ids, proc, fn_val, arg_basic_types); + + if fn_val.verify(true) { + fpm.run_on(&fn_val); + } else { + eprintln!( + "\n\nFunction {:?} failed LLVM verification in NON-OPTIMIZED build. Its content was:\n", fn_val.get_name().to_str().unwrap() + ); + + fn_val.print_to_stderr(); + + panic!( + "The preceding code was from {:?}, which failed LLVM verification in NON-OPTIMIZED build.", fn_val.get_name().to_str().unwrap() + ); + } + } + + // Add main to the module. + let main_fn = env.module.add_function(main_fn_name, main_fn_type, None); + let cc = roc_gen::llvm::build::get_call_conventions(target.default_calling_convention().unwrap()); + + main_fn.set_call_conventions(cc); + + // Add main's body + let basic_block = context.append_basic_block(main_fn, "entry"); + + builder.position_at_end(basic_block); + + use roc_gen::llvm::build::Scope; + let ret = roc_gen::llvm::build::build_exp_stmt( + &env, + &mut layout_ids, + &mut Scope::default(), + main_fn, + &main_body, + ); + + builder.build_return(Some(&ret)); + + // Uncomment this to see the module's un-optimized LLVM instruction output: + env.module.print_to_stderr(); + + if main_fn.verify(true) { + fpm.run_on(&main_fn); + } else { + panic!("Main function {} failed LLVM verification in NON-OPTIMIZED build. Uncomment things nearby to see more details.", main_fn_name); + } + + // Verify the module + if let Err(errors) = env.module.verify() { + panic!("Errors defining module: {:?}", errors); + } + + // Uncomment this to see the module's optimized LLVM instruction output: + // env.module.print_to_stderr(); + + unsafe { + let main: JitFunction $ty> = execution_engine + .get_function(main_fn_name) + .ok() + .ok_or(format!("Unable to JIT compile `{}`", main_fn_name)) + .expect("errored"); + + assert_eq!($transform(main.call()), $expected); + } + }; + + ($src:expr, $expected:expr, $ty:ty, $transform:expr) => { + assert_llvm_ir_evals_to!($src, $expected, $ty, $transform, false); + }; +} + +#[macro_export] +macro_rules! assert_evals_to_ir { + ($src:expr, $expected:expr, $ty:ty) => { + // Run un-optimized tests, and then optimized tests, in separate scopes. + // These each rebuild everything from scratch, starting with + // parsing the source, so that there's no chance their passing + // or failing depends on leftover state from the previous one. + { + assert_llvm_ir_evals_to!($src, $expected, $ty, (|val| val)); + } + }; + ($src:expr, $expected:expr, $ty:ty, $transform:expr) => { + // Same as above, except with an additional transformation argument. + { + assert_llvm_ir_evals_to!($src, $expected, $ty, $transform); + } + }; + ($src:expr, $expected:expr, $ty:ty, $transform:expr, $leak:expr) => { + // Same as above, except with an additional transformation argument. + { + assert_llvm_ir_evals_to!($src, $expected, $ty, $transform, $leak); + } + }; +} diff --git a/compiler/mono/src/decision_tree2.rs b/compiler/mono/src/decision_tree2.rs index 66965e167d..1d2f41e444 100644 --- a/compiler/mono/src/decision_tree2.rs +++ b/compiler/mono/src/decision_tree2.rs @@ -1,4 +1,4 @@ -use crate::experiment::{DestructType, Env, Expr, Literal, Pattern, Stmt}; +use crate::experiment::{DestructType, Env, Expr, JoinPointId, Literal, Pattern, Stmt}; use crate::layout::{Builtin, Layout}; use crate::pattern2::{Ctor, RenderAs, TagId, Union}; use bumpalo::Bump; @@ -31,8 +31,12 @@ pub fn compile<'a>(raw_branches: Vec<(Guard<'a>, Pattern<'a>, u64)>) -> Decision pub enum Guard<'a> { NoGuard, Guard { - stores: &'a [(Symbol, Layout<'a>, Expr<'a>)], - expr: Stmt<'a>, + /// Symbol that stores a boolean + /// when true this branch is picked, otherwise skipped + symbol: Symbol, + /// after assigning to symbol, the stmt jumps to this label + id: JoinPointId, + stmt: Stmt<'a>, }, } @@ -72,8 +76,12 @@ pub enum Test<'a> { // A pattern that always succeeds (like `_`) can still have a guard Guarded { opt_test: Option>>, - stores: &'a [(Symbol, Layout<'a>, Expr<'a>)], - expr: Stmt<'a>, + /// Symbol that stores a boolean + /// when true this branch is picked, otherwise skipped + symbol: Symbol, + /// after assigning to symbol, the stmt jumps to this label + id: JoinPointId, + stmt: Stmt<'a>, }, } use std::hash::{Hash, Hasher}; @@ -353,11 +361,12 @@ fn test_at_path<'a>(selected_path: &Path, branch: Branch<'a>, all_tests: &mut Ve None => {} Some((_, guard, pattern)) => { let guarded = |test| { - if let Guard::Guard { stores, expr } = guard { + if let Guard::Guard { symbol, id, stmt } = guard { Guarded { opt_test: Some(Box::new(test)), - stores, - expr: expr.clone(), + stmt: stmt.clone(), + symbol: *symbol, + id: *id, } } else { test @@ -367,11 +376,12 @@ fn test_at_path<'a>(selected_path: &Path, branch: Branch<'a>, all_tests: &mut Ve match pattern { // TODO use guard! Identifier(_) | Underscore | Shadowed(_, _) | UnsupportedPattern(_) => { - if let Guard::Guard { stores, expr } = guard { + if let Guard::Guard { symbol, id, stmt } = guard { all_tests.push(Guarded { opt_test: None, - stores, - expr: expr.clone(), + stmt: stmt.clone(), + symbol: *symbol, + id: *id, }); } } @@ -994,8 +1004,7 @@ fn test_to_equality<'a>( cond_layout: &Layout<'a>, path: &Path, test: Test<'a>, - tests: &mut Vec<(StoresVec<'a>, Symbol, Symbol, Layout<'a>)>, -) { +) -> (StoresVec<'a>, Symbol, Symbol, Layout<'a>) { match test { Test::IsCtor { tag_id, @@ -1034,24 +1043,25 @@ fn test_to_equality<'a>( stores.push((lhs_symbol, Layout::Builtin(Builtin::Int64), lhs)); stores.push((rhs_symbol, Layout::Builtin(Builtin::Int64), rhs)); - tests.push(( + ( stores, lhs_symbol, rhs_symbol, Layout::Builtin(Builtin::Int64), - )); + ) } Test::IsInt(test_int) => { let lhs = Expr::Literal(Literal::Int(test_int)); let lhs_symbol = env.unique_symbol(); let (mut stores, rhs_symbol) = path_to_expr(env, cond_symbol, &path, &cond_layout); stores.push((lhs_symbol, Layout::Builtin(Builtin::Int64), lhs)); - tests.push(( + + ( stores, lhs_symbol, rhs_symbol, Layout::Builtin(Builtin::Int64), - )) + ) } Test::IsFloat(test_int) => { @@ -1061,12 +1071,13 @@ fn test_to_equality<'a>( let lhs_symbol = env.unique_symbol(); let (mut stores, rhs_symbol) = path_to_expr(env, cond_symbol, &path, &cond_layout); stores.push((lhs_symbol, Layout::Builtin(Builtin::Float64), lhs)); - tests.push(( + + ( stores, lhs_symbol, rhs_symbol, Layout::Builtin(Builtin::Float64), - )) + ) } Test::IsByte { @@ -1076,12 +1087,13 @@ fn test_to_equality<'a>( let lhs_symbol = env.unique_symbol(); let (mut stores, rhs_symbol) = path_to_expr(env, cond_symbol, &path, &cond_layout); stores.push((lhs_symbol, Layout::Builtin(Builtin::Int8), lhs)); - tests.push(( + + ( stores, lhs_symbol, rhs_symbol, Layout::Builtin(Builtin::Int8), - )); + ) } Test::IsBit(test_bit) => { @@ -1089,12 +1101,12 @@ fn test_to_equality<'a>( let lhs_symbol = env.unique_symbol(); let (mut stores, rhs_symbol) = path_to_expr(env, cond_symbol, &path, &cond_layout); - tests.push(( + ( stores, lhs_symbol, rhs_symbol, Layout::Builtin(Builtin::Int1), - )); + ) } Test::IsStr(test_str) => { @@ -1104,31 +1116,15 @@ fn test_to_equality<'a>( stores.push((lhs_symbol, Layout::Builtin(Builtin::Str), lhs)); - tests.push(( + ( stores, lhs_symbol, rhs_symbol, Layout::Builtin(Builtin::Str), - )); + ) } - Test::Guarded { - opt_test, - stores, - expr, - } => { - if let Some(nested) = opt_test { - test_to_equality(env, cond_symbol, cond_layout, path, *nested, tests); - } - - /* - let lhs = Expr::Bool(true); - let rhs = Expr::Store(stores, env.arena.alloc(expr)); - - tests.push((lhs, rhs, Layout::Builtin(Builtin::Int1))); - */ - todo!("pattern with guard, what to do?") - } + Test::Guarded { .. } => unreachable!("should be handled elsewhere"), } } @@ -1160,12 +1156,6 @@ fn decide_to_branching<'a>( } => { // generate a switch based on the test chain - let mut tests = Vec::with_capacity(test_chain.len()); - - for (path, test) in test_chain { - test_to_equality(env, cond_symbol, &cond_layout, &path, test, &mut tests); - } - let (pass_stores, mut pass_expr) = decide_to_branching( env, cond_symbol, @@ -1175,8 +1165,11 @@ fn decide_to_branching<'a>( jumps, ); + dbg!(&pass_expr); + // TODO remove clone - for (symbol, layout, expr) in pass_stores.iter().cloned() { + for (symbol, layout, expr) in pass_stores.iter().cloned().rev() { + println!("{} {:?}", symbol, expr); pass_expr = Stmt::Let(symbol, expr, layout, env.arena.alloc(pass_expr)); } @@ -1190,7 +1183,7 @@ fn decide_to_branching<'a>( ); // TODO remove clone - for (symbol, layout, expr) in fail_stores.iter().cloned() { + for (symbol, layout, expr) in fail_stores.iter().cloned().rev() { fail_expr = Stmt::Let(symbol, expr, layout, env.arena.alloc(fail_expr)); } @@ -1211,7 +1204,7 @@ fn decide_to_branching<'a>( let mut cond = Stmt::Cond { cond_symbol, - cond_layout, + cond_layout: cond_layout.clone(), branching_symbol, branching_layout, pass, @@ -1221,21 +1214,64 @@ fn decide_to_branching<'a>( let true_symbol = env.unique_symbol(); - // let condition = boolean_all(env.arena, tests); + let mut tests = Vec::with_capacity(test_chain.len()); + + let mut guard = None; + + // Assumption: there is at most 1 guard, and it is the outer layer. + for (path, test) in test_chain { + match test { + Test::Guarded { + opt_test, + id, + symbol, + stmt, + } => { + if let Some(nested) = opt_test { + tests.push(test_to_equality( + env, + cond_symbol, + &cond_layout, + &path, + *nested, + )); + } + + // let (stores, rhs_symbol) = path_to_expr(env, cond_symbol, &path, &cond_layout); + + guard = Some((symbol, id, stmt)); + } + + _ => tests.push(test_to_equality( + env, + cond_symbol, + &cond_layout, + &path, + test, + )), + } + } + debug_assert!(!tests.is_empty()); let mut current_symbol = branching_symbol; let mut condition_symbol = true_symbol; - for (new_stores, lhs, rhs, layout) in tests.into_iter().rev() { + + let accum_symbols = std::iter::once(true_symbol) + .chain((0..tests.len() - 1).map(|_| env.unique_symbol())) + .rev() + .collect::>(); + + for ((new_stores, lhs, rhs, layout), accum) in + tests.into_iter().rev().zip(accum_symbols) + { let test_symbol = env.unique_symbol(); let test = Expr::RunLowLevel( LowLevel::Eq, bumpalo::vec![in env.arena; lhs, rhs].into_bump_slice(), ); - let and_expr = Expr::RunLowLevel( - LowLevel::And, - env.arena.alloc([test_symbol, condition_symbol]), - ); + let and_expr = + Expr::RunLowLevel(LowLevel::And, env.arena.alloc([test_symbol, accum])); // write to the branching symbol cond = Stmt::Let( @@ -1257,9 +1293,40 @@ fn decide_to_branching<'a>( cond = Stmt::Let(symbol, expr, layout, env.arena.alloc(cond)); } + condition_symbol = current_symbol; + current_symbol = accum; + } + + /* + + // the guard is the final thing that we check, so needs to be layered on first! + if let Some((symbol, id, stmt)) = guard { + let test_symbol = symbol; + let and_expr = Expr::RunLowLevel( + LowLevel::And, + env.arena.alloc([test_symbol, condition_symbol]), + ); + + // write to the branching symbol + cond = Stmt::Let( + current_symbol, + and_expr, + Layout::Builtin(Builtin::Int1), + env.arena.alloc(cond), + ); + + // calculate the guard value + cond = Stmt::Join { + id, + arguments: &[], + remainder: env.arena.alloc(stmt), + continuation: env.arena.alloc(cond), + }; + condition_symbol = current_symbol; current_symbol = env.unique_symbol(); } + */ cond = Stmt::Let( true_symbol, @@ -1324,6 +1391,7 @@ fn decide_to_branching<'a>( branches.push((tag, branch)); } + dbg!(&branches, &default_branch); let mut switch = Stmt::Switch { cond_layout, diff --git a/compiler/mono/src/experiment.rs b/compiler/mono/src/experiment.rs index fcbb7023ff..425c9dcd1c 100644 --- a/compiler/mono/src/experiment.rs +++ b/compiler/mono/src/experiment.rs @@ -304,7 +304,7 @@ impl<'a, 'i> Env<'a, 'i> { } } -#[derive(Clone, Debug, PartialEq, Copy)] +#[derive(Clone, Debug, PartialEq, Copy, Eq, Hash)] pub struct JoinPointId(Symbol); pub type Stores<'a> = &'a [(Symbol, Layout<'a>, Expr<'a>)]; @@ -348,8 +348,10 @@ pub enum Stmt<'a> { Join { id: JoinPointId, arguments: &'a [Symbol], - result: &'a Stmt<'a>, + /// does not contain jumps to this id continuation: &'a Stmt<'a>, + /// contains the jumps to this id + remainder: &'a Stmt<'a>, }, Jump(JoinPointId, &'a [Symbol]), RuntimeError(&'a str), @@ -441,6 +443,15 @@ where alloc.text(format!("{}", symbol)) } +fn join_point_to_doc<'b, D, A>(alloc: &'b D, symbol: JoinPointId) -> DocBuilder<'b, D, A> +where + D: DocAllocator<'b, A>, + D::Doc: Clone, + A: Clone, +{ + alloc.text(format!("{}", symbol.0)) +} + impl<'a> Expr<'a> { pub fn to_doc<'b, D, A>(&'b self, alloc: &'b D, parens: bool) -> DocBuilder<'b, D, A> where @@ -528,6 +539,15 @@ impl<'a> Expr<'a> { } impl<'a> Stmt<'a> { + pub fn new( + env: &mut Env<'a, '_>, + can_expr: roc_can::expr::Expr, + procs: &mut Procs<'a>, + ) -> Self { + let mut layout_cache = LayoutCache::default(); + + from_can(env, can_expr, procs, &mut layout_cache) + } pub fn to_doc<'b, D, A>(&'b self, alloc: &'b D, parens: bool) -> DocBuilder<'b, D, A> where D: DocAllocator<'b, A>, @@ -597,6 +617,34 @@ impl<'a> Stmt<'a> { .append(alloc.hardline()) .append(fail.to_doc(alloc, false).indent(4)), RuntimeError(s) => alloc.text(format!("Error {}", s)), + + Join { + id, + arguments, + continuation, + remainder, + } => alloc.intersperse( + vec![ + remainder.to_doc(alloc, false), + alloc + .text("joinpoint ") + .append(join_point_to_doc(alloc, *id)) + .append(":"), + continuation.to_doc(alloc, false).indent(4), + ], + alloc.hardline(), + ), + Jump(id, arguments) => { + let it = arguments.iter().map(|s| symbol_to_doc(alloc, *s)); + + alloc + .text("jump ") + .append(join_point_to_doc(alloc, *id)) + .append(" ".repeat(arguments.len().min(1))) + .append(alloc.intersperse(it, alloc.space())) + .append(";") + } + _ => todo!(), } /* @@ -888,11 +936,14 @@ fn specialize<'a>( .from_var(&env.arena, ret_var, env.subs, env.pointer_size) .unwrap_or_else(|err| panic!("TODO handle invalid function {:?}", err)); + // TODO WRONG + let closes_over_layout = Layout::Struct(&[]); + let proc = Proc { name: proc_name, args: proc_args.into_bump_slice(), body: specialized_body, - closes_over: Layout::Struct(&[]), + closes_over: closes_over_layout, ret_layout, }; @@ -968,7 +1019,7 @@ pub fn with_hole<'a>( } } Var(symbol) => Stmt::Ret(symbol), - + // Var(symbol) => panic!("reached Var {}", symbol), Tag { variant_var, name: tag_name, @@ -987,12 +1038,7 @@ pub fn with_hole<'a>( match variant { Never => unreachable!("The `[]` type has no constructors"), - Unit => Stmt::Let( - assigned, - Expr::Struct(&[]), - Layout::Builtin(Builtin::Float64), - hole, - ), + Unit => Stmt::Let(assigned, Expr::Struct(&[]), Layout::Struct(&[]), hole), BoolUnion { ttrue, .. } => Stmt::Let( assigned, Expr::Literal(Literal::Bool(tag_name == ttrue)), @@ -1016,12 +1062,22 @@ pub fn with_hole<'a>( Unwrapped(field_layouts) => { let mut field_symbols = Vec::with_capacity_in(field_layouts.len(), env.arena); - for _ in 0..field_layouts.len() { - field_symbols.push(env.unique_symbol()); + for (_, arg) in args.iter() { + if let roc_can::expr::Expr::Var(symbol) = arg.value { + field_symbols.push(symbol); + } else { + field_symbols.push(env.unique_symbol()); + } } - let layout = Layout::Struct(field_layouts.into_bump_slice()); + // Layout will unpack this unwrapped tack if it only has one (non-zero-sized) field + let layout = layout_cache + .from_var(env.arena, variant_var, env.subs, env.pointer_size) + .unwrap_or_else(|err| { + panic!("TODO turn fn_var into a RuntimeError {:?}", err) + }); + // even though this was originally a Tag, we treat it as a Struct from now on let mut stmt = Stmt::Let( assigned, Expr::Struct(field_symbols.clone().into_bump_slice()), @@ -1031,6 +1087,10 @@ pub fn with_hole<'a>( for ((_, arg), symbol) in args.into_iter().rev().zip(field_symbols.iter().rev()) { + // if this argument is already a symbol, we don't need to re-define it + if let roc_can::expr::Expr::Var(_) = arg.value { + continue; + } stmt = with_hole( env, arg.value, @@ -1055,8 +1115,12 @@ pub fn with_hole<'a>( let tag_id_symbol = env.unique_symbol(); field_symbols.push(tag_id_symbol); - for _ in 0..args.len() { - field_symbols.push(env.unique_symbol()); + for (_, arg) in args.iter() { + if let roc_can::expr::Expr::Var(symbol) = arg.value { + field_symbols.push(symbol); + } else { + field_symbols.push(env.unique_symbol()); + } } let layout_it = argument_layouts.iter(); @@ -1082,6 +1146,11 @@ pub fn with_hole<'a>( for ((_, arg), symbol) in args.into_iter().rev().zip(field_symbols.iter().rev()) { + // if this argument is already a symbol, we don't need to re-define it + if let roc_can::expr::Expr::Var(_) = arg.value { + continue; + } + stmt = with_hole( env, arg.value, @@ -1130,7 +1199,10 @@ pub fn with_hole<'a>( can_fields.push(field); } - let layout = Layout::Struct(field_layouts.into_bump_slice()); + // creating a record from the var will unpack it if it's just a single field. + let layout = layout_cache + .from_var(env.arena, record_var, env.subs, env.pointer_size) + .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); let field_symbols = field_symbols.into_bump_slice(); let mut stmt = Stmt::Let(assigned, Expr::Struct(field_symbols), layout, hole); @@ -1151,16 +1223,133 @@ pub fn with_hole<'a>( EmptyRecord => Stmt::Let(assigned, Expr::Struct(&[]), Layout::Struct(&[]), hole), + If { + cond_var, + branch_var, + branches, + final_else, + } => { + let arena = env.arena; + + let ret_layout = layout_cache + .from_var(env.arena, branch_var, env.subs, env.pointer_size) + .expect("invalid ret_layout"); + let cond_layout = layout_cache + .from_var(env.arena, cond_var, env.subs, env.pointer_size) + .expect("invalid cond_layout"); + + let id = JoinPointId(env.unique_symbol()); + let jump = env.arena.alloc(Stmt::Jump(id, &[])); + + let mut stmt = with_hole(env, final_else.value, procs, layout_cache, assigned, jump); + + for (loc_cond, loc_then) in branches.into_iter().rev() { + let branching_symbol = env.unique_symbol(); + let then = with_hole(env, loc_then.value, procs, layout_cache, assigned, jump); + + stmt = Stmt::Cond { + cond_symbol: branching_symbol, + branching_symbol, + cond_layout: cond_layout.clone(), + branching_layout: cond_layout.clone(), + pass: env.arena.alloc(then), + fail: env.arena.alloc(stmt), + ret_layout: ret_layout.clone(), + }; + + // add condition + stmt = with_hole( + env, + loc_cond.value, + procs, + layout_cache, + branching_symbol, + env.arena.alloc(stmt), + ); + } + + let join = Stmt::Join { + id, + arguments: &[], + remainder: env.arena.alloc(stmt), + continuation: hole, + }; + + // expr + join + } + When { .. } | If { .. } => todo!("when or if in expression requires join points"), List { .. } => todo!("list"), LetRec(_, _, _, _) | LetNonRec(_, _, _, _) => todo!("lets"), - Access { .. } | Accessor { .. } | Update { .. } => todo!("record access/accessor/update"), + Access { + record_var, + field_var, + field, + loc_expr, + .. + } => { + let arena = env.arena; + + let sorted_fields = crate::layout::sort_record_fields( + env.arena, + record_var, + env.subs, + env.pointer_size, + ); + + let mut index = None; + let mut field_layouts = Vec::with_capacity_in(sorted_fields.len(), env.arena); + + for (current, (label, field_layout)) in sorted_fields.into_iter().enumerate() { + field_layouts.push(field_layout); + + if label == field { + index = Some(current); + } + } + + let record_symbol = if let roc_can::expr::Expr::Var(symbol) = loc_expr.value { + symbol + } else { + env.unique_symbol() + }; + + let expr = Expr::AccessAtIndex { + index: index.expect("field not in its own type") as u64, + field_layouts: field_layouts.into_bump_slice(), + structure: record_symbol, + is_unwrapped: true, + }; + + let layout = layout_cache + .from_var(env.arena, field_var, env.subs, env.pointer_size) + .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); + + let mut stmt = Stmt::Let(assigned, expr, layout, hole); + + if let roc_can::expr::Expr::Var(symbol) = loc_expr.value { + // do nothing + } else { + stmt = with_hole( + env, + loc_expr.value, + procs, + layout_cache, + record_symbol, + env.arena.alloc(stmt), + ); + }; + + stmt + } + + Accessor { .. } | Update { .. } => todo!("record access/accessor/update"), Closure(_, _, _, _, _) => todo!("call"), Call(boxed, loc_args, _) => { - dbg!(&boxed, &loc_args); let (fn_var, loc_expr, ret_var) = *boxed; /* @@ -1188,7 +1377,6 @@ pub fn with_hole<'a>( */ // match from_can(env, loc_expr.value, procs, layout_cache) { - dbg!(&procs.module_thunks); match loc_expr.value { roc_can::expr::Expr::Var(proc_name) if procs.module_thunks.contains(&proc_name) => { todo!() @@ -1269,7 +1457,7 @@ pub fn with_hole<'a>( } } - RunLowLevel { op, args, .. } => { + RunLowLevel { op, args, ret_var } => { let op = optimize_low_level(env.subs, op, &args); let mut arg_symbols = Vec::with_capacity_in(args.len(), env.arena); @@ -1283,23 +1471,18 @@ pub fn with_hole<'a>( } let arg_symbols = arg_symbols.into_bump_slice(); - // WRONG!@ - let layout = Layout::Builtin(Builtin::Int64); + // layout of the return type + let layout = layout_cache + .from_var(env.arena, ret_var, env.subs, env.pointer_size) + .unwrap_or_else(|err| todo!("TODO turn fn_var into a RuntimeError {:?}", err)); + let mut result = Stmt::Let(assigned, Expr::RunLowLevel(op, arg_symbols), layout, hole); - dbg!(&args, &arg_symbols); - for ((arg_var, arg_expr), symbol) in + for ((_arg_var, arg_expr), symbol) in args.into_iter().rev().zip(arg_symbols.iter().rev()) { - dbg!(&result); - /* - let arg = from_can(env, arg_expr, procs, layout_cache); - let layout = layout_cache - .from_var(env.arena, arg_var, env.subs, env.pointer_size) - .unwrap_or_else(|err| todo!("TODO turn fn_var into a RuntimeError {:?}", err)); - */ - - if let roc_can::expr::Expr::Var(symbol) = arg_expr { + // if this argument is already a symbol, we don't need to re-define it + if let roc_can::expr::Expr::Var(_) = arg_expr { continue; } @@ -1313,7 +1496,6 @@ pub fn with_hole<'a>( ); } - dbg!(&result); result } RuntimeError(_) => todo!("runtime error"), @@ -1341,7 +1523,6 @@ pub fn from_can<'a>( let (loc_body, ret_var) = *boxed_body; - dbg!("inserting", *symbol); procs.insert_named( env, layout_cache, @@ -1368,7 +1549,19 @@ pub fn from_can<'a>( ); } - todo!("convert complex pattern to when"); + let (symbol, can_expr) = + pattern_to_when(env, def.expr_var, def.loc_pattern, def.expr_var, *cont); + + let stmt = from_can(env, can_expr.value, procs, layout_cache); + + with_hole( + env, + def.loc_expr.value, + procs, + layout_cache, + symbol, + env.arena.alloc(stmt), + ) } If { @@ -1439,8 +1632,12 @@ pub fn from_can<'a>( procs, ); - let hole = env.arena.alloc(mono_when); - with_hole(env, loc_cond.value, procs, layout_cache, cond_symbol, hole) + if let roc_can::expr::Expr::Var(_) = loc_cond.value { + mono_when + } else { + let hole = env.arena.alloc(mono_when); + with_hole(env, loc_cond.value, procs, layout_cache, cond_symbol, hole) + } } _ => { let symbol = env.unique_symbol(); @@ -1568,11 +1765,27 @@ fn from_can_when<'a>( // // otherwise, we modify the branch's expression to include the stores if let Some(loc_guard) = when_branch.guard.clone() { - let expr = from_can(env, loc_guard.value, procs, layout_cache); + let guard_symbol = env.unique_symbol(); + let id = JoinPointId(env.unique_symbol()); + + let hole = env.arena.alloc(Stmt::Jump(id, &[])); + let mut stmt = with_hole( + env, + loc_guard.value, + procs, + layout_cache, + guard_symbol, + hole, + ); + + for (symbol, expr, layout) in stores.into_iter().rev() { + stmt = Stmt::Let(symbol, layout, expr, env.arena.alloc(stmt)); + } ( crate::decision_tree2::Guard::Guard { - stores: stores.into_bump_slice(), - expr, + stmt, + id, + symbol: guard_symbol, }, &[] as &[_], mono_expr.clone(), @@ -1588,6 +1801,7 @@ fn from_can_when<'a>( Err(message) => { // when the pattern is invalid, a guard must give a runtime error too if when_branch.guard.is_some() { + /* ( crate::decision_tree2::Guard::Guard { stores: &[], @@ -1597,6 +1811,8 @@ fn from_can_when<'a>( // we can never hit this Stmt::RuntimeError(&"invalid pattern with guard: unreachable"), ) + */ + todo!() } else { ( crate::decision_tree2::Guard::NoGuard, @@ -1675,7 +1891,7 @@ fn store_pattern<'a>( Identifier(symbol) => { // let load = Expr::Load(outer_symbol); // stored.push((*symbol, layout, load)) - todo!() + // todo!() } Underscore => { // Since _ is never read, it's safe to reassign it. @@ -1821,14 +2037,23 @@ fn call_by_name<'a>( Ok(layout) => { // Build the CallByName node let arena = env.arena; - let mut args = Vec::with_capacity_in(loc_args.len(), arena); let mut pattern_vars = Vec::with_capacity_in(loc_args.len(), arena); - for (var, loc_arg) in loc_args { + let mut field_symbols = Vec::with_capacity_in(loc_args.len(), env.arena); + + for (_, arg_expr) in loc_args.iter() { + if let roc_can::expr::Expr::Var(symbol) = arg_expr.value { + field_symbols.push(symbol); + } else { + field_symbols.push(env.unique_symbol()); + } + } + let field_symbols = field_symbols.into_bump_slice(); + + for (var, loc_arg) in loc_args.clone() { match layout_cache.from_var(&env.arena, var, &env.subs, env.pointer_size) { Ok(layout) => { pattern_vars.push(var); - args.push((from_can(env, loc_arg.value, procs, layout_cache), layout)); } Err(_) => { // One of this function's arguments code gens to a runtime error, @@ -1838,15 +2063,41 @@ fn call_by_name<'a>( } } + // wrong, clearly. But in general I think the layout here should be just the return type. + let layout = if let Layout::FunctionPointer(_, rlayout) = layout { + rlayout + } else { + todo!() + }; + // If we've already specialized this one, no further work is needed. if procs.specialized.contains_key(&(proc_name, layout.clone())) { let call = Expr::FunctionCall { call_type: CallType::ByName(proc_name), layout: layout.clone(), - args: &[], + args: field_symbols, }; - Stmt::Let(assigned, call, layout.clone(), hole) + let mut result = Stmt::Let(assigned, call, layout.clone(), hole); + + for ((_, loc_arg), symbol) in + loc_args.into_iter().rev().zip(field_symbols.iter().rev()) + { + // if this argument is already a symbol, we don't need to re-define it + if let roc_can::expr::Expr::Var(_) = loc_arg.value { + continue; + } + result = with_hole( + env, + loc_arg.value, + procs, + layout_cache, + *symbol, + env.arena.alloc(result), + ); + } + + result } else { let pending = PendingSpecialization { pattern_vars, @@ -1873,10 +2124,29 @@ fn call_by_name<'a>( let call = Expr::FunctionCall { call_type: CallType::ByName(proc_name), layout: layout.clone(), - args: &[], + args: field_symbols, }; - Stmt::Let(assigned, call, layout, hole) + let mut result = Stmt::Let(assigned, call, layout.clone(), hole); + + for ((_, loc_arg), symbol) in + loc_args.into_iter().rev().zip(field_symbols.iter().rev()) + { + // if this argument is already a symbol, we don't need to re-define it + if let roc_can::expr::Expr::Var(_) = loc_arg.value { + continue; + } + result = with_hole( + env, + loc_arg.value, + procs, + layout_cache, + *symbol, + env.arena.alloc(result), + ); + } + + result } None => { let opt_partial_proc = procs.partial_procs.get(&proc_name); @@ -1909,10 +2179,32 @@ fn call_by_name<'a>( let call = Expr::FunctionCall { call_type: CallType::ByName(proc_name), layout: layout.clone(), - args: &[], + args: field_symbols, }; - Stmt::Let(assigned, call, layout, hole) + let mut result = + Stmt::Let(assigned, call, layout.clone(), hole); + + for ((_, loc_arg), symbol) in loc_args + .into_iter() + .rev() + .zip(field_symbols.iter().rev()) + { + // if this argument is already a symbol, we don't need to re-define it + if let roc_can::expr::Expr::Var(_) = loc_arg.value { + continue; + } + result = with_hole( + env, + loc_arg.value, + procs, + layout_cache, + *symbol, + env.arena.alloc(result), + ); + } + + result } Err(error) => { let error_msg = env.arena.alloc(format!( diff --git a/compiler/mono/tests/test_mono.rs b/compiler/mono/tests/test_mono.rs index 2430953acc..7f4dc26d4e 100644 --- a/compiler/mono/tests/test_mono.rs +++ b/compiler/mono/tests/test_mono.rs @@ -1017,6 +1017,11 @@ mod test_mono { let result = procs_string.join("\n"); + let the_same = result == expected; + if !the_same { + println!("{}", result); + } + assert_eq!(result, expected); } @@ -1115,12 +1120,12 @@ mod test_mono { let Test.10 = 0i64; let Test.11 = 3i64; let Test.1 = Just Test.10 Test.11; - let Test.7 = true; - let Test.5 = Index 0 Test.1; - let Test.4 = 0i64; - let Test.8 = lowlevel Eq Test.4 Test.5; - let Test.6 = lowlevel And Test.8 Test.7; - if Test.6 then + let Test.5 = true; + let Test.7 = Index 0 Test.1; + let Test.6 = 0i64; + let Test.8 = lowlevel Eq Test.6 Test.7; + let Test.4 = lowlevel And Test.8 Test.5; + if Test.4 then let Test.0 = Index 1 Test.1; ret Test.0; else @@ -1197,10 +1202,335 @@ mod test_mono { let Test.3 = lowlevel NumAdd #Attr.2 #Attr.3; ret Test.3; - let Test.0 = CallByName Num.14; + let Test.1 = 1i64; + let Test.2 = 2i64; + let Test.0 = CallByName Num.14 Test.1 Test.2; ret Test.0; "# ), ) } + + #[test] + fn ir_round() { + compiles_to_ir( + r#" + Num.round 3.6 + "#, + indoc!( + r#" + procedure Num.36 (#Attr.2): + let Test.2 = lowlevel NumRound #Attr.2; + ret Test.2; + + let Test.1 = 3.6f64; + let Test.0 = CallByName Num.36 Test.1; + ret Test.0; + "# + ), + ) + } + + #[test] + fn ir_when_idiv() { + compiles_to_ir( + r#" + when 1000 // 10 is + Ok val -> val + Err _ -> -1 + "#, + indoc!( + r#" + procedure Num.32 (#Attr.2, #Attr.3): + let Test.19 = 0i64; + let Test.15 = lowlevel NotEq #Attr.3 Test.19; + if Test.15 then + let Test.17 = 1i64; + let Test.18 = lowlevel NumDivUnchecked #Attr.2 #Attr.3; + let Test.16 = Ok Test.17 Test.18; + ret Test.16; + else + let Test.13 = 0i64; + let Test.14 = Struct {}; + let Test.12 = Err Test.13 Test.14; + ret Test.12; + + let Test.10 = 1000i64; + let Test.11 = 10i64; + let Test.1 = CallByName Num.32 Test.10 Test.11; + let Test.5 = true; + let Test.7 = Index 0 Test.1; + let Test.6 = 1i64; + let Test.8 = lowlevel Eq Test.6 Test.7; + let Test.4 = lowlevel And Test.8 Test.5; + if Test.4 then + let Test.0 = Index 1 Test.1; + ret Test.0; + else + let Test.3 = -1i64; + ret Test.3; + "# + ), + ) + } + + #[test] + fn ir_two_defs() { + compiles_to_ir( + r#" + x = 3 + y = 4 + + x + y + "#, + indoc!( + r#" + procedure Num.14 (#Attr.2, #Attr.3): + let Test.3 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.3; + + let Test.1 = 4i64; + let Test.0 = 3i64; + let Test.2 = CallByName Num.14 Test.0 Test.1; + ret Test.2; + "# + ), + ) + } + + #[test] + fn ir_when_just() { + compiles_to_ir( + r#" + x : [ Nothing, Just Int ] + x = Just 41 + + when x is + Just v -> v + 0x1 + Nothing -> 0x1 + "#, + indoc!( + r#" + procedure Num.14 (#Attr.2, #Attr.3): + let Test.4 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.4; + + let Test.12 = 0i64; + let Test.13 = 41i64; + let Test.0 = Just Test.12 Test.13; + let Test.7 = true; + let Test.9 = Index 0 Test.0; + let Test.8 = 0i64; + let Test.10 = lowlevel Eq Test.8 Test.9; + let Test.6 = lowlevel And Test.10 Test.7; + if Test.6 then + let Test.1 = Index 1 Test.0; + let Test.3 = 1i64; + let Test.2 = CallByName Num.14 Test.1 Test.3; + ret Test.2; + else + let Test.5 = 1i64; + ret Test.5; + "# + ), + ) + } + + #[test] + fn one_element_tag() { + compiles_to_ir( + r#" + x : [ Pair Int ] + x = Pair 2 + + x + "#, + indoc!( + r#" + let Test.2 = 2i64; + let Test.0 = Pair Test.2; + ret Test.0; + "# + ), + ) + } + + #[test] + fn join_points() { + compiles_to_ir( + r#" + x = + if True then 1 else 2 + + x + "#, + indoc!( + r#" + let Test.3 = true; + if Test.3 then + let Test.0 = 1i64; + jump Test.2; + else + let Test.0 = 2i64; + jump Test.2; + joinpoint Test.2: + ret Test.0; + "# + ), + ) + } + + #[test] + fn guard_pattern_true() { + compiles_to_ir( + r#" + when 2 is + 2 if True -> 42 + _ -> 0 + "#, + indoc!( + r#" + let Test.0 = 2i64; + let Test.6 = true; + let Test.10 = lowlevel Eq Test.6 Test.2; + let Test.9 = lowlevel And Test.10 Test.5; + let Test.7 = 2i64; + let Test.8 = lowlevel Eq Test.7 Test.0; + let Test.5 = lowlevel And Test.8 Test.6; + let Test.2 = true; + jump Test.3; + joinpoint Test.3: + if Test.5 then + let Test.1 = 42i64; + ret Test.1; + else + let Test.4 = 0i64; + ret Test.4; + "# + ), + ) + } + + #[test] + fn when_on_record() { + compiles_to_ir( + r#" + when { x: 0x2 } is + { x } -> x + 3 + "#, + indoc!( + r#" + let Test.5 = 2i64; + let Test.1 = Struct {Test.5}; + let Test.0 = Index 0 Test.1; + let Test.3 = 3i64; + let Test.2 = CallByName Num.14 Test.0 Test.3; + ret Test.2; + "# + ), + ) + } + + #[test] + fn let_on_record() { + compiles_to_ir( + r#" + { x } = { x: 0x2, y: 3.14 } + + x + "#, + indoc!( + r#" + let Test.4 = 2i64; + let Test.5 = 3.14f64; + let Test.1 = Struct {Test.4, Test.5}; + let Test.0 = Index 0 Test.1; + ret Test.0; + "# + ), + ) + } + + #[test] + fn when_nested_maybe() { + compiles_to_ir( + r#" + Maybe a : [ Nothing, Just a ] + + x : Maybe (Maybe Int) + x = Just (Just 41) + + when x is + Just (Just v) -> v + 0x1 + _ -> 0x1 + "#, + indoc!( + r#" + procedure Num.14 (#Attr.2, #Attr.3): + let Test.5 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.5; + + let Test.16 = 2i64; + let Test.17 = 3i64; + let Test.2 = Struct {Test.16, Test.17}; + let Test.7 = true; + let Test.8 = 4i64; + let Test.9 = Index 0 Test.2; + let Test.14 = lowlevel Eq Test.8 Test.9; + let Test.13 = lowlevel And Test.14 Test.6; + let Test.10 = 3i64; + let Test.11 = Index 1 Test.2; + let Test.12 = lowlevel Eq Test.10 Test.11; + let Test.6 = lowlevel And Test.12 Test.7; + if Test.6 then + let Test.3 = 9i64; + ret Test.3; + else + let Test.1 = Index 1 Test.2; + let Test.0 = Index 0 Test.2; + let Test.4 = CallByName Num.14 Test.0 Test.1; + ret Test.4; + "# + ), + ) + } + + #[test] + fn when_on_two_values() { + compiles_to_ir( + r#" + when Pair 2 3 is + Pair 4 3 -> 9 + Pair a b -> a + b + "#, + indoc!( + r#" + procedure Num.14 (#Attr.2, #Attr.3): + let Test.5 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.5; + + let Test.16 = 2i64; + let Test.17 = 3i64; + let Test.2 = Struct {Test.16, Test.17}; + let Test.7 = true; + let Test.8 = 4i64; + let Test.9 = Index 0 Test.2; + let Test.14 = lowlevel Eq Test.8 Test.9; + let Test.13 = lowlevel And Test.14 Test.6; + let Test.10 = 3i64; + let Test.11 = Index 1 Test.2; + let Test.12 = lowlevel Eq Test.10 Test.11; + let Test.6 = lowlevel And Test.12 Test.7; + if Test.6 then + let Test.3 = 9i64; + ret Test.3; + else + let Test.1 = Index 1 Test.2; + let Test.0 = Index 0 Test.2; + let Test.4 = CallByName Num.14 Test.0 Test.1; + ret Test.4; + "# + ), + ) + } }