Add inference for do notation

This commit is contained in:
imaqtkatt 2024-05-07 11:47:07 -03:00
parent 6eb2a45b03
commit 16fcf89798
8 changed files with 82 additions and 46 deletions

View File

@ -95,7 +95,7 @@ pub fn desugar_book(
ctx.resolve_refs()?;
ctx.book.apply_bnd();
ctx.apply_bnd()?;
ctx.desugar_match_defs()?;

View File

@ -51,11 +51,11 @@ impl fmt::Display for Term {
Term::Let { pat, val, nxt } => {
write!(f, "let {} = {}; {}", pat, val, nxt)
}
Term::Bnd { fun, ask, val, nxt } => {
write!(f, "do {fun} {{ ")?;
Term::Bnd { typ, ask, val, nxt } => {
write!(f, "do {typ} {{ ")?;
write!(f, "ask {} = {}; ", ask, val)?;
let mut cur = nxt;
while let Term::Bnd { fun: _, ask, val, nxt } = &**cur {
while let Term::Bnd { typ: _, ask, val, nxt } = &**cur {
cur = nxt;
write!(f, "ask {} = {}; ", ask, val)?;
}
@ -277,11 +277,11 @@ impl Term {
write!(f, "let {} = {};\n{:tab$}{}", pat, val.display_pretty(tab), "", nxt.display_pretty(tab))
}
Term::Bnd { fun, ask, val, nxt } => {
writeln!(f, "do {fun} {{")?;
Term::Bnd { typ, ask, val, nxt } => {
writeln!(f, "do {typ} {{")?;
writeln!(f, "{:tab$}ask {} = {};", "", ask, val.display_pretty(tab + 2), tab = tab + 2)?;
let mut cur = nxt;
while let Term::Bnd { fun: _, ask, val, nxt } = &**cur {
while let Term::Bnd { typ: _, ask, val, nxt } = &**cur {
cur = nxt;
writeln!(f, "{:tab$}ask {} = {};", "", ask, val.display_pretty(tab + 2), tab = tab + 2)?;
}

View File

@ -88,7 +88,7 @@ pub enum Term {
nxt: Box<Term>,
},
Bnd {
fun: Name,
typ: Name,
ask: Box<Pattern>,
val: Box<Term>,
nxt: Box<Term>,
@ -321,8 +321,8 @@ impl Clone for Term {
Self::Var { nam } => Self::Var { nam: nam.clone() },
Self::Lnk { nam } => Self::Lnk { nam: nam.clone() },
Self::Let { pat, val, nxt } => Self::Let { pat: pat.clone(), val: val.clone(), nxt: nxt.clone() },
Self::Bnd { fun, ask, val, nxt } => {
Self::Bnd { fun: fun.clone(), ask: ask.clone(), val: val.clone(), nxt: nxt.clone() }
Self::Bnd { typ: fun, ask, val, nxt } => {
Self::Bnd { typ: fun.clone(), ask: ask.clone(), val: val.clone(), nxt: nxt.clone() }
}
Self::Use { nam, val, nxt } => Self::Use { nam: nam.clone(), val: val.clone(), nxt: nxt.clone() },
Self::App { tag, fun, arg } => Self::App { tag: tag.clone(), fun: fun.clone(), arg: arg.clone() },

View File

@ -412,9 +412,9 @@ impl<'a> TermParser<'a> {
// Do
if self.try_consume_keyword("do") {
unexpected_tag(self)?;
let fun = self.parse_name()?;
let typ = self.parse_name()?;
self.consume("{")?;
let ask = self.parse_ask(Name::new(fun))?;
let ask = self.parse_ask(Name::new(typ))?;
self.consume("}")?;
return Ok(ask);
}
@ -426,15 +426,15 @@ impl<'a> TermParser<'a> {
})
}
fn parse_ask(&mut self, fun: Name) -> Result<Term, String> {
fn parse_ask(&mut self, typ: Name) -> Result<Term, String> {
maybe_grow(|| {
if self.try_consume_keyword("ask") {
let ask = self.parse_pattern(true)?;
self.consume("=")?;
let val = self.parse_term()?;
self.try_consume(";");
let nxt = self.parse_ask(fun.clone())?;
Ok(Term::Bnd { fun, ask: Box::new(ask), val: Box::new(val), nxt: Box::new(nxt) })
let nxt = self.parse_ask(typ.clone())?;
Ok(Term::Bnd { typ, ask: Box::new(ask), val: Box::new(val), nxt: Box::new(nxt) })
} else {
self.parse_term()
}

View File

@ -1,35 +1,71 @@
use core::fmt;
use std::collections::HashSet;
use crate::{
diagnostics::Diagnostics,
maybe_grow,
term::{Book, Pattern, Term},
term::{Ctx, Name, Pattern, Term},
};
impl Book {
pub fn apply_bnd(&mut self) {
for def in self.defs.values_mut() {
pub struct MonadicBindError {
expected_def: Name,
type_name: Name,
}
impl Ctx<'_> {
pub fn apply_bnd(&mut self) -> Result<(), Diagnostics> {
self.info.start_pass();
let def_names = self.book.defs.keys().cloned().collect::<HashSet<_>>();
for def in self.book.defs.values_mut() {
for rule in def.rules.iter_mut() {
rule.body.apply_bnd();
if let Err(e) = rule.body.apply_bnd(&def_names) {
self.info.add_rule_error(e, def.name.clone());
}
}
}
self.info.fatal(())
}
}
impl Term {
pub fn apply_bnd(&mut self) {
pub fn apply_bnd(&mut self, def_names: &HashSet<Name>) -> Result<(), MonadicBindError> {
maybe_grow(|| {
for children in self.children_mut() {
children.apply_bnd();
}
});
if let Term::Bnd { typ, ask, val, nxt } = self {
let fun = make_fun_name(typ);
if let Term::Bnd { fun, ask, val, nxt } = self {
let mut fvs = nxt.free_vars();
ask.binds().flatten().for_each(|bind| _ = fvs.remove(bind));
let fvs = fvs.into_keys().collect::<Vec<_>>();
let nxt =
fvs.iter().fold(*nxt.clone(), |nxt, nam| Term::lam(Pattern::Var(Some(nam.clone())), nxt.clone()));
let nxt = Term::lam(*ask.clone(), nxt);
let term = Term::call(Term::Ref { nam: fun.clone() }, [*val.clone(), nxt]);
*self = Term::call(term, fvs.into_iter().map(|nam| Term::Var { nam }));
}
if def_names.contains(&fun) {
let mut fvs = nxt.free_vars();
ask.binds().flatten().for_each(|bind| _ = fvs.remove(bind));
let fvs = fvs.into_keys().collect::<Vec<_>>();
let nxt =
fvs.iter().fold(*nxt.clone(), |nxt, nam| Term::lam(Pattern::Var(Some(nam.clone())), nxt.clone()));
let nxt = Term::lam(*ask.clone(), nxt);
let term = Term::call(Term::Ref { nam: fun.clone() }, [*val.clone(), nxt]);
*self = Term::call(term, fvs.into_iter().map(|nam| Term::Var { nam }));
} else {
return Err(MonadicBindError { expected_def: fun, type_name: typ.clone() });
}
}
for children in self.children_mut() {
children.apply_bnd(def_names)?;
}
Ok(())
})
}
}
fn make_fun_name(typ: &mut Name) -> Name {
let name: String = [typ, "/", "bind"].into_iter().collect();
Name::new(name)
}
impl fmt::Display for MonadicBindError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Could not find definition {} for type {}.", self.expected_def, self.type_name)
}
}

View File

@ -1,5 +1,5 @@
Result.bind (Result.ok val) f = (f val)
Result.bind err _ = err
Result/bind (Result.ok val) f = (f val)
Result/bind err _ = err
safe_div a b = switch b {
0: (Result.err "Div by 0")
@ -11,7 +11,7 @@ safe_rem a b = switch b {
_: (Result.ok (% a b))
}
Main = do Result.bind {
Main = do Result {
ask y = (safe_div 3 2)
ask x = (safe_rem y 0);
x

View File

@ -1,9 +1,9 @@
Result.bind (Result.ok val) f = (f val)
Result.bind err _ = err
Result/bind (Result.ok val) f = (f val)
Result/bind err _ = err
Bar x = (Result.err 0)
Foo x y = do Result.bind {
Foo x y = do Result {
ask x = (Bar x);
(Foo x y)
}

View File

@ -2,13 +2,13 @@
source: tests/golden_tests.rs
input_file: tests/golden_tests/desugar_file/bind_syntax.hvm
---
(Result.bind) = λa λb (a Result.bind__C1 Result.bind__C0 b)
(Result/bind) = λa λb (a Result/bind__C1 Result/bind__C0 b)
(safe_div) = λa λb (switch b { 0: λ* (Result.err (String.cons 68 (String.cons 105 (String.cons 118 (String.cons 32 (String.cons 98 (String.cons 121 (String.cons 32 (String.cons 48 String.nil))))))))); _: safe_div__C0; } a)
(safe_rem) = λa λb (switch b { 0: λ* (Result.err (String.cons 77 (String.cons 111 (String.cons 100 (String.cons 32 (String.cons 98 (String.cons 121 (String.cons 32 (String.cons 48 String.nil))))))))); _: safe_rem__C0; } a)
(Main) = (Result.bind Main__C1 Main__C0)
(Main) = (Result/bind Main__C1 Main__C0)
(String.cons) = λa λb λc λ* (c a b)
@ -18,13 +18,13 @@ input_file: tests/golden_tests/desugar_file/bind_syntax.hvm
(Result.err) = λa λ* λb (b a)
(Main__C0) = λa (Result.bind (safe_rem a 0) λb b)
(Main__C0) = λa (Result/bind (safe_rem a 0) λb b)
(Main__C1) = (safe_div 3 2)
(Result.bind__C0) = λa λ* (Result.err a)
(Result/bind__C0) = λa λ* (Result.err a)
(Result.bind__C1) = λa λb (b a)
(Result/bind__C1) = λa λb (b a)
(safe_div__C0) = λa λb (Result.ok (/ b (+ a 1)))