Add local def statement

This commit is contained in:
imaqtkatt 2024-06-27 15:23:38 -03:00
parent 10610c3774
commit 1785b392fe
9 changed files with 184 additions and 6 deletions

View File

@ -781,7 +781,9 @@ impl<'a> TermParser<'a> {
self.check_top_level_redefinition(&def.name, book, span)?;
def.order_kwargs(book)?;
def.gen_map_get();
let locals = def.lift_local_defs(&mut 0)?;
let def = def.to_fun(builtin)?;
book.defs.extend(locals);
book.defs.insert(def.name.clone(), def);
Ok(())
}

View File

@ -15,6 +15,10 @@ impl Definition {
impl Stmt {
fn gen_map_get(&mut self, id: &mut usize) {
match self {
Stmt::LocalDef { def, nxt } => {
nxt.gen_map_get(id);
def.gen_map_get()
}
Stmt::Assign { pat, val, nxt } => {
let key_substitutions =
if let AssignPattern::MapSet(_, key) = pat { key.substitute_map_gets(id) } else { Vec::new() };

129
src/imp/lift_local_defs.rs Normal file
View File

@ -0,0 +1,129 @@
use std::collections::BTreeMap;
use indexmap::IndexMap;
use crate::fun::{self, Name, Pattern};
use super::{Definition, Expr, Stmt};
impl Definition {
pub fn lift_local_defs(&mut self, gen: &mut usize) -> Result<IndexMap<Name, fun::Definition>, String> {
let mut defs = IndexMap::new();
self.body.lift_local_defs(&self.name, &mut defs, gen)?;
Ok(defs)
}
}
impl Stmt {
pub fn lift_local_defs(
&mut self,
parent: &Name,
defs: &mut IndexMap<Name, fun::Definition>,
gen: &mut usize,
) -> Result<(), String> {
match self {
Stmt::LocalDef { .. } => {
let Stmt::LocalDef { mut def, mut nxt } = std::mem::take(self) else { unreachable!() };
let children = def.lift_local_defs(gen)?;
nxt.lift_local_defs(parent, defs, gen)?;
let local_name = Name::new(format!("{}__local_{}_{}", parent, gen, def.name));
*gen += 1;
let (r#use, mut def, fvs) = gen_use(local_name.clone(), *def, nxt)?;
*self = r#use;
apply_closure(&mut def, fvs);
defs.extend(children);
defs.insert(def.name.clone(), def);
Ok(())
}
Stmt::Assign { pat: _, val: _, nxt } => {
if let Some(nxt) = nxt {
nxt.lift_local_defs(parent, defs, gen)?;
}
Ok(())
}
Stmt::If { cond: _, then, otherwise, nxt } => {
then.lift_local_defs(parent, defs, gen)?;
otherwise.lift_local_defs(parent, defs, gen)?;
if let Some(nxt) = nxt {
nxt.lift_local_defs(parent, defs, gen)?;
}
Ok(())
}
Stmt::Match { arg: _, bnd: _, with_bnd: _, with_arg: _, arms, nxt }
| Stmt::Fold { arg: _, bnd: _, with_bnd: _, with_arg: _, arms, nxt } => {
for arm in arms.iter_mut() {
arm.rgt.lift_local_defs(parent, defs, gen)?;
}
if let Some(nxt) = nxt {
nxt.lift_local_defs(parent, defs, gen)?;
}
Ok(())
}
Stmt::Switch { arg: _, bnd: _, with_bnd: _, with_arg: _, arms, nxt } => {
for arm in arms.iter_mut() {
arm.lift_local_defs(parent, defs, gen)?;
}
if let Some(nxt) = nxt {
nxt.lift_local_defs(parent, defs, gen)?;
}
Ok(())
}
Stmt::Bend { bnd: _, arg: _, cond: _, step, base, nxt } => {
step.lift_local_defs(parent, defs, gen)?;
base.lift_local_defs(parent, defs, gen)?;
if let Some(nxt) = nxt {
nxt.lift_local_defs(parent, defs, gen)?;
}
Ok(())
}
Stmt::With { typ: _, bod, nxt } => {
bod.lift_local_defs(parent, defs, gen)?;
if let Some(nxt) = nxt {
nxt.lift_local_defs(parent, defs, gen)?;
}
Ok(())
}
Stmt::InPlace { op: _, pat: _, val: _, nxt }
| Stmt::Ask { pat: _, val: _, nxt }
| Stmt::Open { typ: _, var: _, nxt }
| Stmt::Use { nam: _, val: _, nxt } => nxt.lift_local_defs(parent, defs, gen),
Stmt::Return { .. } | Stmt::Err => Ok(()),
}
}
}
fn gen_use(
local_name: Name,
def: Definition,
nxt: Box<Stmt>,
) -> Result<(Stmt, fun::Definition, Vec<Name>), String> {
let params = def.params.clone();
let mut def = def.to_fun(false)?;
let fvs = BTreeMap::from_iter(def.rules[0].body.free_vars());
let fvs = fvs.into_keys().filter(|fv| !params.contains(fv)).collect::<Vec<_>>();
let val = Expr::Call {
fun: Box::new(Expr::Var { nam: local_name.clone() }),
args: fvs.iter().cloned().map(|nam| Expr::Var { nam }).collect(),
kwargs: vec![],
};
let r#use = Stmt::Use { nam: def.name.clone(), val: Box::new(val), nxt };
def.name = local_name;
Ok((r#use, def, fvs))
}
fn apply_closure(def: &mut fun::Definition, fvs: Vec<Name>) {
let rule = &mut def.rules[0];
let mut n_pats = fvs.into_iter().map(|x| Pattern::Var(Some(x))).collect::<Vec<_>>();
let rule_pats = std::mem::take(&mut rule.pats);
n_pats.extend(rule_pats);
rule.pats = n_pats;
}

View File

@ -1,4 +1,5 @@
pub mod gen_map_get;
pub mod lift_local_defs;
mod order_kwargs;
pub mod parser;
pub mod to_fun;
@ -105,7 +106,7 @@ pub enum Stmt {
otherwise: Box<Stmt>,
nxt: Option<Box<Stmt>>,
},
// "match" {arg} ":" ("as" {bind})?
// "match" ({bind} "=")? {arg} ({with_clause})? ":"
// case {lft} ":" {rgt}
// ...
// <nxt>?
@ -117,7 +118,7 @@ pub enum Stmt {
arms: Vec<MatchArm>,
nxt: Option<Box<Stmt>>,
},
// "switch" {arg} ("as" {bind})?
// "switch" ({bind} "=")? {arg} ({with_clause})? ":"
// case 0..wildcard ":" {rgt}
// ...
// <nxt>?
@ -129,7 +130,7 @@ pub enum Stmt {
arms: Vec<Stmt>,
nxt: Option<Box<Stmt>>,
},
// "bend" ({bind} ("="" {init})?)* "while" {cond} ":"
// "bend" ({bind} ("=" {init})?)* "while" {cond} ":"
// {step}
// "then" ":"
// {base}
@ -142,7 +143,7 @@ pub enum Stmt {
base: Box<Stmt>,
nxt: Option<Box<Stmt>>,
},
// "fold" {arg} ("as" {bind})? ":" {arms}
// "fold" ({bind} "=")? {arg} ({with_clause})? ":" {arms}
// case {lft} ":" {rgt}
// ...
// <nxt>?
@ -184,11 +185,16 @@ pub enum Stmt {
val: Box<Expr>,
nxt: Box<Stmt>,
},
// {def} {nxt}
LocalDef {
def: Box<Definition>,
nxt: Box<Stmt>,
},
#[default]
Err,
}
// Name "(" {fields}* ")"
// Name "{" {fields}* "}"
#[derive(Clone, Debug)]
pub struct Variant {
pub name: Name,

View File

@ -15,6 +15,10 @@ impl Definition {
impl Stmt {
fn order_kwargs(&mut self, book: &Book) -> Result<(), String> {
match self {
Stmt::LocalDef { def, nxt } => {
def.order_kwargs(book)?;
nxt.order_kwargs(book)?;
}
Stmt::Assign { val, nxt, .. } => {
val.order_kwargs(book)?;
if let Some(nxt) = nxt {

View File

@ -423,6 +423,8 @@ impl<'a> PyParser<'a> {
maybe_grow(|| {
if self.try_parse_keyword("return") {
self.parse_return()
} else if self.try_parse_keyword("def") {
self.parse_local_def(indent)
} else if self.try_parse_keyword("if") {
self.parse_if(indent)
} else if self.try_parse_keyword("match") {
@ -963,13 +965,23 @@ impl<'a> PyParser<'a> {
Ok((stmt, nxt_indent))
}
pub fn parse_def(&mut self, mut indent: Indent) -> ParseResult<(Definition, Indent)> {
fn parse_local_def(&mut self, indent: &mut Indent) -> ParseResult<(Stmt, Indent)> {
let (def, mut nxt_indent) = self.parse_def_aux(*indent)?;
let (nxt, nxt_indent) = self.parse_statement(&mut nxt_indent)?;
let stmt = Stmt::LocalDef { def: Box::new(def), nxt: Box::new(nxt) };
Ok((stmt, nxt_indent))
}
pub fn parse_def(&mut self, indent: Indent) -> ParseResult<(Definition, Indent)> {
if indent != Indent::Val(0) {
let msg = "Indentation error. Functions defined with 'def' must be at the start of the line.";
let idx = *self.index();
return self.with_ctx(Err(msg), idx..idx + 1);
}
self.parse_def_aux(indent)
}
fn parse_def_aux(&mut self, mut indent: Indent) -> ParseResult<(Definition, Indent)> {
self.skip_trivia_inline()?;
let name = self.parse_top_level_name()?;
self.skip_trivia_inline()?;

View File

@ -52,6 +52,7 @@ impl Stmt {
// TODO: Refactor this to not repeat everything.
// TODO: When we have an error with an assignment, we should show the offending assignment (eg. "{pat} = ...").
let stmt_to_fun = match self {
Stmt::LocalDef { .. } => todo!(),
Stmt::Assign { pat: AssignPattern::MapSet(map, key), val, nxt: Some(nxt) } => {
let (nxt_pat, nxt) = match nxt.into_fun()? {
StmtToFun::Return(term) => (None, term),

View File

@ -0,0 +1,9 @@
def main:
y = 89
def aux(x):
def aux(x):
def aux(x):
return x + y
return aux(x)
return aux(x)
return aux(2)

View File

@ -0,0 +1,11 @@
---
source: tests/golden_tests.rs
input_file: tests/golden_tests/desugar_file/main_aux.bend
---
(aux__local_0_aux) = λa λb (+ b a)
(aux__local_1_aux) = λa λb λc (a b c)
(main__local_2_aux) = λa λb λc λd (b a c d)
(main) = (main__local_2_aux aux__local_0_aux aux__local_1_aux 89 2)