From fda44b8910b045d670a431c55bf9a6ede6924bb5 Mon Sep 17 00:00:00 2001 From: imaqtkatt Date: Mon, 8 Jul 2024 17:26:34 -0300 Subject: [PATCH] #623 Desugar local defs properly --- src/fun/builtins.rs | 7 +++ src/fun/display.rs | 14 ++--- src/fun/mod.rs | 12 ++-- src/fun/parser.rs | 6 +- src/fun/transform/desugar_match_defs.rs | 23 +++++++- src/fun/transform/desugar_open.rs | 43 ++++++++------ src/fun/transform/fix_match_defs.rs | 53 +++++++++++++----- src/fun/transform/fix_match_terms.rs | 8 ++- src/fun/transform/lift_local_defs.rs | 24 ++++---- src/fun/transform/resolve_refs.rs | 56 ++++++++++++------- src/imp/to_fun.rs | 2 +- src/lib.rs | 4 +- tests/snapshots/parse_file__fun_def.bend.snap | 2 +- 13 files changed, 175 insertions(+), 79 deletions(-) diff --git a/src/fun/builtins.rs b/src/fun/builtins.rs index 23aa8740..882f750c 100644 --- a/src/fun/builtins.rs +++ b/src/fun/builtins.rs @@ -70,6 +70,13 @@ impl Term { Term::List { els } => *self = Term::encode_list(std::mem::take(els)), Term::Str { val } => *self = Term::encode_str(val), Term::Nat { val } => *self = Term::encode_nat(*val), + Term::Def { def, nxt } => { + for rule in def.rules.iter_mut() { + rule.pats.iter_mut().for_each(Pattern::encode_builtins); + rule.body.encode_builtins(); + } + nxt.encode_builtins(); + } _ => { for child in self.children_mut() { child.encode_builtins(); diff --git a/src/fun/display.rs b/src/fun/display.rs index 0330dd01..ad54a761 100644 --- a/src/fun/display.rs +++ b/src/fun/display.rs @@ -167,10 +167,10 @@ impl fmt::Display for Term { } Term::List { els } => write!(f, "[{}]", DisplayJoin(|| els.iter(), ", "),), Term::Open { typ, var, bod } => write!(f, "open {typ} {var}; {bod}"), - Term::Def { nam, rules, nxt } => { + Term::Def { def, nxt } => { write!(f, "def ")?; - for rule in rules.iter() { - write!(f, "{}", rule.display(nam))?; + for rule in def.rules.iter() { + write!(f, "{}", rule.display(&def.name))?; } write!(f, "{nxt}") } @@ -500,13 +500,13 @@ impl Term { Term::Num { val: Num::F24(val) } => write!(f, "{val:.3}"), Term::Str { val } => write!(f, "{val:?}"), Term::Ref { nam } => write!(f, "{nam}"), - Term::Def { nam, rules, nxt } => { + Term::Def { def, nxt } => { write!(f, "def ")?; - for (i, rule) in rules.iter().enumerate() { + for (i, rule) in def.rules.iter().enumerate() { if i == 0 { - writeln!(f, "{}", rule.display_def_aux(nam, tab + 4))?; + writeln!(f, "{}", rule.display_def_aux(&def.name, tab + 4))?; } else { - writeln!(f, "{:tab$}{}", "", rule.display_def_aux(nam, tab + 4), tab = tab + 4)?; + writeln!(f, "{:tab$}{}", "", rule.display_def_aux(&def.name, tab + 4), tab = tab + 4)?; } } write!(f, "{:tab$}{}", "", nxt.display_pretty(tab)) diff --git a/src/fun/mod.rs b/src/fun/mod.rs index 04a14d1b..6fc02ad4 100644 --- a/src/fun/mod.rs +++ b/src/fun/mod.rs @@ -65,14 +65,14 @@ pub type Adts = IndexMap; pub type Constructors = IndexMap; /// A pattern matching function definition. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Definition { pub name: Name, pub rules: Vec, pub source: Source, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Source { Builtin, /// Was generated by the compiler. @@ -199,9 +199,11 @@ pub enum Term { nam: Name, }, Def { - nam: Name, - rules: Vec, + def: Definition, nxt: Box, + // nam: Name, + // rules: Vec, + // nxt: Box, }, Era, #[default] @@ -402,7 +404,7 @@ impl Clone for Term { Self::Open { typ: typ.clone(), var: var.clone(), bod: nxt.clone() } } Self::Ref { nam } => Self::Ref { nam: nam.clone() }, - Self::Def { nam, rules, nxt } => Self::Def { nam: nam.clone(), rules: rules.clone(), nxt: nxt.clone() }, + Self::Def { def, nxt } => Self::Def { def: def.clone(), nxt: nxt.clone() }, Self::Era => Self::Era, Self::Err => Self::Err, }) diff --git a/src/fun/parser.rs b/src/fun/parser.rs index 72f4aca9..4a3a8d6c 100644 --- a/src/fun/parser.rs +++ b/src/fun/parser.rs @@ -631,7 +631,8 @@ impl<'a> TermParser<'a> { if name == "def" { // parse the nxt def term. self.index = nxt_def; - return Ok(Term::Def { nam: cur_name, rules, nxt: Box::new(self.parse_term()?) }); + let def = FunDefinition::new(name, rules, Source::Local(nxt_def..*self.index())); + return Ok(Term::Def { def, nxt: Box::new(self.parse_term()?) }); } if name == cur_name { rules.push(rule); @@ -648,7 +649,8 @@ impl<'a> TermParser<'a> { } } let nxt = self.parse_term()?; - return Ok(Term::Def { nam: cur_name, rules, nxt: Box::new(nxt) }); + let def = FunDefinition::new(cur_name, rules, Source::Local(nxt_term..*self.index())); + return Ok(Term::Def { def, nxt: Box::new(nxt) }); } // If diff --git a/src/fun/transform/desugar_match_defs.rs b/src/fun/transform/desugar_match_defs.rs index 36abcf69..72e147bb 100644 --- a/src/fun/transform/desugar_match_defs.rs +++ b/src/fun/transform/desugar_match_defs.rs @@ -38,7 +38,9 @@ impl Ctx<'_> { impl Definition { pub fn desugar_match_def(&mut self, ctrs: &Constructors, adts: &Adts) -> Vec { let mut errs = vec![]; - + for rule in self.rules.iter_mut() { + desugar_inner_match_defs(&mut rule.body, ctrs, adts, &mut errs); + } let repeated_bind_errs = fix_repeated_binds(&mut self.rules); errs.extend(repeated_bind_errs); @@ -55,6 +57,25 @@ impl Definition { } } +fn desugar_inner_match_defs( + term: &mut Term, + ctrs: &Constructors, + adts: &Adts, + errs: &mut Vec, +) { + maybe_grow(|| match term { + Term::Def { def, nxt } => { + errs.extend(def.desugar_match_def(ctrs, adts)); + desugar_inner_match_defs(nxt, ctrs, adts, errs); + } + _ => { + for child in term.children_mut() { + desugar_inner_match_defs(child, ctrs, adts, errs); + } + } + }) +} + /// When a rule has repeated bind, the only one that is actually useful is the last one. /// /// Example: In `(Foo x x x x) = x`, the function should return the fourth argument. diff --git a/src/fun/transform/desugar_open.rs b/src/fun/transform/desugar_open.rs index 5210f648..016f1f73 100644 --- a/src/fun/transform/desugar_open.rs +++ b/src/fun/transform/desugar_open.rs @@ -23,26 +23,37 @@ impl Ctx<'_> { impl Term { fn desugar_open(&mut self, adts: &Adts) -> Result<(), String> { maybe_grow(|| { - if let Term::Open { typ, var, bod } = self { - if let Some(adt) = adts.get(&*typ) { - if adt.ctrs.len() == 1 { - let ctr = adt.ctrs.keys().next().unwrap(); - *self = Term::Mat { - arg: Box::new(Term::Var { nam: var.clone() }), - bnd: Some(std::mem::take(var)), - with_bnd: vec![], - with_arg: vec![], - arms: vec![(Some(ctr.clone()), vec![], std::mem::take(bod))], + match self { + Term::Open { typ, var, bod } => { + bod.desugar_open(adts)?; + if let Some(adt) = adts.get(&*typ) { + if adt.ctrs.len() == 1 { + let ctr = adt.ctrs.keys().next().unwrap(); + *self = Term::Mat { + arg: Box::new(Term::Var { nam: var.clone() }), + bnd: Some(std::mem::take(var)), + with_bnd: vec![], + with_arg: vec![], + arms: vec![(Some(ctr.clone()), vec![], std::mem::take(bod))], + } + } else { + return Err(format!("Type '{typ}' of an 'open' has more than one constructor")); } } else { - return Err(format!("Type '{typ}' of an 'open' has more than one constructor")); + return Err(format!("Type '{typ}' of an 'open' is not defined")); + } + } + Term::Def { def, nxt } => { + for rule in def.rules.iter_mut() { + rule.body.desugar_open(adts)?; + } + nxt.desugar_open(adts)?; + } + _ => { + for child in self.children_mut() { + child.desugar_open(adts)?; } - } else { - return Err(format!("Type '{typ}' of an 'open' is not defined")); } - } - for child in self.children_mut() { - child.desugar_open(adts)?; } Ok(()) }) diff --git a/src/fun/transform/fix_match_defs.rs b/src/fun/transform/fix_match_defs.rs index f5415ff3..299fef47 100644 --- a/src/fun/transform/fix_match_defs.rs +++ b/src/fun/transform/fix_match_defs.rs @@ -1,6 +1,6 @@ use crate::{ diagnostics::Diagnostics, - fun::{Adts, Constructors, Ctx, Pattern}, + fun::{Adts, Constructors, Ctx, Pattern, Rule, Term}, }; impl Ctx<'_> { @@ -15,18 +15,7 @@ impl Ctx<'_> { let def_arity = def.arity(); for rule in &mut def.rules { - if rule.arity() != def_arity { - errs.push(format!( - "Incorrect pattern matching rule arity. Expected {} args, found {}.", - def_arity, - rule.arity() - )); - } - - for pat in &mut rule.pats { - pat.resolve_pat(&self.book.ctrs); - pat.check_good_ctr(&self.book.ctrs, &self.book.adts, &mut errs); - } + rule.fix_match_defs(def_arity, &self.book.ctrs, &self.book.adts, &mut errs); } for err in errs { @@ -38,6 +27,44 @@ impl Ctx<'_> { } } +impl Rule { + fn fix_match_defs(&mut self, def_arity: usize, ctrs: &Constructors, adts: &Adts, errs: &mut Vec) { + if self.arity() != def_arity { + errs.push(format!( + "Incorrect pattern matching rule arity. Expected {} args, found {}.", + def_arity, + self.arity() + )); + } + + for pat in &mut self.pats { + pat.resolve_pat(ctrs); + pat.check_good_ctr(ctrs, adts, errs); + } + + self.body.fix_match_defs(ctrs, adts, errs); + } +} + +impl Term { + fn fix_match_defs(&mut self, ctrs: &Constructors, adts: &Adts, errs: &mut Vec) { + match self { + Term::Def { def, nxt } => { + let def_arity = def.arity(); + for rule in &mut def.rules { + rule.fix_match_defs(def_arity, ctrs, adts, errs); + } + nxt.fix_match_defs(ctrs, adts, errs); + } + _ => { + for children in self.children_mut() { + children.fix_match_defs(ctrs, adts, errs); + } + } + } + } +} + impl Pattern { /// If a var pattern actually refers to an ADT constructor, convert it into a constructor pattern. fn resolve_pat(&mut self, ctrs: &Constructors) { diff --git a/src/fun/transform/fix_match_terms.rs b/src/fun/transform/fix_match_terms.rs index f655bf35..14974119 100644 --- a/src/fun/transform/fix_match_terms.rs +++ b/src/fun/transform/fix_match_terms.rs @@ -84,8 +84,14 @@ impl Term { if matches!(self, Term::Mat { .. } | Term::Fold { .. }) { self.fix_match(&mut errs, ctrs, adts); } - // Add a use term to each arm rebuilding the matched variable match self { + Term::Def { def, nxt } => { + for rule in def.rules.iter_mut() { + errs.extend(rule.body.fix_match_terms(ctrs, adts)); + } + errs.extend(nxt.fix_match_terms(ctrs, adts)); + } + // Add a use term to each arm rebuilding the matched variable Term::Mat { arg: _, bnd, with_bnd: _, with_arg: _, arms } | Term::Fold { bnd, arg: _, with_bnd: _, with_arg: _, arms } => { for (ctr, fields, body) in arms { diff --git a/src/fun/transform/lift_local_defs.rs b/src/fun/transform/lift_local_defs.rs index c0fa3d88..4f9545e3 100644 --- a/src/fun/transform/lift_local_defs.rs +++ b/src/fun/transform/lift_local_defs.rs @@ -2,7 +2,10 @@ use std::collections::BTreeSet; use indexmap::IndexMap; -use crate::fun::{Book, Definition, Name, Pattern, Rule, Term}; +use crate::{ + fun::{Book, Definition, Name, Pattern, Rule, Term}, + maybe_grow, +}; impl Book { pub fn lift_local_defs(&mut self) { @@ -25,10 +28,10 @@ impl Rule { impl Term { pub fn lift_local_defs(&mut self, parent: &Name, defs: &mut IndexMap, gen: &mut usize) { - match self { - Term::Def { nam, rules, nxt } => { - let local_name = Name::new(format!("{}__local_{}_{}", parent, gen, nam)); - for rule in rules.iter_mut() { + maybe_grow(|| match self { + Term::Def { def, nxt } => { + let local_name = Name::new(format!("{}__local_{}_{}", parent, gen, def.name)); + for rule in def.rules.iter_mut() { rule.body.lift_local_defs(&local_name, defs, gen); } nxt.lift_local_defs(parent, defs, gen); @@ -36,7 +39,8 @@ impl Term { let inner_defs = defs.keys().filter(|name| name.starts_with(local_name.as_ref())).cloned().collect::>(); - let (r#use, fvs, mut rules) = gen_use(inner_defs, &local_name, nam, nxt, std::mem::take(rules)); + let (r#use, fvs, mut rules) = + gen_use(inner_defs, &local_name, &def.name, nxt, std::mem::take(&mut def.rules)); *self = r#use; apply_closure(&mut rules, &fvs); @@ -49,7 +53,7 @@ impl Term { child.lift_local_defs(parent, defs, gen); } } - } + }) } } @@ -89,9 +93,7 @@ fn gen_use( fn apply_closure(rules: &mut [Rule], fvs: &BTreeSet) { for rule in rules.iter_mut() { - let pats = std::mem::take(&mut rule.pats); - let mut captured = fvs.iter().cloned().map(|nam| Pattern::Var(Some(nam))).collect::>(); - captured.extend(pats); - rule.pats = captured; + let captured = fvs.iter().cloned().map(|nam| Some(nam)).collect::>(); + rule.body = Term::rfold_lams(std::mem::take(&mut rule.body), captured.into_iter()); } } diff --git a/src/fun/transform/resolve_refs.rs b/src/fun/transform/resolve_refs.rs index 11f2ae7c..48722b89 100644 --- a/src/fun/transform/resolve_refs.rs +++ b/src/fun/transform/resolve_refs.rs @@ -30,7 +30,8 @@ impl Ctx<'_> { push_scope(name.as_ref(), &mut scope); } - let res = rule.body.resolve_refs(&def_names, self.book.entrypoint.as_ref(), &mut scope); + let res = + rule.body.resolve_refs(&def_names, self.book.entrypoint.as_ref(), &mut scope, &mut self.info); self.info.take_rule_err(res, def_name.clone()); } } @@ -45,31 +46,48 @@ impl Term { def_names: &HashSet, main: Option<&Name>, scope: &mut HashMap<&'a Name, usize>, + info: &mut Diagnostics, ) -> Result<(), String> { maybe_grow(move || { - if let Term::Var { nam } = self { - if is_var_in_scope(nam, scope) { - // If the variable is actually a reference to main, don't swap and return an error. - if let Some(main) = main { - if nam == main { - return Err("Main definition can't be referenced inside the program.".to_string()); + match self { + Term::Var { nam } => { + if is_var_in_scope(nam, scope) { + // If the variable is actually a reference to main, don't swap and return an error. + if let Some(main) = main { + if nam == main { + return Err("Main definition can't be referenced inside the program.".to_string()); + } + } + + // If the variable is actually a reference to a function, swap the term. + if def_names.contains(nam) { + *self = Term::r#ref(nam); } } + } + Term::Def { def, nxt } => { + for rule in def.rules.iter_mut() { + let mut scope = HashMap::new(); - // If the variable is actually a reference to a function, swap the term. - if def_names.contains(nam) { - *self = Term::r#ref(nam); + for name in rule.pats.iter().flat_map(Pattern::binds) { + push_scope(name.as_ref(), &mut scope); + } + + let res = rule.body.resolve_refs(&def_names, main, &mut scope, info); + info.take_rule_err(res, def.name.clone()); } + nxt.resolve_refs(def_names, main, scope, info)?; } - } - - for (child, binds) in self.children_mut_with_binds() { - for bind in binds.clone() { - push_scope(bind.as_ref(), scope); - } - child.resolve_refs(def_names, main, scope)?; - for bind in binds.rev() { - pop_scope(bind.as_ref(), scope); + _ => { + for (child, binds) in self.children_mut_with_binds() { + for bind in binds.clone() { + push_scope(bind.as_ref(), scope); + } + child.resolve_refs(def_names, main, scope, info)?; + for bind in binds.rev() { + pop_scope(bind.as_ref(), scope); + } + } } } Ok(()) diff --git a/src/imp/to_fun.rs b/src/imp/to_fun.rs index 0f14a04e..b8163119 100644 --- a/src/imp/to_fun.rs +++ b/src/imp/to_fun.rs @@ -391,7 +391,7 @@ impl Stmt { StmtToFun::Assign(pat, term) => (Some(pat), term), }; let def = def.to_fun()?; - let term = fun::Term::Def { nam: def.name, rules: def.rules, nxt: Box::new(nxt) }; + let term = fun::Term::Def { def, nxt: Box::new(nxt) }; if let Some(pat) = nxt_pat { StmtToFun::Assign(pat, term) } else { diff --git a/src/lib.rs b/src/lib.rs index 25e1b27c..0789d3f9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -91,8 +91,6 @@ pub fn desugar_book( ctx.set_entrypoint(); - ctx.book.lift_local_defs(); - ctx.book.encode_adts(opts.adt_encoding); ctx.fix_match_defs()?; @@ -109,6 +107,8 @@ pub fn desugar_book( ctx.fix_match_terms()?; + ctx.book.lift_local_defs(); + ctx.desugar_bend()?; ctx.desugar_fold()?; ctx.desugar_with_blocks()?; diff --git a/tests/snapshots/parse_file__fun_def.bend.snap b/tests/snapshots/parse_file__fun_def.bend.snap index e6107bf0..98c4381f 100644 --- a/tests/snapshots/parse_file__fun_def.bend.snap +++ b/tests/snapshots/parse_file__fun_def.bend.snap @@ -2,4 +2,4 @@ source: tests/golden_tests.rs input_file: tests/golden_tests/parse_file/fun_def.bend --- -(main) = let base = 0; def (aux []) = base(aux (List/Cons head tail)) = (+ head (aux tail))(aux [1, 2, 3]) +(main) = let base = 0; def (aux) = λ%arg0 match %arg0 = %arg0 { List/Nil: base; List/Cons %arg0.head %arg0.tail: use tail = %arg0.tail; use head = %arg0.head; (+ head (aux tail)); }(aux (List/Cons 1 (List/Cons 2 (List/Cons 3 List/Nil))))