#623 Desugar local defs properly

This commit is contained in:
imaqtkatt 2024-07-08 17:26:34 -03:00
parent f288ae7f12
commit fda44b8910
13 changed files with 175 additions and 79 deletions

View File

@ -70,6 +70,13 @@ impl Term {
Term::List { els } => *self = Term::encode_list(std::mem::take(els)),
Term::Str { val } => *self = Term::encode_str(val),
Term::Nat { val } => *self = Term::encode_nat(*val),
Term::Def { def, nxt } => {
for rule in def.rules.iter_mut() {
rule.pats.iter_mut().for_each(Pattern::encode_builtins);
rule.body.encode_builtins();
}
nxt.encode_builtins();
}
_ => {
for child in self.children_mut() {
child.encode_builtins();

View File

@ -167,10 +167,10 @@ impl fmt::Display for Term {
}
Term::List { els } => write!(f, "[{}]", DisplayJoin(|| els.iter(), ", "),),
Term::Open { typ, var, bod } => write!(f, "open {typ} {var}; {bod}"),
Term::Def { nam, rules, nxt } => {
Term::Def { def, nxt } => {
write!(f, "def ")?;
for rule in rules.iter() {
write!(f, "{}", rule.display(nam))?;
for rule in def.rules.iter() {
write!(f, "{}", rule.display(&def.name))?;
}
write!(f, "{nxt}")
}
@ -500,13 +500,13 @@ impl Term {
Term::Num { val: Num::F24(val) } => write!(f, "{val:.3}"),
Term::Str { val } => write!(f, "{val:?}"),
Term::Ref { nam } => write!(f, "{nam}"),
Term::Def { nam, rules, nxt } => {
Term::Def { def, nxt } => {
write!(f, "def ")?;
for (i, rule) in rules.iter().enumerate() {
for (i, rule) in def.rules.iter().enumerate() {
if i == 0 {
writeln!(f, "{}", rule.display_def_aux(nam, tab + 4))?;
writeln!(f, "{}", rule.display_def_aux(&def.name, tab + 4))?;
} else {
writeln!(f, "{:tab$}{}", "", rule.display_def_aux(nam, tab + 4), tab = tab + 4)?;
writeln!(f, "{:tab$}{}", "", rule.display_def_aux(&def.name, tab + 4), tab = tab + 4)?;
}
}
write!(f, "{:tab$}{}", "", nxt.display_pretty(tab))

View File

@ -65,14 +65,14 @@ pub type Adts = IndexMap<Name, Adt>;
pub type Constructors = IndexMap<Name, Name>;
/// A pattern matching function definition.
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Definition {
pub name: Name,
pub rules: Vec<Rule>,
pub source: Source,
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Source {
Builtin,
/// Was generated by the compiler.
@ -199,9 +199,11 @@ pub enum Term {
nam: Name,
},
Def {
nam: Name,
rules: Vec<Rule>,
def: Definition,
nxt: Box<Term>,
// nam: Name,
// rules: Vec<Rule>,
// nxt: Box<Term>,
},
Era,
#[default]
@ -402,7 +404,7 @@ impl Clone for Term {
Self::Open { typ: typ.clone(), var: var.clone(), bod: nxt.clone() }
}
Self::Ref { nam } => Self::Ref { nam: nam.clone() },
Self::Def { nam, rules, nxt } => Self::Def { nam: nam.clone(), rules: rules.clone(), nxt: nxt.clone() },
Self::Def { def, nxt } => Self::Def { def: def.clone(), nxt: nxt.clone() },
Self::Era => Self::Era,
Self::Err => Self::Err,
})

View File

@ -631,7 +631,8 @@ impl<'a> TermParser<'a> {
if name == "def" {
// parse the nxt def term.
self.index = nxt_def;
return Ok(Term::Def { nam: cur_name, rules, nxt: Box::new(self.parse_term()?) });
let def = FunDefinition::new(name, rules, Source::Local(nxt_def..*self.index()));
return Ok(Term::Def { def, nxt: Box::new(self.parse_term()?) });
}
if name == cur_name {
rules.push(rule);
@ -648,7 +649,8 @@ impl<'a> TermParser<'a> {
}
}
let nxt = self.parse_term()?;
return Ok(Term::Def { nam: cur_name, rules, nxt: Box::new(nxt) });
let def = FunDefinition::new(cur_name, rules, Source::Local(nxt_term..*self.index()));
return Ok(Term::Def { def, nxt: Box::new(nxt) });
}
// If

View File

@ -38,7 +38,9 @@ impl Ctx<'_> {
impl Definition {
pub fn desugar_match_def(&mut self, ctrs: &Constructors, adts: &Adts) -> Vec<DesugarMatchDefErr> {
let mut errs = vec![];
for rule in self.rules.iter_mut() {
desugar_inner_match_defs(&mut rule.body, ctrs, adts, &mut errs);
}
let repeated_bind_errs = fix_repeated_binds(&mut self.rules);
errs.extend(repeated_bind_errs);
@ -55,6 +57,25 @@ impl Definition {
}
}
fn desugar_inner_match_defs(
term: &mut Term,
ctrs: &Constructors,
adts: &Adts,
errs: &mut Vec<DesugarMatchDefErr>,
) {
maybe_grow(|| match term {
Term::Def { def, nxt } => {
errs.extend(def.desugar_match_def(ctrs, adts));
desugar_inner_match_defs(nxt, ctrs, adts, errs);
}
_ => {
for child in term.children_mut() {
desugar_inner_match_defs(child, ctrs, adts, errs);
}
}
})
}
/// When a rule has repeated bind, the only one that is actually useful is the last one.
///
/// Example: In `(Foo x x x x) = x`, the function should return the fourth argument.

View File

@ -23,26 +23,37 @@ impl Ctx<'_> {
impl Term {
fn desugar_open(&mut self, adts: &Adts) -> Result<(), String> {
maybe_grow(|| {
if let Term::Open { typ, var, bod } = self {
if let Some(adt) = adts.get(&*typ) {
if adt.ctrs.len() == 1 {
let ctr = adt.ctrs.keys().next().unwrap();
*self = Term::Mat {
arg: Box::new(Term::Var { nam: var.clone() }),
bnd: Some(std::mem::take(var)),
with_bnd: vec![],
with_arg: vec![],
arms: vec![(Some(ctr.clone()), vec![], std::mem::take(bod))],
match self {
Term::Open { typ, var, bod } => {
bod.desugar_open(adts)?;
if let Some(adt) = adts.get(&*typ) {
if adt.ctrs.len() == 1 {
let ctr = adt.ctrs.keys().next().unwrap();
*self = Term::Mat {
arg: Box::new(Term::Var { nam: var.clone() }),
bnd: Some(std::mem::take(var)),
with_bnd: vec![],
with_arg: vec![],
arms: vec![(Some(ctr.clone()), vec![], std::mem::take(bod))],
}
} else {
return Err(format!("Type '{typ}' of an 'open' has more than one constructor"));
}
} else {
return Err(format!("Type '{typ}' of an 'open' has more than one constructor"));
return Err(format!("Type '{typ}' of an 'open' is not defined"));
}
}
Term::Def { def, nxt } => {
for rule in def.rules.iter_mut() {
rule.body.desugar_open(adts)?;
}
nxt.desugar_open(adts)?;
}
_ => {
for child in self.children_mut() {
child.desugar_open(adts)?;
}
} else {
return Err(format!("Type '{typ}' of an 'open' is not defined"));
}
}
for child in self.children_mut() {
child.desugar_open(adts)?;
}
Ok(())
})

View File

@ -1,6 +1,6 @@
use crate::{
diagnostics::Diagnostics,
fun::{Adts, Constructors, Ctx, Pattern},
fun::{Adts, Constructors, Ctx, Pattern, Rule, Term},
};
impl Ctx<'_> {
@ -15,18 +15,7 @@ impl Ctx<'_> {
let def_arity = def.arity();
for rule in &mut def.rules {
if rule.arity() != def_arity {
errs.push(format!(
"Incorrect pattern matching rule arity. Expected {} args, found {}.",
def_arity,
rule.arity()
));
}
for pat in &mut rule.pats {
pat.resolve_pat(&self.book.ctrs);
pat.check_good_ctr(&self.book.ctrs, &self.book.adts, &mut errs);
}
rule.fix_match_defs(def_arity, &self.book.ctrs, &self.book.adts, &mut errs);
}
for err in errs {
@ -38,6 +27,44 @@ impl Ctx<'_> {
}
}
impl Rule {
fn fix_match_defs(&mut self, def_arity: usize, ctrs: &Constructors, adts: &Adts, errs: &mut Vec<String>) {
if self.arity() != def_arity {
errs.push(format!(
"Incorrect pattern matching rule arity. Expected {} args, found {}.",
def_arity,
self.arity()
));
}
for pat in &mut self.pats {
pat.resolve_pat(ctrs);
pat.check_good_ctr(ctrs, adts, errs);
}
self.body.fix_match_defs(ctrs, adts, errs);
}
}
impl Term {
fn fix_match_defs(&mut self, ctrs: &Constructors, adts: &Adts, errs: &mut Vec<String>) {
match self {
Term::Def { def, nxt } => {
let def_arity = def.arity();
for rule in &mut def.rules {
rule.fix_match_defs(def_arity, ctrs, adts, errs);
}
nxt.fix_match_defs(ctrs, adts, errs);
}
_ => {
for children in self.children_mut() {
children.fix_match_defs(ctrs, adts, errs);
}
}
}
}
}
impl Pattern {
/// If a var pattern actually refers to an ADT constructor, convert it into a constructor pattern.
fn resolve_pat(&mut self, ctrs: &Constructors) {

View File

@ -84,8 +84,14 @@ impl Term {
if matches!(self, Term::Mat { .. } | Term::Fold { .. }) {
self.fix_match(&mut errs, ctrs, adts);
}
// Add a use term to each arm rebuilding the matched variable
match self {
Term::Def { def, nxt } => {
for rule in def.rules.iter_mut() {
errs.extend(rule.body.fix_match_terms(ctrs, adts));
}
errs.extend(nxt.fix_match_terms(ctrs, adts));
}
// Add a use term to each arm rebuilding the matched variable
Term::Mat { arg: _, bnd, with_bnd: _, with_arg: _, arms }
| Term::Fold { bnd, arg: _, with_bnd: _, with_arg: _, arms } => {
for (ctr, fields, body) in arms {

View File

@ -2,7 +2,10 @@ use std::collections::BTreeSet;
use indexmap::IndexMap;
use crate::fun::{Book, Definition, Name, Pattern, Rule, Term};
use crate::{
fun::{Book, Definition, Name, Pattern, Rule, Term},
maybe_grow,
};
impl Book {
pub fn lift_local_defs(&mut self) {
@ -25,10 +28,10 @@ impl Rule {
impl Term {
pub fn lift_local_defs(&mut self, parent: &Name, defs: &mut IndexMap<Name, Definition>, gen: &mut usize) {
match self {
Term::Def { nam, rules, nxt } => {
let local_name = Name::new(format!("{}__local_{}_{}", parent, gen, nam));
for rule in rules.iter_mut() {
maybe_grow(|| match self {
Term::Def { def, nxt } => {
let local_name = Name::new(format!("{}__local_{}_{}", parent, gen, def.name));
for rule in def.rules.iter_mut() {
rule.body.lift_local_defs(&local_name, defs, gen);
}
nxt.lift_local_defs(parent, defs, gen);
@ -36,7 +39,8 @@ impl Term {
let inner_defs =
defs.keys().filter(|name| name.starts_with(local_name.as_ref())).cloned().collect::<BTreeSet<_>>();
let (r#use, fvs, mut rules) = gen_use(inner_defs, &local_name, nam, nxt, std::mem::take(rules));
let (r#use, fvs, mut rules) =
gen_use(inner_defs, &local_name, &def.name, nxt, std::mem::take(&mut def.rules));
*self = r#use;
apply_closure(&mut rules, &fvs);
@ -49,7 +53,7 @@ impl Term {
child.lift_local_defs(parent, defs, gen);
}
}
}
})
}
}
@ -89,9 +93,7 @@ fn gen_use(
fn apply_closure(rules: &mut [Rule], fvs: &BTreeSet<Name>) {
for rule in rules.iter_mut() {
let pats = std::mem::take(&mut rule.pats);
let mut captured = fvs.iter().cloned().map(|nam| Pattern::Var(Some(nam))).collect::<Vec<_>>();
captured.extend(pats);
rule.pats = captured;
let captured = fvs.iter().cloned().map(|nam| Some(nam)).collect::<Vec<_>>();
rule.body = Term::rfold_lams(std::mem::take(&mut rule.body), captured.into_iter());
}
}

View File

@ -30,7 +30,8 @@ impl Ctx<'_> {
push_scope(name.as_ref(), &mut scope);
}
let res = rule.body.resolve_refs(&def_names, self.book.entrypoint.as_ref(), &mut scope);
let res =
rule.body.resolve_refs(&def_names, self.book.entrypoint.as_ref(), &mut scope, &mut self.info);
self.info.take_rule_err(res, def_name.clone());
}
}
@ -45,31 +46,48 @@ impl Term {
def_names: &HashSet<Name>,
main: Option<&Name>,
scope: &mut HashMap<&'a Name, usize>,
info: &mut Diagnostics,
) -> Result<(), String> {
maybe_grow(move || {
if let Term::Var { nam } = self {
if is_var_in_scope(nam, scope) {
// If the variable is actually a reference to main, don't swap and return an error.
if let Some(main) = main {
if nam == main {
return Err("Main definition can't be referenced inside the program.".to_string());
match self {
Term::Var { nam } => {
if is_var_in_scope(nam, scope) {
// If the variable is actually a reference to main, don't swap and return an error.
if let Some(main) = main {
if nam == main {
return Err("Main definition can't be referenced inside the program.".to_string());
}
}
// If the variable is actually a reference to a function, swap the term.
if def_names.contains(nam) {
*self = Term::r#ref(nam);
}
}
}
Term::Def { def, nxt } => {
for rule in def.rules.iter_mut() {
let mut scope = HashMap::new();
// If the variable is actually a reference to a function, swap the term.
if def_names.contains(nam) {
*self = Term::r#ref(nam);
for name in rule.pats.iter().flat_map(Pattern::binds) {
push_scope(name.as_ref(), &mut scope);
}
let res = rule.body.resolve_refs(&def_names, main, &mut scope, info);
info.take_rule_err(res, def.name.clone());
}
nxt.resolve_refs(def_names, main, scope, info)?;
}
}
for (child, binds) in self.children_mut_with_binds() {
for bind in binds.clone() {
push_scope(bind.as_ref(), scope);
}
child.resolve_refs(def_names, main, scope)?;
for bind in binds.rev() {
pop_scope(bind.as_ref(), scope);
_ => {
for (child, binds) in self.children_mut_with_binds() {
for bind in binds.clone() {
push_scope(bind.as_ref(), scope);
}
child.resolve_refs(def_names, main, scope, info)?;
for bind in binds.rev() {
pop_scope(bind.as_ref(), scope);
}
}
}
}
Ok(())

View File

@ -391,7 +391,7 @@ impl Stmt {
StmtToFun::Assign(pat, term) => (Some(pat), term),
};
let def = def.to_fun()?;
let term = fun::Term::Def { nam: def.name, rules: def.rules, nxt: Box::new(nxt) };
let term = fun::Term::Def { def, nxt: Box::new(nxt) };
if let Some(pat) = nxt_pat {
StmtToFun::Assign(pat, term)
} else {

View File

@ -91,8 +91,6 @@ pub fn desugar_book(
ctx.set_entrypoint();
ctx.book.lift_local_defs();
ctx.book.encode_adts(opts.adt_encoding);
ctx.fix_match_defs()?;
@ -109,6 +107,8 @@ pub fn desugar_book(
ctx.fix_match_terms()?;
ctx.book.lift_local_defs();
ctx.desugar_bend()?;
ctx.desugar_fold()?;
ctx.desugar_with_blocks()?;

View File

@ -2,4 +2,4 @@
source: tests/golden_tests.rs
input_file: tests/golden_tests/parse_file/fun_def.bend
---
(main) = let base = 0; def (aux []) = base(aux (List/Cons head tail)) = (+ head (aux tail))(aux [1, 2, 3])
(main) = let base = 0; def (aux) = λ%arg0 match %arg0 = %arg0 { List/Nil: base; List/Cons %arg0.head %arg0.tail: use tail = %arg0.tail; use head = %arg0.head; (+ head (aux tail)); }(aux (List/Cons 1 (List/Cons 2 (List/Cons 3 List/Nil))))