diff --git a/ast/src/constrain.rs b/ast/src/constrain.rs index ff53020294..05cd1908fc 100644 --- a/ast/src/constrain.rs +++ b/ast/src/constrain.rs @@ -20,6 +20,7 @@ use crate::{ expr2::{ClosureExtra, Expr2, ExprId, WhenBranch}, record_field::RecordField, }, + fun_def::FunctionDef, pattern::{DestructType, Pattern2, PatternId, PatternState2, RecordDestruct}, types::{Type2, TypeId}, val_def::ValueDef, @@ -818,6 +819,89 @@ pub fn constrain_expr<'a>( } } } + // In an expression like + // id = \x -> x + // + // id 1 + // The `def_id` refers to the definition `id = \x -> x`, + // and the body refers to `id 1`. + Expr2::LetFunction { + def_id, + body_id, + body_var, + } => { + let body = env.pool.get(*body_id); + let body_con = constrain_expr(arena, env, body, expected.shallow_clone(), region); + + let function_def = env.pool.get(*def_id); + + match function_def { + FunctionDef::WithAnnotation { .. } => { + todo!("implement constraint generation for {:?}", function_def) + } + FunctionDef::NoAnnotation { + name, + arguments, + body_id: expr_id, + return_var, + } => { + // A function definition is equivalent to a named value definition, where the + // value is a closure. So, we create a closure definition in correspondence + // with the function definition, generate type constraints for it, and demand + // that type of the function is just the type of the resolved closure. + let fn_var = env.var_store.fresh(); + let fn_ty = Type2::Variable(fn_var); + + let clos_var = env.var_store.fresh(); + let clos_ty = Type2::Variable(clos_var); + let extra = ClosureExtra { + return_type: env.var_store.fresh(), + captured_symbols: PoolVec::empty(env.pool), + closure_type: env.var_store.fresh(), + closure_ext_var: env.var_store.fresh(), + }; + let clos = Expr2::Closure { + args: arguments.shallow_clone(), + uniq_symbol: *name, + body_id: *expr_id, + function_type: clos_var, + extra: env.pool.add(extra), + recursive: roc_can::expr::Recursive::Recursive, + }; + let clos_con = constrain_expr( + arena, + env, + &clos, + Expected::NoExpectation(fn_ty.shallow_clone()), + region, + ); + + // This is the `foo` part in `foo = \...`. We want to bind the name of the + // function with its type, whose constraints we generated above. + let mut def_pattern_state = PatternState2 { + headers: BumpMap::new_in(arena), + vars: BumpVec::new_in(arena), + constraints: BumpVec::new_in(arena), + }; + def_pattern_state.headers.insert(*name, fn_ty); + def_pattern_state.vars.push(fn_var); + + Let(arena.alloc(LetConstraint { + rigid_vars: BumpVec::new_in(arena), // The function def is unannotated, so there are no rigid type vars + flex_vars: def_pattern_state.vars, + def_types: def_pattern_state.headers, // Binding function name -> its type + defs_constraint: Let(arena.alloc(LetConstraint { + rigid_vars: BumpVec::new_in(arena), // always empty + flex_vars: BumpVec::new_in(arena), // empty, because our functions have no arguments + def_types: BumpMap::new_in(arena), // empty, because our functions have no arguments + defs_constraint: And(def_pattern_state.constraints), + ret_constraint: clos_con, + })), + ret_constraint: body_con, + })) + } + } + } Expr2::Update { symbol, updates, @@ -1031,7 +1115,6 @@ pub fn constrain_expr<'a>( exists(arena, vars, And(and_constraints)) } Expr2::LetRec { .. } => todo!(), - Expr2::LetFunction { .. } => todo!(), } } @@ -2244,4 +2327,66 @@ pub mod test_constrain { "{}* -> Num *", ) } + + #[test] + fn recursive_identity() { + infer_eq( + indoc!( + r#" + identity = \val -> val + + identity + "# + ), + "a -> a", + ); + } + + #[test] + fn use_apply() { + infer_eq( + indoc!( + r#" + identity = \a -> a + apply = \f, x -> f x + + apply identity 5 + "# + ), + "Num *", + ); + } + + #[test] + fn nested_let_function() { + infer_eq( + indoc!( + r#" + curryPair = \a -> + getB = \b -> Pair a b + getB + + curryPair + "# + ), + "a -> (b -> [ Pair a b ]*)", + ); + } + + #[test] + fn record_with_bound_var() { + infer_eq( + indoc!( + r#" + fn = \rec -> + x = rec.x + + rec + + fn + "# + ), + "{ x : a }b -> { x : a }b", + ); + } } diff --git a/ast/src/lang/core/def/def.rs b/ast/src/lang/core/def/def.rs index 00e75c296e..850c20383e 100644 --- a/ast/src/lang/core/def/def.rs +++ b/ast/src/lang/core/def/def.rs @@ -689,14 +689,14 @@ fn canonicalize_pending_def<'a>( // parent commit for the bug this fixed! let refs = References::new(); - let arguments: PoolVec<(PatternId, Variable)> = + let arguments: PoolVec<(Variable, PatternId)> = PoolVec::with_capacity(closure_args.len() as u32, env.pool); let it: Vec<_> = closure_args.iter(env.pool).map(|(x, y)| (*x, *y)).collect(); for (node_id, (_, pattern_id)) in arguments.iter_node_ids().zip(it.into_iter()) { - env.pool[node_id] = (pattern_id, env.var_store.fresh()); + env.pool[node_id] = (env.var_store.fresh(), pattern_id); } let function_def = FunctionDef::NoAnnotation { diff --git a/ast/src/lang/core/fun_def.rs b/ast/src/lang/core/fun_def.rs index 02d3bdbab4..1cc1d78aaa 100644 --- a/ast/src/lang/core/fun_def.rs +++ b/ast/src/lang/core/fun_def.rs @@ -22,7 +22,7 @@ pub enum FunctionDef { }, NoAnnotation { name: Symbol, // 8B - arguments: PoolVec<(PatternId, Variable)>, // 8B + arguments: PoolVec<(Variable, PatternId)>, // 8B return_var: Variable, // 4B body_id: ExprId, // 4B }, diff --git a/ast/src/lang/core/mod.rs b/ast/src/lang/core/mod.rs index 74300dab4f..801f6afa18 100644 --- a/ast/src/lang/core/mod.rs +++ b/ast/src/lang/core/mod.rs @@ -2,7 +2,7 @@ pub mod ast; mod declaration; pub mod def; pub mod expr; -mod fun_def; +pub mod fun_def; pub mod header; pub mod pattern; pub mod str;