Implement constraint generation for Expr2::LetFunction

We do this by treating function definition bodies as equivalent to
closures, and piggy-backing on existing work to generate constraints
over closures. Then, we just bind the function name with the resolved
type of the function body.

Support for constraint generation in the presence of annotated functions
will be added later.
This commit is contained in:
ayazhafiz 2021-11-16 15:25:42 -05:00
parent b824302ab3
commit 3fe29c9949
4 changed files with 150 additions and 5 deletions

View File

@ -20,6 +20,7 @@ use crate::{
expr2::{ClosureExtra, Expr2, ExprId, WhenBranch},
record_field::RecordField,
},
fun_def::FunctionDef,
pattern::{DestructType, Pattern2, PatternId, PatternState2, RecordDestruct},
types::{Type2, TypeId},
val_def::ValueDef,
@ -818,6 +819,89 @@ pub fn constrain_expr<'a>(
}
}
}
// In an expression like
// id = \x -> x
//
// id 1
// The `def_id` refers to the definition `id = \x -> x`,
// and the body refers to `id 1`.
Expr2::LetFunction {
def_id,
body_id,
body_var,
} => {
let body = env.pool.get(*body_id);
let body_con = constrain_expr(arena, env, body, expected.shallow_clone(), region);
let function_def = env.pool.get(*def_id);
match function_def {
FunctionDef::WithAnnotation { .. } => {
todo!("implement constraint generation for {:?}", function_def)
}
FunctionDef::NoAnnotation {
name,
arguments,
body_id: expr_id,
return_var,
} => {
// A function definition is equivalent to a named value definition, where the
// value is a closure. So, we create a closure definition in correspondence
// with the function definition, generate type constraints for it, and demand
// that type of the function is just the type of the resolved closure.
let fn_var = env.var_store.fresh();
let fn_ty = Type2::Variable(fn_var);
let clos_var = env.var_store.fresh();
let clos_ty = Type2::Variable(clos_var);
let extra = ClosureExtra {
return_type: env.var_store.fresh(),
captured_symbols: PoolVec::empty(env.pool),
closure_type: env.var_store.fresh(),
closure_ext_var: env.var_store.fresh(),
};
let clos = Expr2::Closure {
args: arguments.shallow_clone(),
uniq_symbol: *name,
body_id: *expr_id,
function_type: clos_var,
extra: env.pool.add(extra),
recursive: roc_can::expr::Recursive::Recursive,
};
let clos_con = constrain_expr(
arena,
env,
&clos,
Expected::NoExpectation(fn_ty.shallow_clone()),
region,
);
// This is the `foo` part in `foo = \...`. We want to bind the name of the
// function with its type, whose constraints we generated above.
let mut def_pattern_state = PatternState2 {
headers: BumpMap::new_in(arena),
vars: BumpVec::new_in(arena),
constraints: BumpVec::new_in(arena),
};
def_pattern_state.headers.insert(*name, fn_ty);
def_pattern_state.vars.push(fn_var);
Let(arena.alloc(LetConstraint {
rigid_vars: BumpVec::new_in(arena), // The function def is unannotated, so there are no rigid type vars
flex_vars: def_pattern_state.vars,
def_types: def_pattern_state.headers, // Binding function name -> its type
defs_constraint: Let(arena.alloc(LetConstraint {
rigid_vars: BumpVec::new_in(arena), // always empty
flex_vars: BumpVec::new_in(arena), // empty, because our functions have no arguments
def_types: BumpMap::new_in(arena), // empty, because our functions have no arguments
defs_constraint: And(def_pattern_state.constraints),
ret_constraint: clos_con,
})),
ret_constraint: body_con,
}))
}
}
}
Expr2::Update {
symbol,
updates,
@ -1031,7 +1115,6 @@ pub fn constrain_expr<'a>(
exists(arena, vars, And(and_constraints))
}
Expr2::LetRec { .. } => todo!(),
Expr2::LetFunction { .. } => todo!(),
}
}
@ -2244,4 +2327,66 @@ pub mod test_constrain {
"{}* -> Num *",
)
}
#[test]
fn recursive_identity() {
infer_eq(
indoc!(
r#"
identity = \val -> val
identity
"#
),
"a -> a",
);
}
#[test]
fn use_apply() {
infer_eq(
indoc!(
r#"
identity = \a -> a
apply = \f, x -> f x
apply identity 5
"#
),
"Num *",
);
}
#[test]
fn nested_let_function() {
infer_eq(
indoc!(
r#"
curryPair = \a ->
getB = \b -> Pair a b
getB
curryPair
"#
),
"a -> (b -> [ Pair a b ]*)",
);
}
#[test]
fn record_with_bound_var() {
infer_eq(
indoc!(
r#"
fn = \rec ->
x = rec.x
rec
fn
"#
),
"{ x : a }b -> { x : a }b",
);
}
}

View File

@ -689,14 +689,14 @@ fn canonicalize_pending_def<'a>(
// parent commit for the bug this fixed!
let refs = References::new();
let arguments: PoolVec<(PatternId, Variable)> =
let arguments: PoolVec<(Variable, PatternId)> =
PoolVec::with_capacity(closure_args.len() as u32, env.pool);
let it: Vec<_> = closure_args.iter(env.pool).map(|(x, y)| (*x, *y)).collect();
for (node_id, (_, pattern_id)) in arguments.iter_node_ids().zip(it.into_iter())
{
env.pool[node_id] = (pattern_id, env.var_store.fresh());
env.pool[node_id] = (env.var_store.fresh(), pattern_id);
}
let function_def = FunctionDef::NoAnnotation {

View File

@ -22,7 +22,7 @@ pub enum FunctionDef {
},
NoAnnotation {
name: Symbol, // 8B
arguments: PoolVec<(PatternId, Variable)>, // 8B
arguments: PoolVec<(Variable, PatternId)>, // 8B
return_var: Variable, // 4B
body_id: ExprId, // 4B
},

View File

@ -2,7 +2,7 @@ pub mod ast;
mod declaration;
pub mod def;
pub mod expr;
mod fun_def;
pub mod fun_def;
pub mod header;
pub mod pattern;
pub mod str;