mirror of
https://github.com/HigherOrderCO/Bend.git
synced 2024-10-26 05:50:18 +03:00
#623 Desugar local defs properly
This commit is contained in:
parent
f288ae7f12
commit
fda44b8910
@ -70,6 +70,13 @@ impl Term {
|
||||
Term::List { els } => *self = Term::encode_list(std::mem::take(els)),
|
||||
Term::Str { val } => *self = Term::encode_str(val),
|
||||
Term::Nat { val } => *self = Term::encode_nat(*val),
|
||||
Term::Def { def, nxt } => {
|
||||
for rule in def.rules.iter_mut() {
|
||||
rule.pats.iter_mut().for_each(Pattern::encode_builtins);
|
||||
rule.body.encode_builtins();
|
||||
}
|
||||
nxt.encode_builtins();
|
||||
}
|
||||
_ => {
|
||||
for child in self.children_mut() {
|
||||
child.encode_builtins();
|
||||
|
@ -167,10 +167,10 @@ impl fmt::Display for Term {
|
||||
}
|
||||
Term::List { els } => write!(f, "[{}]", DisplayJoin(|| els.iter(), ", "),),
|
||||
Term::Open { typ, var, bod } => write!(f, "open {typ} {var}; {bod}"),
|
||||
Term::Def { nam, rules, nxt } => {
|
||||
Term::Def { def, nxt } => {
|
||||
write!(f, "def ")?;
|
||||
for rule in rules.iter() {
|
||||
write!(f, "{}", rule.display(nam))?;
|
||||
for rule in def.rules.iter() {
|
||||
write!(f, "{}", rule.display(&def.name))?;
|
||||
}
|
||||
write!(f, "{nxt}")
|
||||
}
|
||||
@ -500,13 +500,13 @@ impl Term {
|
||||
Term::Num { val: Num::F24(val) } => write!(f, "{val:.3}"),
|
||||
Term::Str { val } => write!(f, "{val:?}"),
|
||||
Term::Ref { nam } => write!(f, "{nam}"),
|
||||
Term::Def { nam, rules, nxt } => {
|
||||
Term::Def { def, nxt } => {
|
||||
write!(f, "def ")?;
|
||||
for (i, rule) in rules.iter().enumerate() {
|
||||
for (i, rule) in def.rules.iter().enumerate() {
|
||||
if i == 0 {
|
||||
writeln!(f, "{}", rule.display_def_aux(nam, tab + 4))?;
|
||||
writeln!(f, "{}", rule.display_def_aux(&def.name, tab + 4))?;
|
||||
} else {
|
||||
writeln!(f, "{:tab$}{}", "", rule.display_def_aux(nam, tab + 4), tab = tab + 4)?;
|
||||
writeln!(f, "{:tab$}{}", "", rule.display_def_aux(&def.name, tab + 4), tab = tab + 4)?;
|
||||
}
|
||||
}
|
||||
write!(f, "{:tab$}{}", "", nxt.display_pretty(tab))
|
||||
|
@ -65,14 +65,14 @@ pub type Adts = IndexMap<Name, Adt>;
|
||||
pub type Constructors = IndexMap<Name, Name>;
|
||||
|
||||
/// A pattern matching function definition.
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct Definition {
|
||||
pub name: Name,
|
||||
pub rules: Vec<Rule>,
|
||||
pub source: Source,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub enum Source {
|
||||
Builtin,
|
||||
/// Was generated by the compiler.
|
||||
@ -199,9 +199,11 @@ pub enum Term {
|
||||
nam: Name,
|
||||
},
|
||||
Def {
|
||||
nam: Name,
|
||||
rules: Vec<Rule>,
|
||||
def: Definition,
|
||||
nxt: Box<Term>,
|
||||
// nam: Name,
|
||||
// rules: Vec<Rule>,
|
||||
// nxt: Box<Term>,
|
||||
},
|
||||
Era,
|
||||
#[default]
|
||||
@ -402,7 +404,7 @@ impl Clone for Term {
|
||||
Self::Open { typ: typ.clone(), var: var.clone(), bod: nxt.clone() }
|
||||
}
|
||||
Self::Ref { nam } => Self::Ref { nam: nam.clone() },
|
||||
Self::Def { nam, rules, nxt } => Self::Def { nam: nam.clone(), rules: rules.clone(), nxt: nxt.clone() },
|
||||
Self::Def { def, nxt } => Self::Def { def: def.clone(), nxt: nxt.clone() },
|
||||
Self::Era => Self::Era,
|
||||
Self::Err => Self::Err,
|
||||
})
|
||||
|
@ -631,7 +631,8 @@ impl<'a> TermParser<'a> {
|
||||
if name == "def" {
|
||||
// parse the nxt def term.
|
||||
self.index = nxt_def;
|
||||
return Ok(Term::Def { nam: cur_name, rules, nxt: Box::new(self.parse_term()?) });
|
||||
let def = FunDefinition::new(name, rules, Source::Local(nxt_def..*self.index()));
|
||||
return Ok(Term::Def { def, nxt: Box::new(self.parse_term()?) });
|
||||
}
|
||||
if name == cur_name {
|
||||
rules.push(rule);
|
||||
@ -648,7 +649,8 @@ impl<'a> TermParser<'a> {
|
||||
}
|
||||
}
|
||||
let nxt = self.parse_term()?;
|
||||
return Ok(Term::Def { nam: cur_name, rules, nxt: Box::new(nxt) });
|
||||
let def = FunDefinition::new(cur_name, rules, Source::Local(nxt_term..*self.index()));
|
||||
return Ok(Term::Def { def, nxt: Box::new(nxt) });
|
||||
}
|
||||
|
||||
// If
|
||||
|
@ -38,7 +38,9 @@ impl Ctx<'_> {
|
||||
impl Definition {
|
||||
pub fn desugar_match_def(&mut self, ctrs: &Constructors, adts: &Adts) -> Vec<DesugarMatchDefErr> {
|
||||
let mut errs = vec![];
|
||||
|
||||
for rule in self.rules.iter_mut() {
|
||||
desugar_inner_match_defs(&mut rule.body, ctrs, adts, &mut errs);
|
||||
}
|
||||
let repeated_bind_errs = fix_repeated_binds(&mut self.rules);
|
||||
errs.extend(repeated_bind_errs);
|
||||
|
||||
@ -55,6 +57,25 @@ impl Definition {
|
||||
}
|
||||
}
|
||||
|
||||
fn desugar_inner_match_defs(
|
||||
term: &mut Term,
|
||||
ctrs: &Constructors,
|
||||
adts: &Adts,
|
||||
errs: &mut Vec<DesugarMatchDefErr>,
|
||||
) {
|
||||
maybe_grow(|| match term {
|
||||
Term::Def { def, nxt } => {
|
||||
errs.extend(def.desugar_match_def(ctrs, adts));
|
||||
desugar_inner_match_defs(nxt, ctrs, adts, errs);
|
||||
}
|
||||
_ => {
|
||||
for child in term.children_mut() {
|
||||
desugar_inner_match_defs(child, ctrs, adts, errs);
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// When a rule has repeated bind, the only one that is actually useful is the last one.
|
||||
///
|
||||
/// Example: In `(Foo x x x x) = x`, the function should return the fourth argument.
|
||||
|
@ -23,7 +23,9 @@ impl Ctx<'_> {
|
||||
impl Term {
|
||||
fn desugar_open(&mut self, adts: &Adts) -> Result<(), String> {
|
||||
maybe_grow(|| {
|
||||
if let Term::Open { typ, var, bod } = self {
|
||||
match self {
|
||||
Term::Open { typ, var, bod } => {
|
||||
bod.desugar_open(adts)?;
|
||||
if let Some(adt) = adts.get(&*typ) {
|
||||
if adt.ctrs.len() == 1 {
|
||||
let ctr = adt.ctrs.keys().next().unwrap();
|
||||
@ -41,9 +43,18 @@ impl Term {
|
||||
return Err(format!("Type '{typ}' of an 'open' is not defined"));
|
||||
}
|
||||
}
|
||||
Term::Def { def, nxt } => {
|
||||
for rule in def.rules.iter_mut() {
|
||||
rule.body.desugar_open(adts)?;
|
||||
}
|
||||
nxt.desugar_open(adts)?;
|
||||
}
|
||||
_ => {
|
||||
for child in self.children_mut() {
|
||||
child.desugar_open(adts)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::{
|
||||
diagnostics::Diagnostics,
|
||||
fun::{Adts, Constructors, Ctx, Pattern},
|
||||
fun::{Adts, Constructors, Ctx, Pattern, Rule, Term},
|
||||
};
|
||||
|
||||
impl Ctx<'_> {
|
||||
@ -15,18 +15,7 @@ impl Ctx<'_> {
|
||||
|
||||
let def_arity = def.arity();
|
||||
for rule in &mut def.rules {
|
||||
if rule.arity() != def_arity {
|
||||
errs.push(format!(
|
||||
"Incorrect pattern matching rule arity. Expected {} args, found {}.",
|
||||
def_arity,
|
||||
rule.arity()
|
||||
));
|
||||
}
|
||||
|
||||
for pat in &mut rule.pats {
|
||||
pat.resolve_pat(&self.book.ctrs);
|
||||
pat.check_good_ctr(&self.book.ctrs, &self.book.adts, &mut errs);
|
||||
}
|
||||
rule.fix_match_defs(def_arity, &self.book.ctrs, &self.book.adts, &mut errs);
|
||||
}
|
||||
|
||||
for err in errs {
|
||||
@ -38,6 +27,44 @@ impl Ctx<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
impl Rule {
|
||||
fn fix_match_defs(&mut self, def_arity: usize, ctrs: &Constructors, adts: &Adts, errs: &mut Vec<String>) {
|
||||
if self.arity() != def_arity {
|
||||
errs.push(format!(
|
||||
"Incorrect pattern matching rule arity. Expected {} args, found {}.",
|
||||
def_arity,
|
||||
self.arity()
|
||||
));
|
||||
}
|
||||
|
||||
for pat in &mut self.pats {
|
||||
pat.resolve_pat(ctrs);
|
||||
pat.check_good_ctr(ctrs, adts, errs);
|
||||
}
|
||||
|
||||
self.body.fix_match_defs(ctrs, adts, errs);
|
||||
}
|
||||
}
|
||||
|
||||
impl Term {
|
||||
fn fix_match_defs(&mut self, ctrs: &Constructors, adts: &Adts, errs: &mut Vec<String>) {
|
||||
match self {
|
||||
Term::Def { def, nxt } => {
|
||||
let def_arity = def.arity();
|
||||
for rule in &mut def.rules {
|
||||
rule.fix_match_defs(def_arity, ctrs, adts, errs);
|
||||
}
|
||||
nxt.fix_match_defs(ctrs, adts, errs);
|
||||
}
|
||||
_ => {
|
||||
for children in self.children_mut() {
|
||||
children.fix_match_defs(ctrs, adts, errs);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Pattern {
|
||||
/// If a var pattern actually refers to an ADT constructor, convert it into a constructor pattern.
|
||||
fn resolve_pat(&mut self, ctrs: &Constructors) {
|
||||
|
@ -84,8 +84,14 @@ impl Term {
|
||||
if matches!(self, Term::Mat { .. } | Term::Fold { .. }) {
|
||||
self.fix_match(&mut errs, ctrs, adts);
|
||||
}
|
||||
// Add a use term to each arm rebuilding the matched variable
|
||||
match self {
|
||||
Term::Def { def, nxt } => {
|
||||
for rule in def.rules.iter_mut() {
|
||||
errs.extend(rule.body.fix_match_terms(ctrs, adts));
|
||||
}
|
||||
errs.extend(nxt.fix_match_terms(ctrs, adts));
|
||||
}
|
||||
// Add a use term to each arm rebuilding the matched variable
|
||||
Term::Mat { arg: _, bnd, with_bnd: _, with_arg: _, arms }
|
||||
| Term::Fold { bnd, arg: _, with_bnd: _, with_arg: _, arms } => {
|
||||
for (ctr, fields, body) in arms {
|
||||
|
@ -2,7 +2,10 @@ use std::collections::BTreeSet;
|
||||
|
||||
use indexmap::IndexMap;
|
||||
|
||||
use crate::fun::{Book, Definition, Name, Pattern, Rule, Term};
|
||||
use crate::{
|
||||
fun::{Book, Definition, Name, Pattern, Rule, Term},
|
||||
maybe_grow,
|
||||
};
|
||||
|
||||
impl Book {
|
||||
pub fn lift_local_defs(&mut self) {
|
||||
@ -25,10 +28,10 @@ impl Rule {
|
||||
|
||||
impl Term {
|
||||
pub fn lift_local_defs(&mut self, parent: &Name, defs: &mut IndexMap<Name, Definition>, gen: &mut usize) {
|
||||
match self {
|
||||
Term::Def { nam, rules, nxt } => {
|
||||
let local_name = Name::new(format!("{}__local_{}_{}", parent, gen, nam));
|
||||
for rule in rules.iter_mut() {
|
||||
maybe_grow(|| match self {
|
||||
Term::Def { def, nxt } => {
|
||||
let local_name = Name::new(format!("{}__local_{}_{}", parent, gen, def.name));
|
||||
for rule in def.rules.iter_mut() {
|
||||
rule.body.lift_local_defs(&local_name, defs, gen);
|
||||
}
|
||||
nxt.lift_local_defs(parent, defs, gen);
|
||||
@ -36,7 +39,8 @@ impl Term {
|
||||
|
||||
let inner_defs =
|
||||
defs.keys().filter(|name| name.starts_with(local_name.as_ref())).cloned().collect::<BTreeSet<_>>();
|
||||
let (r#use, fvs, mut rules) = gen_use(inner_defs, &local_name, nam, nxt, std::mem::take(rules));
|
||||
let (r#use, fvs, mut rules) =
|
||||
gen_use(inner_defs, &local_name, &def.name, nxt, std::mem::take(&mut def.rules));
|
||||
*self = r#use;
|
||||
|
||||
apply_closure(&mut rules, &fvs);
|
||||
@ -49,7 +53,7 @@ impl Term {
|
||||
child.lift_local_defs(parent, defs, gen);
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -89,9 +93,7 @@ fn gen_use(
|
||||
|
||||
fn apply_closure(rules: &mut [Rule], fvs: &BTreeSet<Name>) {
|
||||
for rule in rules.iter_mut() {
|
||||
let pats = std::mem::take(&mut rule.pats);
|
||||
let mut captured = fvs.iter().cloned().map(|nam| Pattern::Var(Some(nam))).collect::<Vec<_>>();
|
||||
captured.extend(pats);
|
||||
rule.pats = captured;
|
||||
let captured = fvs.iter().cloned().map(|nam| Some(nam)).collect::<Vec<_>>();
|
||||
rule.body = Term::rfold_lams(std::mem::take(&mut rule.body), captured.into_iter());
|
||||
}
|
||||
}
|
||||
|
@ -30,7 +30,8 @@ impl Ctx<'_> {
|
||||
push_scope(name.as_ref(), &mut scope);
|
||||
}
|
||||
|
||||
let res = rule.body.resolve_refs(&def_names, self.book.entrypoint.as_ref(), &mut scope);
|
||||
let res =
|
||||
rule.body.resolve_refs(&def_names, self.book.entrypoint.as_ref(), &mut scope, &mut self.info);
|
||||
self.info.take_rule_err(res, def_name.clone());
|
||||
}
|
||||
}
|
||||
@ -45,9 +46,11 @@ impl Term {
|
||||
def_names: &HashSet<Name>,
|
||||
main: Option<&Name>,
|
||||
scope: &mut HashMap<&'a Name, usize>,
|
||||
info: &mut Diagnostics,
|
||||
) -> Result<(), String> {
|
||||
maybe_grow(move || {
|
||||
if let Term::Var { nam } = self {
|
||||
match self {
|
||||
Term::Var { nam } => {
|
||||
if is_var_in_scope(nam, scope) {
|
||||
// If the variable is actually a reference to main, don't swap and return an error.
|
||||
if let Some(main) = main {
|
||||
@ -62,16 +65,31 @@ impl Term {
|
||||
}
|
||||
}
|
||||
}
|
||||
Term::Def { def, nxt } => {
|
||||
for rule in def.rules.iter_mut() {
|
||||
let mut scope = HashMap::new();
|
||||
|
||||
for name in rule.pats.iter().flat_map(Pattern::binds) {
|
||||
push_scope(name.as_ref(), &mut scope);
|
||||
}
|
||||
|
||||
let res = rule.body.resolve_refs(&def_names, main, &mut scope, info);
|
||||
info.take_rule_err(res, def.name.clone());
|
||||
}
|
||||
nxt.resolve_refs(def_names, main, scope, info)?;
|
||||
}
|
||||
_ => {
|
||||
for (child, binds) in self.children_mut_with_binds() {
|
||||
for bind in binds.clone() {
|
||||
push_scope(bind.as_ref(), scope);
|
||||
}
|
||||
child.resolve_refs(def_names, main, scope)?;
|
||||
child.resolve_refs(def_names, main, scope, info)?;
|
||||
for bind in binds.rev() {
|
||||
pop_scope(bind.as_ref(), scope);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
@ -391,7 +391,7 @@ impl Stmt {
|
||||
StmtToFun::Assign(pat, term) => (Some(pat), term),
|
||||
};
|
||||
let def = def.to_fun()?;
|
||||
let term = fun::Term::Def { nam: def.name, rules: def.rules, nxt: Box::new(nxt) };
|
||||
let term = fun::Term::Def { def, nxt: Box::new(nxt) };
|
||||
if let Some(pat) = nxt_pat {
|
||||
StmtToFun::Assign(pat, term)
|
||||
} else {
|
||||
|
@ -91,8 +91,6 @@ pub fn desugar_book(
|
||||
|
||||
ctx.set_entrypoint();
|
||||
|
||||
ctx.book.lift_local_defs();
|
||||
|
||||
ctx.book.encode_adts(opts.adt_encoding);
|
||||
|
||||
ctx.fix_match_defs()?;
|
||||
@ -109,6 +107,8 @@ pub fn desugar_book(
|
||||
|
||||
ctx.fix_match_terms()?;
|
||||
|
||||
ctx.book.lift_local_defs();
|
||||
|
||||
ctx.desugar_bend()?;
|
||||
ctx.desugar_fold()?;
|
||||
ctx.desugar_with_blocks()?;
|
||||
|
@ -2,4 +2,4 @@
|
||||
source: tests/golden_tests.rs
|
||||
input_file: tests/golden_tests/parse_file/fun_def.bend
|
||||
---
|
||||
(main) = let base = 0; def (aux []) = base(aux (List/Cons head tail)) = (+ head (aux tail))(aux [1, 2, 3])
|
||||
(main) = let base = 0; def (aux) = λ%arg0 match %arg0 = %arg0 { List/Nil: base; List/Cons %arg0.head %arg0.tail: use tail = %arg0.tail; use head = %arg0.head; (+ head (aux tail)); }(aux (List/Cons 1 (List/Cons 2 (List/Cons 3 List/Nil))))
|
||||
|
Loading…
Reference in New Issue
Block a user