mirror of
https://github.com/HigherOrderCO/Bend.git
synced 2024-09-17 14:47:21 +03:00
Add inference for do notation
This commit is contained in:
parent
6eb2a45b03
commit
16fcf89798
@ -95,7 +95,7 @@ pub fn desugar_book(
|
||||
|
||||
ctx.resolve_refs()?;
|
||||
|
||||
ctx.book.apply_bnd();
|
||||
ctx.apply_bnd()?;
|
||||
|
||||
ctx.desugar_match_defs()?;
|
||||
|
||||
|
@ -51,11 +51,11 @@ impl fmt::Display for Term {
|
||||
Term::Let { pat, val, nxt } => {
|
||||
write!(f, "let {} = {}; {}", pat, val, nxt)
|
||||
}
|
||||
Term::Bnd { fun, ask, val, nxt } => {
|
||||
write!(f, "do {fun} {{ ")?;
|
||||
Term::Bnd { typ, ask, val, nxt } => {
|
||||
write!(f, "do {typ} {{ ")?;
|
||||
write!(f, "ask {} = {}; ", ask, val)?;
|
||||
let mut cur = nxt;
|
||||
while let Term::Bnd { fun: _, ask, val, nxt } = &**cur {
|
||||
while let Term::Bnd { typ: _, ask, val, nxt } = &**cur {
|
||||
cur = nxt;
|
||||
write!(f, "ask {} = {}; ", ask, val)?;
|
||||
}
|
||||
@ -277,11 +277,11 @@ impl Term {
|
||||
write!(f, "let {} = {};\n{:tab$}{}", pat, val.display_pretty(tab), "", nxt.display_pretty(tab))
|
||||
}
|
||||
|
||||
Term::Bnd { fun, ask, val, nxt } => {
|
||||
writeln!(f, "do {fun} {{")?;
|
||||
Term::Bnd { typ, ask, val, nxt } => {
|
||||
writeln!(f, "do {typ} {{")?;
|
||||
writeln!(f, "{:tab$}ask {} = {};", "", ask, val.display_pretty(tab + 2), tab = tab + 2)?;
|
||||
let mut cur = nxt;
|
||||
while let Term::Bnd { fun: _, ask, val, nxt } = &**cur {
|
||||
while let Term::Bnd { typ: _, ask, val, nxt } = &**cur {
|
||||
cur = nxt;
|
||||
writeln!(f, "{:tab$}ask {} = {};", "", ask, val.display_pretty(tab + 2), tab = tab + 2)?;
|
||||
}
|
||||
|
@ -88,7 +88,7 @@ pub enum Term {
|
||||
nxt: Box<Term>,
|
||||
},
|
||||
Bnd {
|
||||
fun: Name,
|
||||
typ: Name,
|
||||
ask: Box<Pattern>,
|
||||
val: Box<Term>,
|
||||
nxt: Box<Term>,
|
||||
@ -321,8 +321,8 @@ impl Clone for Term {
|
||||
Self::Var { nam } => Self::Var { nam: nam.clone() },
|
||||
Self::Lnk { nam } => Self::Lnk { nam: nam.clone() },
|
||||
Self::Let { pat, val, nxt } => Self::Let { pat: pat.clone(), val: val.clone(), nxt: nxt.clone() },
|
||||
Self::Bnd { fun, ask, val, nxt } => {
|
||||
Self::Bnd { fun: fun.clone(), ask: ask.clone(), val: val.clone(), nxt: nxt.clone() }
|
||||
Self::Bnd { typ: fun, ask, val, nxt } => {
|
||||
Self::Bnd { typ: fun.clone(), ask: ask.clone(), val: val.clone(), nxt: nxt.clone() }
|
||||
}
|
||||
Self::Use { nam, val, nxt } => Self::Use { nam: nam.clone(), val: val.clone(), nxt: nxt.clone() },
|
||||
Self::App { tag, fun, arg } => Self::App { tag: tag.clone(), fun: fun.clone(), arg: arg.clone() },
|
||||
|
@ -412,9 +412,9 @@ impl<'a> TermParser<'a> {
|
||||
// Do
|
||||
if self.try_consume_keyword("do") {
|
||||
unexpected_tag(self)?;
|
||||
let fun = self.parse_name()?;
|
||||
let typ = self.parse_name()?;
|
||||
self.consume("{")?;
|
||||
let ask = self.parse_ask(Name::new(fun))?;
|
||||
let ask = self.parse_ask(Name::new(typ))?;
|
||||
self.consume("}")?;
|
||||
return Ok(ask);
|
||||
}
|
||||
@ -426,15 +426,15 @@ impl<'a> TermParser<'a> {
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_ask(&mut self, fun: Name) -> Result<Term, String> {
|
||||
fn parse_ask(&mut self, typ: Name) -> Result<Term, String> {
|
||||
maybe_grow(|| {
|
||||
if self.try_consume_keyword("ask") {
|
||||
let ask = self.parse_pattern(true)?;
|
||||
self.consume("=")?;
|
||||
let val = self.parse_term()?;
|
||||
self.try_consume(";");
|
||||
let nxt = self.parse_ask(fun.clone())?;
|
||||
Ok(Term::Bnd { fun, ask: Box::new(ask), val: Box::new(val), nxt: Box::new(nxt) })
|
||||
let nxt = self.parse_ask(typ.clone())?;
|
||||
Ok(Term::Bnd { typ, ask: Box::new(ask), val: Box::new(val), nxt: Box::new(nxt) })
|
||||
} else {
|
||||
self.parse_term()
|
||||
}
|
||||
|
@ -1,35 +1,71 @@
|
||||
use core::fmt;
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::{
|
||||
diagnostics::Diagnostics,
|
||||
maybe_grow,
|
||||
term::{Book, Pattern, Term},
|
||||
term::{Ctx, Name, Pattern, Term},
|
||||
};
|
||||
|
||||
impl Book {
|
||||
pub fn apply_bnd(&mut self) {
|
||||
for def in self.defs.values_mut() {
|
||||
pub struct MonadicBindError {
|
||||
expected_def: Name,
|
||||
type_name: Name,
|
||||
}
|
||||
|
||||
impl Ctx<'_> {
|
||||
pub fn apply_bnd(&mut self) -> Result<(), Diagnostics> {
|
||||
self.info.start_pass();
|
||||
|
||||
let def_names = self.book.defs.keys().cloned().collect::<HashSet<_>>();
|
||||
|
||||
for def in self.book.defs.values_mut() {
|
||||
for rule in def.rules.iter_mut() {
|
||||
rule.body.apply_bnd();
|
||||
if let Err(e) = rule.body.apply_bnd(&def_names) {
|
||||
self.info.add_rule_error(e, def.name.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.info.fatal(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Term {
|
||||
pub fn apply_bnd(&mut self) {
|
||||
pub fn apply_bnd(&mut self, def_names: &HashSet<Name>) -> Result<(), MonadicBindError> {
|
||||
maybe_grow(|| {
|
||||
for children in self.children_mut() {
|
||||
children.apply_bnd();
|
||||
}
|
||||
});
|
||||
if let Term::Bnd { typ, ask, val, nxt } = self {
|
||||
let fun = make_fun_name(typ);
|
||||
|
||||
if let Term::Bnd { fun, ask, val, nxt } = self {
|
||||
let mut fvs = nxt.free_vars();
|
||||
ask.binds().flatten().for_each(|bind| _ = fvs.remove(bind));
|
||||
let fvs = fvs.into_keys().collect::<Vec<_>>();
|
||||
let nxt =
|
||||
fvs.iter().fold(*nxt.clone(), |nxt, nam| Term::lam(Pattern::Var(Some(nam.clone())), nxt.clone()));
|
||||
let nxt = Term::lam(*ask.clone(), nxt);
|
||||
let term = Term::call(Term::Ref { nam: fun.clone() }, [*val.clone(), nxt]);
|
||||
*self = Term::call(term, fvs.into_iter().map(|nam| Term::Var { nam }));
|
||||
}
|
||||
if def_names.contains(&fun) {
|
||||
let mut fvs = nxt.free_vars();
|
||||
ask.binds().flatten().for_each(|bind| _ = fvs.remove(bind));
|
||||
let fvs = fvs.into_keys().collect::<Vec<_>>();
|
||||
let nxt =
|
||||
fvs.iter().fold(*nxt.clone(), |nxt, nam| Term::lam(Pattern::Var(Some(nam.clone())), nxt.clone()));
|
||||
let nxt = Term::lam(*ask.clone(), nxt);
|
||||
let term = Term::call(Term::Ref { nam: fun.clone() }, [*val.clone(), nxt]);
|
||||
*self = Term::call(term, fvs.into_iter().map(|nam| Term::Var { nam }));
|
||||
} else {
|
||||
return Err(MonadicBindError { expected_def: fun, type_name: typ.clone() });
|
||||
}
|
||||
}
|
||||
|
||||
for children in self.children_mut() {
|
||||
children.apply_bnd(def_names)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn make_fun_name(typ: &mut Name) -> Name {
|
||||
let name: String = [typ, "/", "bind"].into_iter().collect();
|
||||
Name::new(name)
|
||||
}
|
||||
|
||||
impl fmt::Display for MonadicBindError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Could not find definition {} for type {}.", self.expected_def, self.type_name)
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
Result.bind (Result.ok val) f = (f val)
|
||||
Result.bind err _ = err
|
||||
Result/bind (Result.ok val) f = (f val)
|
||||
Result/bind err _ = err
|
||||
|
||||
safe_div a b = switch b {
|
||||
0: (Result.err "Div by 0")
|
||||
@ -11,7 +11,7 @@ safe_rem a b = switch b {
|
||||
_: (Result.ok (% a b))
|
||||
}
|
||||
|
||||
Main = do Result.bind {
|
||||
Main = do Result {
|
||||
ask y = (safe_div 3 2)
|
||||
ask x = (safe_rem y 0);
|
||||
x
|
||||
|
@ -1,9 +1,9 @@
|
||||
Result.bind (Result.ok val) f = (f val)
|
||||
Result.bind err _ = err
|
||||
Result/bind (Result.ok val) f = (f val)
|
||||
Result/bind err _ = err
|
||||
|
||||
Bar x = (Result.err 0)
|
||||
|
||||
Foo x y = do Result.bind {
|
||||
Foo x y = do Result {
|
||||
ask x = (Bar x);
|
||||
(Foo x y)
|
||||
}
|
||||
|
@ -2,13 +2,13 @@
|
||||
source: tests/golden_tests.rs
|
||||
input_file: tests/golden_tests/desugar_file/bind_syntax.hvm
|
||||
---
|
||||
(Result.bind) = λa λb (a Result.bind__C1 Result.bind__C0 b)
|
||||
(Result/bind) = λa λb (a Result/bind__C1 Result/bind__C0 b)
|
||||
|
||||
(safe_div) = λa λb (switch b { 0: λ* (Result.err (String.cons 68 (String.cons 105 (String.cons 118 (String.cons 32 (String.cons 98 (String.cons 121 (String.cons 32 (String.cons 48 String.nil))))))))); _: safe_div__C0; } a)
|
||||
|
||||
(safe_rem) = λa λb (switch b { 0: λ* (Result.err (String.cons 77 (String.cons 111 (String.cons 100 (String.cons 32 (String.cons 98 (String.cons 121 (String.cons 32 (String.cons 48 String.nil))))))))); _: safe_rem__C0; } a)
|
||||
|
||||
(Main) = (Result.bind Main__C1 Main__C0)
|
||||
(Main) = (Result/bind Main__C1 Main__C0)
|
||||
|
||||
(String.cons) = λa λb λc λ* (c a b)
|
||||
|
||||
@ -18,13 +18,13 @@ input_file: tests/golden_tests/desugar_file/bind_syntax.hvm
|
||||
|
||||
(Result.err) = λa λ* λb (b a)
|
||||
|
||||
(Main__C0) = λa (Result.bind (safe_rem a 0) λb b)
|
||||
(Main__C0) = λa (Result/bind (safe_rem a 0) λb b)
|
||||
|
||||
(Main__C1) = (safe_div 3 2)
|
||||
|
||||
(Result.bind__C0) = λa λ* (Result.err a)
|
||||
(Result/bind__C0) = λa λ* (Result.err a)
|
||||
|
||||
(Result.bind__C1) = λa λb (b a)
|
||||
(Result/bind__C1) = λa λb (b a)
|
||||
|
||||
(safe_div__C0) = λa λb (Result.ok (/ b (+ a 1)))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user