Implement 'def' term

This commit is contained in:
imaqtkatt 2024-07-01 18:21:49 -03:00
parent 82afbc0d89
commit d3cce02500
16 changed files with 196 additions and 141 deletions

View File

@ -167,6 +167,13 @@ impl fmt::Display for Term {
}
Term::List { els } => write!(f, "[{}]", DisplayJoin(|| els.iter(), ", "),),
Term::Open { typ, var, bod } => write!(f, "open {typ} {var}; {bod}"),
Term::Def { nam, rules, nxt } => {
write!(f, "def ")?;
for rule in rules.iter() {
write!(f, "{}", rule.display(nam))?;
}
write!(f, "{nxt}")
}
Term::Err => write!(f, "<Invalid>"),
})
}
@ -320,6 +327,16 @@ impl Rule {
self.body.display_pretty(2)
)
}
pub fn display_def_aux<'a>(&'a self, def_name: &'a Name, tab: usize) -> impl fmt::Display + 'a {
display!(
"({}{}) =\n {:tab$}{}",
def_name,
DisplayJoin(|| self.pats.iter().map(|x| display!(" {x}")), ""),
"",
self.body.display_pretty(tab + 2)
)
}
}
impl Term {
@ -483,6 +500,17 @@ impl Term {
Term::Num { val: Num::F24(val) } => write!(f, "{val:.3}"),
Term::Str { val } => write!(f, "{val:?}"),
Term::Ref { nam } => write!(f, "{nam}"),
Term::Def { nam, rules, nxt } => {
write!(f, "def ")?;
for (i, rule) in rules.iter().enumerate() {
if i == 0 {
writeln!(f, "{}", rule.display_def_aux(nam, tab + 4))?;
} else {
writeln!(f, "{:tab$}{}", "", rule.display_def_aux(nam, tab + 4), tab = tab + 4)?;
}
}
write!(f, "{:tab$}{}", "", nxt.display_pretty(tab))
}
Term::Era => write!(f, "*"),
Term::Err => write!(f, "<Error>"),
})

View File

