From 3fe29c99498ec296e6bbcef738bee38cc4b1d7e3 Mon Sep 17 00:00:00 2001 From: ayazhafiz Date: Tue, 16 Nov 2021 15:25:42 -0500 Subject: [PATCH] Implement constraint generation for Expr2::LetFunction We do this by treating function definition bodies as equivalent to closures, and piggy-backing on existing work to generate constraints over closures. Then, we just bind the function name with the resolved type of the function body. Support for constraint generation in the presence of annotated functions will be added later. --- ast/src/constrain.rs | 147 ++++++++++++++++++++++++++++++++++- ast/src/lang/core/def/def.rs | 4 +- ast/src/lang/core/fun_def.rs | 2 +- ast/src/lang/core/mod.rs | 2 +- 4 files changed, 150 insertions(+), 5 deletions(-) diff --git a/ast/src/constrain.rs b/ast/src/constrain.rs index ff53020294..05cd1908fc 100644 --- a/ast/src/constrain.rs +++ b/ast/src/constrain.rs @@ -20,6 +20,7 @@ use crate::{ expr2::{ClosureExtra, Expr2, ExprId, WhenBranch}, record_field::RecordField, }, + fun_def::FunctionDef, pattern::{DestructType, Pattern2, PatternId, PatternState2, RecordDestruct}, types::{Type2, TypeId}, val_def::ValueDef, @@ -818,6 +819,89 @@ pub fn constrain_expr<'a>( } } } + // In an expression like + // id = \x -> x + // + // id 1 + // The `def_id` refers to the definition `id = \x -> x`, + // and the body refers to `id 1`. + Expr2::LetFunction { + def_id, + body_id, + body_var, + } => { + let body = env.pool.get(*body_id); + let body_con = constrain_expr(arena, env, body, expected.shallow_clone(), region); + + let function_def = env.pool.get(*def_id); + + match function_def { + FunctionDef::WithAnnotation { .. } => { + todo!("implement constraint generation for {:?}", function_def) + } + FunctionDef::NoAnnotation { + name, + arguments, + body_id: expr_id, + return_var, + } => { + // A function definition is equivalent to a named value definition, where the + // value is a closure. So, we create a closure definition in correspondence + // with the function definition, generate type constraints for it, and demand + // that type of the function is just the type of the resolved closure. + let fn_var = env.var_store.fresh(); + let fn_ty = Type2::Variable(fn_var); + + let clos_var = env.var_store.fresh(); + let clos_ty = Type2::Variable(clos_var); + let extra = ClosureExtra { + return_type: env.var_store.fresh(), + captured_symbols: PoolVec::empty(env.pool), + closure_type: env.var_store.fresh(), + closure_ext_var: env.var_store.fresh(), + }; + let clos = Expr2::Closure { + args: arguments.shallow_clone(), + uniq_symbol: *name, + body_id: *expr_id, + function_type: clos_var, + extra: env.pool.add(extra), + recursive: roc_can::expr::Recursive::Recursive, + }; + let clos_con = constrain_expr( + arena, + env, + &clos, + Expected::NoExpectation(fn_ty.shallow_clone()), + region, + ); + + // This is the `foo` part in `foo = \...`. We want to bind the name of the + // function with its type, whose constraints we generated above. + let mut def_pattern_state = PatternState2 { + headers: BumpMap::new_in(arena), + vars: BumpVec::new_in(arena), + constraints: BumpVec::new_in(arena), + }; + def_pattern_state.headers.insert(*name, fn_ty); + def_pattern_state.vars.push(fn_var); + + Let(arena.alloc(LetConstraint { + rigid_vars: BumpVec::new_in(arena), // The function def is unannotated, so there are no rigid type vars + flex_vars: def_pattern_state.vars, + def_types: def_pattern_state.headers, // Binding function name -> its type + defs_constraint: Let(arena.alloc(LetConstraint { + rigid_vars: BumpVec::new_in(arena), // always empty + flex_vars: BumpVec::new_in(arena), // empty, because our functions have no arguments + def_types: BumpMap::new_in(arena), // empty, because our functions have no arguments + defs_constraint: And(def_pattern_state.constraints), + ret_constraint: clos_con, + })), + ret_constraint: body_con, + })) + } + } + } Expr2::Update { symbol, updates, @@ -1031,7 +1115,6 @@ pub fn constrain_expr<'a>( exists(arena, vars, And(and_constraints)) } Expr2::LetRec { .. } => todo!(), - Expr2::LetFunction { .. } => todo!(), } } @@ -2244,4 +2327,66 @@ pub mod test_constrain { "{}* -> Num *", ) } + + #[test] + fn recursive_identity() { + infer_eq( + indoc!( + r#" + identity = \val -> val + + identity + "# + ), + "a -> a", + ); + } + + #[test] + fn use_apply() { + infer_eq( + indoc!( + r#" + identity = \a -> a + apply = \f, x -> f x + + apply identity 5 + "# + ), + "Num *", + ); + } + + #[test] + fn nested_let_function() { + infer_eq( + indoc!( + r#" + curryPair = \a -> + getB = \b -> Pair a b + getB + + curryPair + "# + ), + "a -> (b -> [ Pair a b ]*)", + ); + } + + #[test] + fn record_with_bound_var() { + infer_eq( + indoc!( + r#" + fn = \rec -> + x = rec.x + + rec + + fn + "# + ), + "{ x : a }b -> { x : a }b", + ); + } } diff --git a/ast/src/lang/core/def/def.rs b/ast/src/lang/core/def/def.rs index 00e75c296e..850c20383e 100644 --- a/ast/src/lang/core/def/def.rs +++ b/ast/src/lang/core/def/def.rs @@ -689,14 +689,14 @@ fn canonicalize_pending_def<'a>( // parent commit for the bug this fixed! let refs = References::new(); - let arguments: PoolVec<(PatternId, Variable)> = + let arguments: PoolVec<(Variable, PatternId)> = PoolVec::with_capacity(closure_args.len() as u32, env.pool); let it: Vec<_> = closure_args.iter(env.pool).map(|(x, y)| (*x, *y)).collect(); for (node_id, (_, pattern_id)) in arguments.iter_node_ids().zip(it.into_iter()) { - env.pool[node_id] = (pattern_id, env.var_store.fresh()); + env.pool[node_id] = (env.var_store.fresh(), pattern_id); } let function_def = FunctionDef::NoAnnotation { diff --git a/ast/src/lang/core/fun_def.rs b/ast/src/lang/core/fun_def.rs index 02d3bdbab4..1cc1d78aaa 100644 --- a/ast/src/lang/core/fun_def.rs +++ b/ast/src/lang/core/fun_def.rs @@ -22,7 +22,7 @@ pub enum FunctionDef { }, NoAnnotation { name: Symbol, // 8B - arguments: PoolVec<(PatternId, Variable)>, // 8B + arguments: PoolVec<(Variable, PatternId)>, // 8B return_var: Variable, // 4B body_id: ExprId, // 4B }, diff --git a/ast/src/lang/core/mod.rs b/ast/src/lang/core/mod.rs index 74300dab4f..801f6afa18 100644 --- a/ast/src/lang/core/mod.rs +++ b/ast/src/lang/core/mod.rs @@ -2,7 +2,7 @@ pub mod ast; mod declaration; pub mod def; pub mod expr; -mod fun_def; +pub mod fun_def; pub mod header; pub mod pattern; pub mod str;