Reorder local defs lifting

This commit is contained in:
imaqtkatt 2024-06-28 11:08:03 -03:00
parent 945bed6872
commit 0cd4386f32
4 changed files with 32 additions and 11 deletions

View File

@ -24,17 +24,15 @@ impl Stmt {
match self {
Stmt::LocalDef { .. } => {
let Stmt::LocalDef { mut def, mut nxt } = std::mem::take(self) else { unreachable!() };
let children = def.lift_local_defs(gen)?;
nxt.lift_local_defs(parent, defs, gen)?;
let local_name = Name::new(format!("{}__local_{}_{}", parent, gen, def.name));
def.body.lift_local_defs(&local_name, defs, gen)?;
nxt.lift_local_defs(parent, defs, gen)?;
*gen += 1;
let (r#use, mut def, fvs) = gen_use(local_name.clone(), *def, nxt)?;
*self = r#use;
apply_closure(&mut def, fvs);
defs.extend(children);
defs.insert(def.name.clone(), def);
Ok(())
}
@ -122,8 +120,8 @@ fn gen_use(
fn apply_closure(def: &mut fun::Definition, fvs: Vec<Name>) {
let rule = &mut def.rules[0];
let mut n_pats = fvs.into_iter().map(|x| Pattern::Var(Some(x))).collect::<Vec<_>>();
let mut captured = fvs.into_iter().map(|x| Pattern::Var(Some(x))).collect::<Vec<_>>();
let rule_pats = std::mem::take(&mut rule.pats);
n_pats.extend(rule_pats);
rule.pats = n_pats;
captured.extend(rule_pats);
rule.pats = captured;
}

View File

@ -0,0 +1,10 @@
def main:
def A():
def B():
return 0
return B()
def A():
def B():
return 1
return B()
return A()

View File

@ -0,0 +1,13 @@
---
source: tests/golden_tests.rs
input_file: tests/golden_tests/desugar_file/local_def_shadow.bend
---
(main__local_0_A__local_0_B) = 0
(main__local_1_A__local_1_B) = 1
(main__local_1_A) = λa a
(main__local_0_A) = λa a
(main) = (main__local_1_A main__local_1_A__local_1_B)

View File

@ -2,10 +2,10 @@
source: tests/golden_tests.rs
input_file: tests/golden_tests/desugar_file/main_aux.bend
---
(aux__local_0_aux) = λa λb (+ b a)
(main__local_0_aux__local_0_aux__local_0_aux) = λa λb (+ b a)
(aux__local_1_aux) = λa λb λc (a b c)
(main__local_0_aux__local_0_aux) = λa λb λc (a b c)
(main__local_2_aux) = λa λb λc λd (b a c d)
(main__local_0_aux) = λa λb λc λd (a b c d)
(main) = (main__local_2_aux aux__local_0_aux aux__local_1_aux 89 2)
(main) = (main__local_0_aux main__local_0_aux__local_0_aux main__local_0_aux__local_0_aux__local_0_aux 89 2)