mirror of
https://github.com/HigherOrderCO/Bend.git
synced 2024-08-15 14:50:42 +03:00
Implement 'def' term
This commit is contained in:
parent
82afbc0d89
commit
d3cce02500
@ -167,6 +167,13 @@ impl fmt::Display for Term {
|
||||
}
|
||||
Term::List { els } => write!(f, "[{}]", DisplayJoin(|| els.iter(), ", "),),
|
||||
Term::Open { typ, var, bod } => write!(f, "open {typ} {var}; {bod}"),
|
||||
Term::Def { nam, rules, nxt } => {
|
||||
write!(f, "def ")?;
|
||||
for rule in rules.iter() {
|
||||
write!(f, "{}", rule.display(nam))?;
|
||||
}
|
||||
write!(f, "{nxt}")
|
||||
}
|
||||
Term::Err => write!(f, "<Invalid>"),
|
||||
})
|
||||
}
|
||||
@ -320,6 +327,16 @@ impl Rule {
|
||||
self.body.display_pretty(2)
|
||||
)
|
||||
}
|
||||
|
||||
pub fn display_def_aux<'a>(&'a self, def_name: &'a Name, tab: usize) -> impl fmt::Display + 'a {
|
||||
display!(
|
||||
"({}{}) =\n {:tab$}{}",
|
||||
def_name,
|
||||
DisplayJoin(|| self.pats.iter().map(|x| display!(" {x}")), ""),
|
||||
"",
|
||||
self.body.display_pretty(tab + 2)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Term {
|
||||
@ -483,6 +500,17 @@ impl Term {
|
||||
Term::Num { val: Num::F24(val) } => write!(f, "{val:.3}"),
|
||||
Term::Str { val } => write!(f, "{val:?}"),
|
||||
Term::Ref { nam } => write!(f, "{nam}"),
|
||||
Term::Def { nam, rules, nxt } => {
|
||||
write!(f, "def ")?;
|
||||
for (i, rule) in rules.iter().enumerate() {
|
||||
if i == 0 {
|
||||
writeln!(f, "{}", rule.display_def_aux(nam, tab + 4))?;
|
||||
} else {
|
||||
writeln!(f, "{:tab$}{}", "", rule.display_def_aux(nam, tab + 4), tab = tab + 4)?;
|
||||
}
|
||||
}
|
||||
write!(f, "{:tab$}{}", "", nxt.display_pretty(tab))
|
||||
}
|
||||
Term::Era => write!(f, "*"),
|
||||
Term::Err => write!(f, "<Error>"),
|
||||
})
|
||||
|
@ -73,7 +73,7 @@ pub struct HvmDefinition {
|
||||
}
|
||||
|
||||
/// A pattern matching rule of a definition.
|
||||
#[derive(Debug, Clone, Default, PartialEq)]
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash)]
|
||||
pub struct Rule {
|
||||
pub pats: Vec<Pattern>,
|
||||
pub body: Term,
|
||||
@ -179,6 +179,11 @@ pub enum Term {
|
||||
Ref {
|
||||
nam: Name,
|
||||
},
|
||||
Def {
|
||||
nam: Name,
|
||||
rules: Vec<Rule>,
|
||||
nxt: Box<Term>,
|
||||
},
|
||||
Era,
|
||||
#[default]
|
||||
Err,
|
||||
@ -378,6 +383,7 @@ impl Clone for Term {
|
||||
Self::Open { typ: typ.clone(), var: var.clone(), bod: nxt.clone() }
|
||||
}
|
||||
Self::Ref { nam } => Self::Ref { nam: nam.clone() },
|
||||
Self::Def { nam, rules, nxt } => Self::Def { nam: nam.clone(), rules: rules.clone(), nxt: nxt.clone() },
|
||||
Self::Era => Self::Era,
|
||||
Self::Err => Self::Err,
|
||||
})
|
||||
@ -545,6 +551,7 @@ impl Term {
|
||||
| Term::Nat { .. }
|
||||
| Term::Str { .. }
|
||||
| Term::Ref { .. }
|
||||
| Term::Def { .. }
|
||||
| Term::Era
|
||||
| Term::Err => ChildrenIter::Zero([]),
|
||||
}
|
||||
@ -580,6 +587,7 @@ impl Term {
|
||||
| Term::Nat { .. }
|
||||
| Term::Str { .. }
|
||||
| Term::Ref { .. }
|
||||
| Term::Def { .. }
|
||||
| Term::Era
|
||||
| Term::Err => ChildrenIter::Zero([]),
|
||||
}
|
||||
@ -647,6 +655,7 @@ impl Term {
|
||||
| Term::Nat { .. }
|
||||
| Term::Str { .. }
|
||||
| Term::Ref { .. }
|
||||
| Term::Def { .. }
|
||||
| Term::Era
|
||||
| Term::Err => ChildrenIter::Zero([]),
|
||||
Term::Open { .. } => unreachable!("Open should be removed in earlier pass"),
|
||||
@ -710,6 +719,7 @@ impl Term {
|
||||
| Term::Nat { .. }
|
||||
| Term::Str { .. }
|
||||
| Term::Ref { .. }
|
||||
| Term::Def { .. }
|
||||
| Term::Era
|
||||
| Term::Err => ChildrenIter::Zero([]),
|
||||
Term::Open { .. } => unreachable!("Open should be removed in earlier pass"),
|
||||
|
@ -490,6 +490,38 @@ impl<'a> TermParser<'a> {
|
||||
return Ok(Term::Ask { pat: Box::new(pat), val: Box::new(val), nxt: Box::new(nxt) });
|
||||
}
|
||||
|
||||
// Def
|
||||
if self.try_parse_keyword("def") {
|
||||
self.skip_trivia();
|
||||
let (cur_name, rule) = self.parse_rule()?;
|
||||
let mut rules = vec![rule];
|
||||
let mut cur_idx = *self.index();
|
||||
loop {
|
||||
self.skip_trivia();
|
||||
let back_term = *self.index();
|
||||
match self.parse_rule() {
|
||||
Ok((name, rule)) => {
|
||||
if name == "def" {
|
||||
self.index = back_term;
|
||||
return Ok(Term::Def { nam: cur_name, rules, nxt: Box::new(self.parse_term()?) });
|
||||
}
|
||||
if name == cur_name {
|
||||
rules.push(rule);
|
||||
cur_idx = *self.index();
|
||||
} else {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
self.index = cur_idx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let nxt = self.parse_term()?;
|
||||
return Ok(Term::Def { nam: cur_name, rules, nxt: Box::new(nxt) });
|
||||
}
|
||||
|
||||
// If
|
||||
if self.try_parse_keyword("if") {
|
||||
let mut chain = Vec::new();
|
||||
@ -781,9 +813,9 @@ impl<'a> TermParser<'a> {
|
||||
self.check_top_level_redefinition(&def.name, book, span)?;
|
||||
def.order_kwargs(book)?;
|
||||
def.gen_map_get();
|
||||
let locals = def.lift_local_defs(&mut 0)?;
|
||||
// let locals = def.lift_local_defs(&mut 0)?;
|
||||
let def = def.to_fun(builtin)?;
|
||||
book.defs.extend(locals);
|
||||
// book.defs.extend(locals);
|
||||
book.defs.insert(def.name.clone(), def);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -235,6 +235,7 @@ impl<'t, 'l> EncodeTermState<'t, 'l> {
|
||||
| Term::Nat { .. } // Removed in encode_nat
|
||||
| Term::Str { .. } // Removed in encode_str
|
||||
| Term::List { .. } // Removed in encode_list
|
||||
| Term::Def { .. } // Removed in earlier pass
|
||||
| Term::Err => unreachable!(),
|
||||
}
|
||||
while let Some((pat, val)) = self.lets.pop() {
|
||||
|
@ -68,6 +68,7 @@ impl Term {
|
||||
| Term::Swt { .. }
|
||||
| Term::Fold { .. }
|
||||
| Term::Bend { .. }
|
||||
| Term::Def { .. }
|
||||
| Term::Era
|
||||
| Term::Err => {}
|
||||
})
|
||||
|
@ -219,6 +219,7 @@ impl Term {
|
||||
| Term::With { .. }
|
||||
| Term::Ask { .. }
|
||||
| Term::Open { .. }
|
||||
| Term::Def { .. }
|
||||
| Term::Err => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -264,7 +265,12 @@ impl Term {
|
||||
| Term::Ref { .. }
|
||||
| Term::Era
|
||||
| Term::Err => FloatIter::Zero([]),
|
||||
Term::With { .. } | Term::Ask { .. } | Term::Bend { .. } | Term::Fold { .. } | Term::Open { .. } => {
|
||||
Term::With { .. }
|
||||
| Term::Ask { .. }
|
||||
| Term::Bend { .. }
|
||||
| Term::Fold { .. }
|
||||
| Term::Open { .. }
|
||||
| Term::Def { .. } => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
92
src/fun/transform/lift_defs.rs
Normal file
92
src/fun/transform/lift_defs.rs
Normal file
@ -0,0 +1,92 @@
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
use indexmap::IndexMap;
|
||||
|
||||
use crate::fun::{Book, Definition, Name, Pattern, Rule, Term};
|
||||
|
||||
impl Book {
|
||||
pub fn lift_defs(&mut self) {
|
||||
let mut defs = IndexMap::new();
|
||||
for (name, def) in self.defs.iter_mut() {
|
||||
let mut gen = 0;
|
||||
for rule in def.rules.iter_mut() {
|
||||
rule.body.lift_defs(name, &mut defs, &mut gen);
|
||||
}
|
||||
}
|
||||
self.defs.extend(defs);
|
||||
}
|
||||
}
|
||||
|
||||
impl Rule {
|
||||
pub fn binds(&self) -> impl DoubleEndedIterator<Item = &Option<Name>> + Clone {
|
||||
self.pats.iter().flat_map(Pattern::binds)
|
||||
}
|
||||
}
|
||||
|
||||
impl Term {
|
||||
pub fn lift_defs(&mut self, parent: &Name, defs: &mut IndexMap<Name, Definition>, gen: &mut usize) {
|
||||
match self {
|
||||
Term::Def { nam, rules, nxt } => {
|
||||
let local_name = Name::new(format!("{}__local_{}_{}", parent, gen, nam));
|
||||
for rule in rules.iter_mut() {
|
||||
rule.body.lift_defs(&local_name, defs, gen);
|
||||
}
|
||||
nxt.lift_defs(parent, defs, gen);
|
||||
*gen += 1;
|
||||
|
||||
let inner_defs =
|
||||
defs.keys().filter(|name| name.starts_with(local_name.as_ref())).cloned().collect::<BTreeSet<_>>();
|
||||
let (r#use, fvs, mut rules) = gen_use(inner_defs, &local_name, nam, nxt, std::mem::take(rules));
|
||||
*self = r#use;
|
||||
|
||||
apply_closure(&mut rules, &fvs);
|
||||
|
||||
let new_def = Definition { name: local_name.clone(), rules, builtin: false };
|
||||
defs.insert(local_name.clone(), new_def);
|
||||
}
|
||||
_ => {
|
||||
for child in self.children_mut() {
|
||||
child.lift_defs(parent, defs, gen);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_use(
|
||||
inner_defs: BTreeSet<Name>,
|
||||
local_name: &Name,
|
||||
nam: &Name,
|
||||
nxt: &mut Box<Term>,
|
||||
rules: Vec<Rule>,
|
||||
) -> (Term, BTreeSet<Name>, Vec<Rule>) {
|
||||
let mut fvs = BTreeSet::<Name>::new();
|
||||
for rule in rules.iter() {
|
||||
fvs.extend(rule.body.free_vars().into_keys().collect::<BTreeSet<_>>());
|
||||
}
|
||||
// fvs = fvs.into_iter().filter(|fv| !inner_defs.contains(fv)).collect();
|
||||
fvs.retain(|fv| !inner_defs.contains(fv));
|
||||
for rule in rules.iter() {
|
||||
for bind in rule.binds().flatten() {
|
||||
fvs.remove(bind);
|
||||
}
|
||||
}
|
||||
|
||||
let call = Term::call(
|
||||
Term::Ref { nam: local_name.clone() },
|
||||
fvs.iter().cloned().map(|nam| Term::Var { nam }).collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let r#use = Term::Use { nam: Some(nam.clone()), val: Box::new(call), nxt: std::mem::take(nxt) };
|
||||
|
||||
(r#use, fvs, rules)
|
||||
}
|
||||
|
||||
fn apply_closure(rules: &mut [Rule], fvs: &BTreeSet<Name>) {
|
||||
for rule in rules.iter_mut() {
|
||||
let pats = std::mem::take(&mut rule.pats);
|
||||
let mut captured = fvs.iter().cloned().map(|nam| Pattern::Var(Some(nam))).collect::<Vec<_>>();
|
||||
captured.extend(pats);
|
||||
rule.pats = captured;
|
||||
}
|
||||
}
|
@ -168,6 +168,7 @@ impl Term {
|
||||
Term::Fold { .. } => unreachable!("'fold' should be removed in earlier pass"),
|
||||
Term::Bend { .. } => unreachable!("'bend' should be removed in earlier pass"),
|
||||
Term::Open { .. } => unreachable!("'open' should be removed in earlier pass"),
|
||||
Term::Def { .. } => unreachable!("'def' should be removed in earlier pass"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ pub mod expand_main;
|
||||
pub mod fix_match_defs;
|
||||
pub mod fix_match_terms;
|
||||
pub mod float_combinators;
|
||||
pub mod lift_defs;
|
||||
pub mod linearize_matches;
|
||||
pub mod linearize_vars;
|
||||
pub mod resolve_refs;
|
||||
|
@ -162,6 +162,7 @@ impl UniqueNameGenerator {
|
||||
| Term::Era
|
||||
| Term::Err => {}
|
||||
Term::Open { .. } => unreachable!("'open' should be removed in earlier pass"),
|
||||
Term::Def { .. } => unreachable!("'def' should be removed in earlier pass"),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1,131 +0,0 @@
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
|
||||
use indexmap::IndexMap;
|
||||
|
||||
use crate::fun::{self, Name, Pattern};
|
||||
|
||||
use super::{Definition, Expr, Stmt};
|
||||
|
||||
impl Definition {
|
||||
pub fn lift_local_defs(&mut self, gen: &mut usize) -> Result<IndexMap<Name, fun::Definition>, String> {
|
||||
let mut defs = IndexMap::new();
|
||||
self.body.lift_local_defs(&self.name, &mut defs, gen)?;
|
||||
Ok(defs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Stmt {
|
||||
pub fn lift_local_defs(
|
||||
&mut self,
|
||||
parent: &Name,
|
||||
defs: &mut IndexMap<Name, fun::Definition>,
|
||||
gen: &mut usize,
|
||||
) -> Result<(), String> {
|
||||
match self {
|
||||
Stmt::LocalDef { .. } => {
|
||||
let Stmt::LocalDef { mut def, mut nxt } = std::mem::take(self) else { unreachable!() };
|
||||
let local_name = Name::new(format!("{}__local_{}_{}", parent, gen, def.name));
|
||||
def.body.lift_local_defs(&local_name, defs, gen)?;
|
||||
nxt.lift_local_defs(parent, defs, gen)?;
|
||||
*gen += 1;
|
||||
|
||||
let inner_defs =
|
||||
defs.keys().filter(|name| name.starts_with(local_name.as_ref())).cloned().collect::<BTreeSet<_>>();
|
||||
let (r#use, mut def, fvs) = gen_use(local_name.clone(), *def, nxt, inner_defs)?;
|
||||
*self = r#use;
|
||||
apply_closure(&mut def, fvs);
|
||||
|
||||
defs.insert(def.name.clone(), def);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Stmt::Assign { pat: _, val: _, nxt } => {
|
||||
if let Some(nxt) = nxt {
|
||||
nxt.lift_local_defs(parent, defs, gen)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Stmt::If { cond: _, then, otherwise, nxt } => {
|
||||
then.lift_local_defs(parent, defs, gen)?;
|
||||
otherwise.lift_local_defs(parent, defs, gen)?;
|
||||
if let Some(nxt) = nxt {
|
||||
nxt.lift_local_defs(parent, defs, gen)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Stmt::Match { arg: _, bnd: _, with_bnd: _, with_arg: _, arms, nxt }
|
||||
| Stmt::Fold { arg: _, bnd: _, with_bnd: _, with_arg: _, arms, nxt } => {
|
||||
for arm in arms.iter_mut() {
|
||||
arm.rgt.lift_local_defs(parent, defs, gen)?;
|
||||
}
|
||||
if let Some(nxt) = nxt {
|
||||
nxt.lift_local_defs(parent, defs, gen)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Stmt::Switch { arg: _, bnd: _, with_bnd: _, with_arg: _, arms, nxt } => {
|
||||
for arm in arms.iter_mut() {
|
||||
arm.lift_local_defs(parent, defs, gen)?;
|
||||
}
|
||||
if let Some(nxt) = nxt {
|
||||
nxt.lift_local_defs(parent, defs, gen)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Stmt::Bend { bnd: _, arg: _, cond: _, step, base, nxt } => {
|
||||
step.lift_local_defs(parent, defs, gen)?;
|
||||
base.lift_local_defs(parent, defs, gen)?;
|
||||
if let Some(nxt) = nxt {
|
||||
nxt.lift_local_defs(parent, defs, gen)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Stmt::With { typ: _, bod, nxt } => {
|
||||
bod.lift_local_defs(parent, defs, gen)?;
|
||||
if let Some(nxt) = nxt {
|
||||
nxt.lift_local_defs(parent, defs, gen)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Stmt::InPlace { op: _, pat: _, val: _, nxt }
|
||||
| Stmt::Ask { pat: _, val: _, nxt }
|
||||
| Stmt::Open { typ: _, var: _, nxt }
|
||||
| Stmt::Use { nam: _, val: _, nxt } => nxt.lift_local_defs(parent, defs, gen),
|
||||
|
||||
Stmt::Return { .. } | Stmt::Err => Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_use(
|
||||
local_name: Name,
|
||||
def: Definition,
|
||||
nxt: Box<Stmt>,
|
||||
inner_defs: BTreeSet<Name>,
|
||||
) -> Result<(Stmt, fun::Definition, Vec<Name>), String> {
|
||||
let params = def.params.clone();
|
||||
let ignored: BTreeSet<Name> = params.into_iter().chain(inner_defs).collect();
|
||||
let mut def = def.to_fun(false)?;
|
||||
|
||||
let fvs = BTreeMap::from_iter(def.rules[0].body.free_vars());
|
||||
let fvs = fvs.into_keys().filter(|fv| !ignored.contains(fv)).collect::<Vec<_>>();
|
||||
let val = Expr::Call {
|
||||
fun: Box::new(Expr::Var { nam: local_name.clone() }),
|
||||
args: fvs.iter().cloned().map(|nam| Expr::Var { nam }).collect(),
|
||||
kwargs: vec![],
|
||||
};
|
||||
|
||||
let r#use = Stmt::Use { nam: def.name.clone(), val: Box::new(val), nxt };
|
||||
def.name = local_name;
|
||||
|
||||
Ok((r#use, def, fvs))
|
||||
}
|
||||
|
||||
fn apply_closure(def: &mut fun::Definition, fvs: Vec<Name>) {
|
||||
let rule = &mut def.rules[0];
|
||||
let mut captured = fvs.into_iter().map(|x| Pattern::Var(Some(x))).collect::<Vec<_>>();
|
||||
let rule_pats = std::mem::take(&mut rule.pats);
|
||||
captured.extend(rule_pats);
|
||||
rule.pats = captured;
|
||||
}
|
@ -1,5 +1,4 @@
|
||||
pub mod gen_map_get;
|
||||
pub mod lift_local_defs;
|
||||
mod order_kwargs;
|
||||
pub mod parser;
|
||||
pub mod to_fun;
|
||||
|
@ -52,7 +52,6 @@ impl Stmt {
|
||||
// TODO: Refactor this to not repeat everything.
|
||||
// TODO: When we have an error with an assignment, we should show the offending assignment (eg. "{pat} = ...").
|
||||
let stmt_to_fun = match self {
|
||||
Stmt::LocalDef { .. } => todo!(),
|
||||
Stmt::Assign { pat: AssignPattern::MapSet(map, key), val, nxt: Some(nxt) } => {
|
||||
let (nxt_pat, nxt) = match nxt.into_fun()? {
|
||||
StmtToFun::Return(term) => (None, term),
|
||||
@ -367,6 +366,19 @@ impl Stmt {
|
||||
}
|
||||
}
|
||||
Stmt::Return { term } => StmtToFun::Return(term.to_fun()),
|
||||
Stmt::LocalDef { def, nxt } => {
|
||||
let (nxt_pat, nxt) = match nxt.into_fun()? {
|
||||
StmtToFun::Return(term) => (None, term),
|
||||
StmtToFun::Assign(pat, term) => (Some(pat), term),
|
||||
};
|
||||
let def = def.to_fun(false)?;
|
||||
let term = fun::Term::Def { nam: def.name, rules: def.rules, nxt: Box::new(nxt) };
|
||||
if let Some(pat) = nxt_pat {
|
||||
StmtToFun::Assign(pat, term)
|
||||
} else {
|
||||
StmtToFun::Return(term)
|
||||
}
|
||||
}
|
||||
Stmt::Err => unreachable!(),
|
||||
};
|
||||
Ok(stmt_to_fun)
|
||||
|
@ -90,6 +90,8 @@ pub fn desugar_book(
|
||||
|
||||
ctx.set_entrypoint();
|
||||
|
||||
ctx.book.lift_defs();
|
||||
|
||||
ctx.book.encode_adts(opts.adt_encoding);
|
||||
|
||||
ctx.fix_match_defs()?;
|
||||
|
@ -2,6 +2,8 @@
|
||||
source: tests/golden_tests.rs
|
||||
input_file: tests/golden_tests/desugar_file/local_def_shadow.bend
|
||||
---
|
||||
(main) = 1
|
||||
|
||||
(main__local_0_A__local_0_B) = 0
|
||||
|
||||
(main__local_1_A__local_1_B) = 1
|
||||
@ -9,5 +11,3 @@ input_file: tests/golden_tests/desugar_file/local_def_shadow.bend
|
||||
(main__local_1_A) = main__local_1_A__local_1_B
|
||||
|
||||
(main__local_0_A) = main__local_0_A__local_0_B
|
||||
|
||||
(main) = 1
|
||||
|
@ -2,10 +2,10 @@
|
||||
source: tests/golden_tests.rs
|
||||
input_file: tests/golden_tests/desugar_file/main_aux.bend
|
||||
---
|
||||
(main) = (main__local_0_aux 89 2)
|
||||
|
||||
(main__local_0_aux__local_0_aux__local_0_aux) = λa λb (+ b a)
|
||||
|
||||
(main__local_0_aux__local_0_aux) = λa λb (main__local_0_aux__local_0_aux__local_0_aux a b)
|
||||
|
||||
(main__local_0_aux) = λa λb (main__local_0_aux__local_0_aux a b)
|
||||
|
||||
(main) = (main__local_0_aux 89 2)
|
||||
|
Loading…
Reference in New Issue
Block a user