From d3cce025000b5edb9e8ff6649d8061ac91a89a93 Mon Sep 17 00:00:00 2001 From: imaqtkatt Date: Mon, 1 Jul 2024 18:21:49 -0300 Subject: [PATCH] Implement 'def' term --- src/fun/display.rs | 28 ++++ src/fun/mod.rs | 12 +- src/fun/parser.rs | 36 ++++- src/fun/term_to_net.rs | 1 + src/fun/transform/expand_main.rs | 1 + src/fun/transform/float_combinators.rs | 8 +- src/fun/transform/lift_defs.rs | 92 ++++++++++++ src/fun/transform/linearize_vars.rs | 1 + src/fun/transform/mod.rs | 1 + src/fun/transform/unique_names.rs | 1 + src/imp/lift_local_defs.rs | 131 ------------------ src/imp/mod.rs | 1 - src/imp/to_fun.rs | 14 +- src/lib.rs | 2 + .../desugar_file__local_def_shadow.bend.snap | 4 +- .../desugar_file__main_aux.bend.snap | 4 +- 16 files changed, 196 insertions(+), 141 deletions(-) create mode 100644 src/fun/transform/lift_defs.rs delete mode 100644 src/imp/lift_local_defs.rs diff --git a/src/fun/display.rs b/src/fun/display.rs index e4180c0f..0330dd01 100644 --- a/src/fun/display.rs +++ b/src/fun/display.rs @@ -167,6 +167,13 @@ impl fmt::Display for Term { } Term::List { els } => write!(f, "[{}]", DisplayJoin(|| els.iter(), ", "),), Term::Open { typ, var, bod } => write!(f, "open {typ} {var}; {bod}"), + Term::Def { nam, rules, nxt } => { + write!(f, "def ")?; + for rule in rules.iter() { + write!(f, "{}", rule.display(nam))?; + } + write!(f, "{nxt}") + } Term::Err => write!(f, ""), }) } @@ -320,6 +327,16 @@ impl Rule { self.body.display_pretty(2) ) } + + pub fn display_def_aux<'a>(&'a self, def_name: &'a Name, tab: usize) -> impl fmt::Display + 'a { + display!( + "({}{}) =\n {:tab$}{}", + def_name, + DisplayJoin(|| self.pats.iter().map(|x| display!(" {x}")), ""), + "", + self.body.display_pretty(tab + 2) + ) + } } impl Term { @@ -483,6 +500,17 @@ impl Term { Term::Num { val: Num::F24(val) } => write!(f, "{val:.3}"), Term::Str { val } => write!(f, "{val:?}"), Term::Ref { nam } => write!(f, "{nam}"), + Term::Def { nam, rules, nxt } => { + write!(f, "def ")?; + for (i, rule) in rules.iter().enumerate() { + if i == 0 { + writeln!(f, "{}", rule.display_def_aux(nam, tab + 4))?; + } else { + writeln!(f, "{:tab$}{}", "", rule.display_def_aux(nam, tab + 4), tab = tab + 4)?; + } + } + write!(f, "{:tab$}{}", "", nxt.display_pretty(tab)) + } Term::Era => write!(f, "*"), Term::Err => write!(f, ""), }) diff --git a/src/fun/mod.rs b/src/fun/mod.rs index d9667a24..17b0becf 100644 --- a/src/fun/mod.rs +++ b/src/fun/mod.rs @@ -73,7 +73,7 @@ pub struct HvmDefinition { } /// A pattern matching rule of a definition. -#[derive(Debug, Clone, Default, PartialEq)] +#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)] pub struct Rule { pub pats: Vec, pub body: Term, @@ -179,6 +179,11 @@ pub enum Term { Ref { nam: Name, }, + Def { + nam: Name, + rules: Vec, + nxt: Box, + }, Era, #[default] Err, @@ -378,6 +383,7 @@ impl Clone for Term { Self::Open { typ: typ.clone(), var: var.clone(), bod: nxt.clone() } } Self::Ref { nam } => Self::Ref { nam: nam.clone() }, + Self::Def { nam, rules, nxt } => Self::Def { nam: nam.clone(), rules: rules.clone(), nxt: nxt.clone() }, Self::Era => Self::Era, Self::Err => Self::Err, }) @@ -545,6 +551,7 @@ impl Term { | Term::Nat { .. } | Term::Str { .. } | Term::Ref { .. } + | Term::Def { .. } | Term::Era | Term::Err => ChildrenIter::Zero([]), } @@ -580,6 +587,7 @@ impl Term { | Term::Nat { .. } | Term::Str { .. } | Term::Ref { .. } + | Term::Def { .. } | Term::Era | Term::Err => ChildrenIter::Zero([]), } @@ -647,6 +655,7 @@ impl Term { | Term::Nat { .. } | Term::Str { .. } | Term::Ref { .. } + | Term::Def { .. } | Term::Era | Term::Err => ChildrenIter::Zero([]), Term::Open { .. } => unreachable!("Open should be removed in earlier pass"), @@ -710,6 +719,7 @@ impl Term { | Term::Nat { .. } | Term::Str { .. } | Term::Ref { .. } + | Term::Def { .. } | Term::Era | Term::Err => ChildrenIter::Zero([]), Term::Open { .. } => unreachable!("Open should be removed in earlier pass"), diff --git a/src/fun/parser.rs b/src/fun/parser.rs index 5d0614fb..9d2b4b15 100644 --- a/src/fun/parser.rs +++ b/src/fun/parser.rs @@ -490,6 +490,38 @@ impl<'a> TermParser<'a> { return Ok(Term::Ask { pat: Box::new(pat), val: Box::new(val), nxt: Box::new(nxt) }); } + // Def + if self.try_parse_keyword("def") { + self.skip_trivia(); + let (cur_name, rule) = self.parse_rule()?; + let mut rules = vec![rule]; + let mut cur_idx = *self.index(); + loop { + self.skip_trivia(); + let back_term = *self.index(); + match self.parse_rule() { + Ok((name, rule)) => { + if name == "def" { + self.index = back_term; + return Ok(Term::Def { nam: cur_name, rules, nxt: Box::new(self.parse_term()?) }); + } + if name == cur_name { + rules.push(rule); + cur_idx = *self.index(); + } else { + panic!() + } + } + Err(_) => { + self.index = cur_idx; + break; + } + } + } + let nxt = self.parse_term()?; + return Ok(Term::Def { nam: cur_name, rules, nxt: Box::new(nxt) }); + } + // If if self.try_parse_keyword("if") { let mut chain = Vec::new(); @@ -781,9 +813,9 @@ impl<'a> TermParser<'a> { self.check_top_level_redefinition(&def.name, book, span)?; def.order_kwargs(book)?; def.gen_map_get(); - let locals = def.lift_local_defs(&mut 0)?; + // let locals = def.lift_local_defs(&mut 0)?; let def = def.to_fun(builtin)?; - book.defs.extend(locals); + // book.defs.extend(locals); book.defs.insert(def.name.clone(), def); Ok(()) } diff --git a/src/fun/term_to_net.rs b/src/fun/term_to_net.rs index 49fa8b5c..808b6331 100644 --- a/src/fun/term_to_net.rs +++ b/src/fun/term_to_net.rs @@ -235,6 +235,7 @@ impl<'t, 'l> EncodeTermState<'t, 'l> { | Term::Nat { .. } // Removed in encode_nat | Term::Str { .. } // Removed in encode_str | Term::List { .. } // Removed in encode_list + | Term::Def { .. } // Removed in earlier pass | Term::Err => unreachable!(), } while let Some((pat, val)) = self.lets.pop() { diff --git a/src/fun/transform/expand_main.rs b/src/fun/transform/expand_main.rs index 8bf3d418..8b07f4b9 100644 --- a/src/fun/transform/expand_main.rs +++ b/src/fun/transform/expand_main.rs @@ -68,6 +68,7 @@ impl Term { | Term::Swt { .. } | Term::Fold { .. } | Term::Bend { .. } + | Term::Def { .. } | Term::Era | Term::Err => {} }) diff --git a/src/fun/transform/float_combinators.rs b/src/fun/transform/float_combinators.rs index 19d1390f..064e9835 100644 --- a/src/fun/transform/float_combinators.rs +++ b/src/fun/transform/float_combinators.rs @@ -219,6 +219,7 @@ impl Term { | Term::With { .. } | Term::Ask { .. } | Term::Open { .. } + | Term::Def { .. } | Term::Err => unreachable!(), } } @@ -264,7 +265,12 @@ impl Term { | Term::Ref { .. } | Term::Era | Term::Err => FloatIter::Zero([]), - Term::With { .. } | Term::Ask { .. } | Term::Bend { .. } | Term::Fold { .. } | Term::Open { .. } => { + Term::With { .. } + | Term::Ask { .. } + | Term::Bend { .. } + | Term::Fold { .. } + | Term::Open { .. } + | Term::Def { .. } => { unreachable!() } } diff --git a/src/fun/transform/lift_defs.rs b/src/fun/transform/lift_defs.rs new file mode 100644 index 00000000..41d3db66 --- /dev/null +++ b/src/fun/transform/lift_defs.rs @@ -0,0 +1,92 @@ +use std::collections::BTreeSet; + +use indexmap::IndexMap; + +use crate::fun::{Book, Definition, Name, Pattern, Rule, Term}; + +impl Book { + pub fn lift_defs(&mut self) { + let mut defs = IndexMap::new(); + for (name, def) in self.defs.iter_mut() { + let mut gen = 0; + for rule in def.rules.iter_mut() { + rule.body.lift_defs(name, &mut defs, &mut gen); + } + } + self.defs.extend(defs); + } +} + +impl Rule { + pub fn binds(&self) -> impl DoubleEndedIterator> + Clone { + self.pats.iter().flat_map(Pattern::binds) + } +} + +impl Term { + pub fn lift_defs(&mut self, parent: &Name, defs: &mut IndexMap, gen: &mut usize) { + match self { + Term::Def { nam, rules, nxt } => { + let local_name = Name::new(format!("{}__local_{}_{}", parent, gen, nam)); + for rule in rules.iter_mut() { + rule.body.lift_defs(&local_name, defs, gen); + } + nxt.lift_defs(parent, defs, gen); + *gen += 1; + + let inner_defs = + defs.keys().filter(|name| name.starts_with(local_name.as_ref())).cloned().collect::>(); + let (r#use, fvs, mut rules) = gen_use(inner_defs, &local_name, nam, nxt, std::mem::take(rules)); + *self = r#use; + + apply_closure(&mut rules, &fvs); + + let new_def = Definition { name: local_name.clone(), rules, builtin: false }; + defs.insert(local_name.clone(), new_def); + } + _ => { + for child in self.children_mut() { + child.lift_defs(parent, defs, gen); + } + } + } + } +} + +fn gen_use( + inner_defs: BTreeSet, + local_name: &Name, + nam: &Name, + nxt: &mut Box, + rules: Vec, +) -> (Term, BTreeSet, Vec) { + let mut fvs = BTreeSet::::new(); + for rule in rules.iter() { + fvs.extend(rule.body.free_vars().into_keys().collect::>()); + } + // fvs = fvs.into_iter().filter(|fv| !inner_defs.contains(fv)).collect(); + fvs.retain(|fv| !inner_defs.contains(fv)); + for rule in rules.iter() { + for bind in rule.binds().flatten() { + fvs.remove(bind); + } + } + + let call = Term::call( + Term::Ref { nam: local_name.clone() }, + fvs.iter().cloned().map(|nam| Term::Var { nam }).collect::>(), + ); + + let r#use = Term::Use { nam: Some(nam.clone()), val: Box::new(call), nxt: std::mem::take(nxt) }; + + (r#use, fvs, rules) +} + +fn apply_closure(rules: &mut [Rule], fvs: &BTreeSet) { + for rule in rules.iter_mut() { + let pats = std::mem::take(&mut rule.pats); + let mut captured = fvs.iter().cloned().map(|nam| Pattern::Var(Some(nam))).collect::>(); + captured.extend(pats); + rule.pats = captured; + } +} diff --git a/src/fun/transform/linearize_vars.rs b/src/fun/transform/linearize_vars.rs index adacbce2..0e804d10 100644 --- a/src/fun/transform/linearize_vars.rs +++ b/src/fun/transform/linearize_vars.rs @@ -168,6 +168,7 @@ impl Term { Term::Fold { .. } => unreachable!("'fold' should be removed in earlier pass"), Term::Bend { .. } => unreachable!("'bend' should be removed in earlier pass"), Term::Open { .. } => unreachable!("'open' should be removed in earlier pass"), + Term::Def { .. } => unreachable!("'def' should be removed in earlier pass"), } } } diff --git a/src/fun/transform/mod.rs b/src/fun/transform/mod.rs index a73b9934..be55c364 100644 --- a/src/fun/transform/mod.rs +++ b/src/fun/transform/mod.rs @@ -14,6 +14,7 @@ pub mod expand_main; pub mod fix_match_defs; pub mod fix_match_terms; pub mod float_combinators; +pub mod lift_defs; pub mod linearize_matches; pub mod linearize_vars; pub mod resolve_refs; diff --git a/src/fun/transform/unique_names.rs b/src/fun/transform/unique_names.rs index 2aa0c7fd..7be96039 100644 --- a/src/fun/transform/unique_names.rs +++ b/src/fun/transform/unique_names.rs @@ -162,6 +162,7 @@ impl UniqueNameGenerator { | Term::Era | Term::Err => {} Term::Open { .. } => unreachable!("'open' should be removed in earlier pass"), + Term::Def { .. } => unreachable!("'def' should be removed in earlier pass"), }) } diff --git a/src/imp/lift_local_defs.rs b/src/imp/lift_local_defs.rs deleted file mode 100644 index c520e17c..00000000 --- a/src/imp/lift_local_defs.rs +++ /dev/null @@ -1,131 +0,0 @@ -use std::collections::{BTreeMap, BTreeSet}; - -use indexmap::IndexMap; - -use crate::fun::{self, Name, Pattern}; - -use super::{Definition, Expr, Stmt}; - -impl Definition { - pub fn lift_local_defs(&mut self, gen: &mut usize) -> Result, String> { - let mut defs = IndexMap::new(); - self.body.lift_local_defs(&self.name, &mut defs, gen)?; - Ok(defs) - } -} - -impl Stmt { - pub fn lift_local_defs( - &mut self, - parent: &Name, - defs: &mut IndexMap, - gen: &mut usize, - ) -> Result<(), String> { - match self { - Stmt::LocalDef { .. } => { - let Stmt::LocalDef { mut def, mut nxt } = std::mem::take(self) else { unreachable!() }; - let local_name = Name::new(format!("{}__local_{}_{}", parent, gen, def.name)); - def.body.lift_local_defs(&local_name, defs, gen)?; - nxt.lift_local_defs(parent, defs, gen)?; - *gen += 1; - - let inner_defs = - defs.keys().filter(|name| name.starts_with(local_name.as_ref())).cloned().collect::>(); - let (r#use, mut def, fvs) = gen_use(local_name.clone(), *def, nxt, inner_defs)?; - *self = r#use; - apply_closure(&mut def, fvs); - - defs.insert(def.name.clone(), def); - Ok(()) - } - - Stmt::Assign { pat: _, val: _, nxt } => { - if let Some(nxt) = nxt { - nxt.lift_local_defs(parent, defs, gen)?; - } - Ok(()) - } - Stmt::If { cond: _, then, otherwise, nxt } => { - then.lift_local_defs(parent, defs, gen)?; - otherwise.lift_local_defs(parent, defs, gen)?; - if let Some(nxt) = nxt { - nxt.lift_local_defs(parent, defs, gen)?; - } - Ok(()) - } - Stmt::Match { arg: _, bnd: _, with_bnd: _, with_arg: _, arms, nxt } - | Stmt::Fold { arg: _, bnd: _, with_bnd: _, with_arg: _, arms, nxt } => { - for arm in arms.iter_mut() { - arm.rgt.lift_local_defs(parent, defs, gen)?; - } - if let Some(nxt) = nxt { - nxt.lift_local_defs(parent, defs, gen)?; - } - Ok(()) - } - Stmt::Switch { arg: _, bnd: _, with_bnd: _, with_arg: _, arms, nxt } => { - for arm in arms.iter_mut() { - arm.lift_local_defs(parent, defs, gen)?; - } - if let Some(nxt) = nxt { - nxt.lift_local_defs(parent, defs, gen)?; - } - Ok(()) - } - Stmt::Bend { bnd: _, arg: _, cond: _, step, base, nxt } => { - step.lift_local_defs(parent, defs, gen)?; - base.lift_local_defs(parent, defs, gen)?; - if let Some(nxt) = nxt { - nxt.lift_local_defs(parent, defs, gen)?; - } - Ok(()) - } - Stmt::With { typ: _, bod, nxt } => { - bod.lift_local_defs(parent, defs, gen)?; - if let Some(nxt) = nxt { - nxt.lift_local_defs(parent, defs, gen)?; - } - Ok(()) - } - - Stmt::InPlace { op: _, pat: _, val: _, nxt } - | Stmt::Ask { pat: _, val: _, nxt } - | Stmt::Open { typ: _, var: _, nxt } - | Stmt::Use { nam: _, val: _, nxt } => nxt.lift_local_defs(parent, defs, gen), - - Stmt::Return { .. } | Stmt::Err => Ok(()), - } - } -} - -fn gen_use( - local_name: Name, - def: Definition, - nxt: Box, - inner_defs: BTreeSet, -) -> Result<(Stmt, fun::Definition, Vec), String> { - let params = def.params.clone(); - let ignored: BTreeSet = params.into_iter().chain(inner_defs).collect(); - let mut def = def.to_fun(false)?; - - let fvs = BTreeMap::from_iter(def.rules[0].body.free_vars()); - let fvs = fvs.into_keys().filter(|fv| !ignored.contains(fv)).collect::>(); - let val = Expr::Call { - fun: Box::new(Expr::Var { nam: local_name.clone() }), - args: fvs.iter().cloned().map(|nam| Expr::Var { nam }).collect(), - kwargs: vec![], - }; - - let r#use = Stmt::Use { nam: def.name.clone(), val: Box::new(val), nxt }; - def.name = local_name; - - Ok((r#use, def, fvs)) -} - -fn apply_closure(def: &mut fun::Definition, fvs: Vec) { - let rule = &mut def.rules[0]; - let mut captured = fvs.into_iter().map(|x| Pattern::Var(Some(x))).collect::>(); - let rule_pats = std::mem::take(&mut rule.pats); - captured.extend(rule_pats); - rule.pats = captured; -} diff --git a/src/imp/mod.rs b/src/imp/mod.rs index 9058863e..e70cc4e7 100644 --- a/src/imp/mod.rs +++ b/src/imp/mod.rs @@ -1,5 +1,4 @@ pub mod gen_map_get; -pub mod lift_local_defs; mod order_kwargs; pub mod parser; pub mod to_fun; diff --git a/src/imp/to_fun.rs b/src/imp/to_fun.rs index 30f439da..f69df463 100644 --- a/src/imp/to_fun.rs +++ b/src/imp/to_fun.rs @@ -52,7 +52,6 @@ impl Stmt { // TODO: Refactor this to not repeat everything. // TODO: When we have an error with an assignment, we should show the offending assignment (eg. "{pat} = ..."). let stmt_to_fun = match self { - Stmt::LocalDef { .. } => todo!(), Stmt::Assign { pat: AssignPattern::MapSet(map, key), val, nxt: Some(nxt) } => { let (nxt_pat, nxt) = match nxt.into_fun()? { StmtToFun::Return(term) => (None, term), @@ -367,6 +366,19 @@ impl Stmt { } } Stmt::Return { term } => StmtToFun::Return(term.to_fun()), + Stmt::LocalDef { def, nxt } => { + let (nxt_pat, nxt) = match nxt.into_fun()? { + StmtToFun::Return(term) => (None, term), + StmtToFun::Assign(pat, term) => (Some(pat), term), + }; + let def = def.to_fun(false)?; + let term = fun::Term::Def { nam: def.name, rules: def.rules, nxt: Box::new(nxt) }; + if let Some(pat) = nxt_pat { + StmtToFun::Assign(pat, term) + } else { + StmtToFun::Return(term) + } + } Stmt::Err => unreachable!(), }; Ok(stmt_to_fun) diff --git a/src/lib.rs b/src/lib.rs index 413453a7..7372aebc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -90,6 +90,8 @@ pub fn desugar_book( ctx.set_entrypoint(); + ctx.book.lift_defs(); + ctx.book.encode_adts(opts.adt_encoding); ctx.fix_match_defs()?; diff --git a/tests/snapshots/desugar_file__local_def_shadow.bend.snap b/tests/snapshots/desugar_file__local_def_shadow.bend.snap index 963d6e99..551da5ee 100644 --- a/tests/snapshots/desugar_file__local_def_shadow.bend.snap +++ b/tests/snapshots/desugar_file__local_def_shadow.bend.snap @@ -2,6 +2,8 @@ source: tests/golden_tests.rs input_file: tests/golden_tests/desugar_file/local_def_shadow.bend --- +(main) = 1 + (main__local_0_A__local_0_B) = 0 (main__local_1_A__local_1_B) = 1 @@ -9,5 +11,3 @@ input_file: tests/golden_tests/desugar_file/local_def_shadow.bend (main__local_1_A) = main__local_1_A__local_1_B (main__local_0_A) = main__local_0_A__local_0_B - -(main) = 1 diff --git a/tests/snapshots/desugar_file__main_aux.bend.snap b/tests/snapshots/desugar_file__main_aux.bend.snap index f4531a14..4de95035 100644 --- a/tests/snapshots/desugar_file__main_aux.bend.snap +++ b/tests/snapshots/desugar_file__main_aux.bend.snap @@ -2,10 +2,10 @@ source: tests/golden_tests.rs input_file: tests/golden_tests/desugar_file/main_aux.bend --- +(main) = (main__local_0_aux 89 2) + (main__local_0_aux__local_0_aux__local_0_aux) = λa λb (+ b a) (main__local_0_aux__local_0_aux) = λa λb (main__local_0_aux__local_0_aux__local_0_aux a b) (main__local_0_aux) = λa λb (main__local_0_aux__local_0_aux a b) - -(main) = (main__local_0_aux 89 2)