@ -73,7 +73,7 @@ pub struct HvmDefinition {
}
/// A pattern matching rule of a definition.
#[derive(Debug, Clone, Default, PartialEq)]
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
pub struct Rule {
pub pats: Vec<Pattern>,
pub body: Term,
@ -179,6 +179,11 @@ pub enum Term {
Ref {
nam: Name,
},
Def {
nam: Name,
rules: Vec<Rule>,
nxt: Box<Term>,
},
Era,
#[default]
Err,
@ -378,6 +383,7 @@ impl Clone for Term {
Self::Open { typ: typ.clone(), var: var.clone(), bod: nxt.clone() }
}
Self::Ref { nam } => Self::Ref { nam: nam.clone() },
Self::Def { nam, rules, nxt } => Self::Def { nam: nam.clone(), rules: rules.clone(), nxt: nxt.clone() },
Self::Era => Self::Era,
Self::Err => Self::Err,
})
@ -545,6 +551,7 @@ impl Term {
| Term::Nat { .. }
| Term::Str { .. }
| Term::Ref { .. }
| Term::Def { .. }
| Term::Era
| Term::Err => ChildrenIter::Zero([]),
}
@ -580,6 +587,7 @@ impl Term {
| Term::Nat { .. }
| Term::Str { .. }
| Term::Ref { .. }
| Term::Def { .. }
| Term::Era
| Term::Err => ChildrenIter::Zero([]),
}
@ -647,6 +655,7 @@ impl Term {
| Term::Nat { .. }
| Term::Str { .. }
| Term::Ref { .. }
| Term::Def { .. }
| Term::Era
| Term::Err => ChildrenIter::Zero([]),
Term::Open { .. } => unreachable!("Open should be removed in earlier pass"),
@ -710,6 +719,7 @@ impl Term {
| Term::Nat { .. }
| Term::Str { .. }
| Term::Ref { .. }
| Term::Def { .. }
| Term::Era
| Term::Err => ChildrenIter::Zero([]),
Term::Open { .. } => unreachable!("Open should be removed in earlier pass"),

View File

@ -490,6 +490,38 @@ impl<'a> TermParser<'a> {
return Ok(Term::Ask { pat: Box::new(pat), val: Box::new(val), nxt: Box::new(nxt) });
}
// Def
if self.try_parse_keyword("def") {
self.skip_trivia();
let (cur_name, rule) = self.parse_rule()?;
let mut rules = vec![rule];
let mut cur_idx = *self.index();
loop {
self.skip_trivia();
let back_term = *self.index();
match self.parse_rule() {
Ok((name, rule)) => {
if name == "def" {
self.index = back_term;
return Ok(Term::Def { nam: cur_name, rules, nxt: Box::new(self.parse_term()?) });
}
if name == cur_name {
rules.push(rule);
cur_idx = *self.index();
} else {
panic!()
}
}
Err(_) => {
self.index = cur_idx;
break;
}
}
}
let nxt = self.parse_term()?;
return Ok(Term::Def { nam: cur_name, rules, nxt: Box::new(nxt) });
}
// If
if self.try_parse_keyword("if") {
let mut chain = Vec::new();
@ -781,9 +813,9 @@ impl<'a> TermParser<'a> {
self.check_top_level_redefinition(&def.name, book, span)?;
def.order_kwargs(book)?;
def.gen_map_get();
let locals = def.lift_local_defs(&mut 0)?;
// let locals = def.lift_local_defs(&mut 0)?;
let def = def.to_fun(builtin)?;
book.defs.extend(locals);
// book.defs.extend(locals);
book.defs.insert(def.name.clone(), def);
Ok(())
}

View File

@ -235,6 +235,7 @@ impl<'t, 'l> EncodeTermState<'t, 'l> {
| Term::Nat { .. } // Removed in encode_nat
| Term::Str { .. } // Removed in encode_str
| Term::List { .. } // Removed in encode_list
| Term::Def { .. } // Removed in earlier pass
| Term::Err => unreachable!(),
}
while let Some((pat, val)) = self.lets.pop() {

View File

@ -68,6 +68,7 @@ impl Term {
| Term::Swt { .. }
| Term::Fold { .. }
| Term::Bend { .. }
| Term::Def { .. }
| Term::Era
| Term::Err => {}
})

View File

@ -219,6 +219,7 @@ impl Term {
| Term::With { .. }
| Term::Ask { .. }
| Term::Open { .. }
| Term::Def { .. }
| Term::Err => unreachable!(),
}
}
@ -264,7 +265,12 @@ impl Term {
| Term::Ref { .. }
| Term::Era
| Term::Err => FloatIter::Zero([]),
Term::With { .. } | Term::Ask { .. } | Term::Bend { .. } | Term::Fold { .. } | Term::Open { .. } => {
Term::With { .. }
| Term::Ask { .. }
| Term::Bend { .. }
| Term::Fold { .. }
| Term::Open { .. }
| Term::Def { .. } => {
unreachable!()
}
}

View File

@ -0,0 +1,92 @@
use std::collections::BTreeSet;
use indexmap::IndexMap;
use crate::fun::{Book, Definition, Name, Pattern, Rule, Term};
impl Book {
pub fn lift_defs(&mut self) {
let mut defs = IndexMap::new();
for (name, def) in self.defs.iter_mut() {
let mut gen = 0;
for rule in def.rules.iter_mut() {
rule.body.lift_defs(name, &mut defs, &mut gen);
}
}
self.defs.extend(defs);
}
}
impl Rule {
pub fn binds(&self) -> impl DoubleEndedIterator<Item = &Option<Name>> + Clone {
self.pats.iter().flat_map(Pattern::binds)
}
}
impl Term {
pub fn lift_defs(&mut self, parent: &Name, defs: &mut IndexMap<Name, Definition>, gen: &mut usize) {
match self {
Term::Def { nam, rules, nxt } => {
let local_name = Name::new(format!("{}__local_{}_{}", parent, gen, nam));
for rule in rules.iter_mut() {
rule.body.lift_defs(&local_name, defs, gen);
}
nxt.lift_defs(parent, defs, gen);
*gen += 1;
let inner_defs =
defs.keys().filter(|name| name.starts_with(local_name.as_ref())).cloned().collect::<BTreeSet<_>>();
let (r#use, fvs, mut rules) = gen_use(inner_defs, &local_name, nam, nxt, std::mem::take(rules));
*self = r#use;
apply_closure(&mut rules, &fvs);
let new_def = Definition { name: local_name.clone(), rules, builtin: false };
defs.insert(local_name.clone(), new_def);
}
_ => {
for child in self.children_mut() {
child.lift_defs(parent, defs, gen);
}
}
}
}
}
fn gen_use(
inner_defs: BTreeSet<Name>,
local_name: &Name,
nam: &Name,
nxt: &mut Box<Term>,
rules: Vec<Rule>,
) -> (Term, BTreeSet<Name>, Vec<Rule>) {
let mut fvs = BTreeSet::<Name>::new();
for rule in rules.iter() {
fvs.extend(rule.body.free_vars().into_keys().collect::<BTreeSet<_>>());
}
// fvs = fvs.into_iter().filter(|fv| !inner_defs.contains(fv)).collect();
fvs.retain(|fv| !inner_defs.contains(fv));
for rule in rules.iter() {
for bind in rule.binds().flatten() {
fvs.remove(bind);
}
}
let call = Term::call(
Term::Ref { nam: local_name.clone() },
fvs.iter().cloned().map(|nam| Term::Var { nam }).collect::<Vec<_>>(),
);
let r#use = Term::Use { nam: Some(nam.clone()), val: Box::new(call), nxt: std::mem::take(nxt) };
(r#use, fvs, rules)
}
fn apply_closure(rules: &mut [Rule], fvs: &BTreeSet<Name>) {
for rule in rules.iter_mut() {
let pats = std::mem::take(&mut rule.pats);
let mut captured = fvs.iter().cloned().map(|nam| Pattern::Var(Some(nam))).collect::<Vec<_>>();
captured.extend(pats);
rule.pats = captured;
}
}

View File

@ -168,6 +168,7 @@ impl Term {
Term::Fold { .. } => unreachable!("'fold' should be removed in earlier pass"),
Term::Bend { .. } => unreachable!("'bend' should be removed in earlier pass"),
Term::Open { .. } => unreachable!("'open' should be removed in earlier pass"),
Term::Def { .. } => unreachable!("'def' should be removed in earlier pass"),
}
}
}

View File

@ -14,6 +14,7 @@ pub mod expand_main;
pub mod fix_match_defs;
pub mod fix_match_terms;
pub mod float_combinators;
pub mod lift_defs;
pub mod linearize_matches;
pub mod linearize_vars;
pub mod resolve_refs;

View File

@ -162,6 +162,7 @@ impl UniqueNameGenerator {
| Term::Era
| Term::Err => {}
Term::Open { .. } => unreachable!("'open' should be removed in earlier pass"),
Term::Def { .. } => unreachable!("'def' should be removed in earlier pass"),
})
}

View File

@ -1,131 +0,0 @@
use std::collections::{BTreeMap, BTreeSet};
use indexmap::IndexMap;
use crate::fun::{self, Name, Pattern};
use super::{Definition, Expr, Stmt};
impl Definition {
pub fn lift_local_defs(&mut self, gen: &mut usize) -> Result<IndexMap<Name, fun::Definition>, String> {
let mut defs = IndexMap::new();
self.body.lift_local_defs(&self.name, &mut defs, gen)?;
Ok(defs)
}
}
impl Stmt {
pub fn lift_local_defs(
&mut self,
parent: &Name,
defs: &mut IndexMap<Name, fun::Definition>,
gen: &mut usize,
) -> Result<(), String> {
match self {
Stmt::LocalDef { .. } => {
let Stmt::LocalDef { mut def, mut nxt } = std::mem::take(self) else { unreachable!() };
let local_name = Name::new(format!("{}__local_{}_{}", parent, gen, def.name));
def.body.lift_local_defs(&local_name, defs, gen)?;
nxt.lift_local_defs(parent, defs, gen)?;
*gen += 1;
let inner_defs =
defs.keys().filter(|name| name.starts_with(local_name.as_ref())).cloned().collect::<BTreeSet<_>>();
let (r#use, mut def, fvs) = gen_use(local_name.clone(), *def, nxt, inner_defs)?;
*self = r#use;
apply_closure(&mut def, fvs);
defs.insert(def.name.clone(), def);
Ok(())
}
Stmt::Assign { pat: _, val: _, nxt } => {
if let Some(nxt) = nxt {
nxt.lift_local_defs(parent, defs, gen)?;
}
Ok(())
}
Stmt::If { cond: _, then, otherwise, nxt } => {
then.lift_local_defs(parent, defs, gen)?;
otherwise.lift_local_defs(parent, defs, gen)?;
if let Some(nxt) = nxt {
nxt.lift_local_defs(parent, defs, gen)?;
}
Ok(())
}
Stmt::Match { arg: _, bnd: _, with_bnd: _, with_arg: _, arms, nxt }
| Stmt::Fold { arg: _, bnd: _, with_bnd: _, with_arg: _, arms, nxt } => {
for arm in arms.iter_mut() {
arm.rgt.lift_local_defs(parent, defs, gen)?;
}
if let Some(nxt) = nxt {
nxt.lift_local_defs(parent, defs, gen)?;
}
Ok(())
}
Stmt::Switch { arg: _, bnd: _, with_bnd: _, with_arg: _, arms, nxt } => {
for arm in arms.iter_mut() {
arm.lift_local_defs(parent, defs, gen)?;
}
if let Some(nxt) = nxt {
nxt.lift_local_defs(parent, defs, gen)?;
}
Ok(())
}
Stmt::Bend { bnd: _, arg: _, cond: _, step, base, nxt } => {
step.lift_local_defs(parent, defs, gen)?;
base.lift_local_defs(parent, defs, gen)?;
if let Some(nxt) = nxt {
nxt.lift_local_defs(parent, defs, gen)?;
}
Ok(())
}
Stmt::With { typ: _, bod, nxt } => {
bod.lift_local_defs(parent, defs, gen)?;
if let Some(nxt) = nxt {
nxt.lift_local_defs(parent, defs, gen)?;
}
Ok(())
}
Stmt::InPlace { op: _, pat: _, val: _, nxt }
| Stmt::Ask { pat: _, val: _, nxt }
| Stmt::Open { typ: _, var: _, nxt }
| Stmt::Use { nam: _, val: _, nxt } => nxt.lift_local_defs(parent, defs, gen),
Stmt::Return { .. } | Stmt::Err => Ok(()),
}
}
}
fn gen_use(
local_name: Name,
def: Definition,
nxt: Box<Stmt>,
inner_defs: BTreeSet<Name>,
) -> Result<(Stmt, fun::Definition, Vec<Name>), String> {
let params = def.params.clone();
let ignored: BTreeSet<Name> = params.into_iter().chain(inner_defs).collect();
let mut def = def.to_fun(false)?;
let fvs = BTreeMap::from_iter(def.rules[0].body.free_vars());
let fvs = fvs.into_keys().filter(|fv| !ignored.contains(fv)).collect::<Vec<_>>();
let val = Expr::Call {
fun: Box::new(Expr::Var { nam: local_name.clone() }),
args: fvs.iter().cloned().map(|nam| Expr::Var { nam }).collect(),
kwargs: vec![],
};
let r#use = Stmt::Use { nam: def.name.clone(), val: Box::new(val), nxt };
def.name = local_name;
Ok((r#use, def, fvs))
}
fn apply_closure(def: &mut fun::Definition, fvs: Vec<Name>) {
let rule = &mut def.rules[0];
let mut captured = fvs.into_iter().map(|x| Pattern::Var(Some(x))).collect::<Vec<_>>();
let rule_pats = std::mem::take(&mut rule.pats);
captured.extend(rule_pats);
rule.pats = captured;
}

View File

@ -1,5 +1,4 @@
pub mod gen_map_get;
pub mod lift_local_defs;
mod order_kwargs;
pub mod parser;
pub mod to_fun;

View File

@ -52,7 +52,6 @@ impl Stmt {
// TODO: Refactor this to not repeat everything.
// TODO: When we have an error with an assignment, we should show the offending assignment (eg. "{pat} = ...").
let stmt_to_fun = match self {
Stmt::LocalDef { .. } => todo!(),
Stmt::Assign { pat: AssignPattern::MapSet(map, key), val, nxt: Some(nxt) } => {
let (nxt_pat, nxt) = match nxt.into_fun()? {
StmtToFun::Return(term) => (None, term),
@ -367,6 +366,19 @@ impl Stmt {
}
}
Stmt::Return { term } => StmtToFun::Return(term.to_fun()),
Stmt::LocalDef { def, nxt } => {
let (nxt_pat, nxt) = match nxt.into_fun()? {
StmtToFun::Return(term) => (None, term),
StmtToFun::Assign(pat, term) => (Some(pat), term),
};
let def = def.to_fun(false)?;
let term = fun::Term::Def { nam: def.name, rules: def.rules, nxt: Box::new(nxt) };
if let Some(pat) = nxt_pat {
StmtToFun::Assign(pat, term)
} else {
StmtToFun::Return(term)
}
}
Stmt::Err => unreachable!(),
};
Ok(stmt_to_fun)

View File

@ -90,6 +90,8 @@ pub fn desugar_book(
ctx.set_entrypoint();
ctx.book.lift_defs();
ctx.book.encode_adts(opts.adt_encoding);
ctx.fix_match_defs()?;

View File

@ -2,6 +2,8 @@
source: tests/golden_tests.rs
input_file: tests/golden_tests/desugar_file/local_def_shadow.bend
---
(main) = 1
(main__local_0_A__local_0_B) = 0
(main__local_1_A__local_1_B) = 1
@ -9,5 +11,3 @@ input_file: tests/golden_tests/desugar_file/local_def_shadow.bend
(main__local_1_A) = main__local_1_A__local_1_B
(main__local_0_A) = main__local_0_A__local_0_B
(main) = 1

View File

@ -2,10 +2,10 @@
source: tests/golden_tests.rs
input_file: tests/golden_tests/desugar_file/main_aux.bend
---
(main) = (main__local_0_aux 89 2)
(main__local_0_aux__local_0_aux__local_0_aux) = λa λb (+ b a)
(main__local_0_aux__local_0_aux) = λa λb (main__local_0_aux__local_0_aux__local_0_aux a b)
(main__local_0_aux) = λa λb (main__local_0_aux__local_0_aux a b)
(main) = (main__local_0_aux 89 2)