Refactor var-use counting when linearizing

This commit is contained in:
LunaAmora 2024-02-16 15:31:46 -03:00
parent 3ad3e3233c
commit fe057eab51
3 changed files with 44 additions and 146 deletions

View File

@ -1,6 +1,6 @@
use crate::term::{Book, MatchNum, Name, Pattern, Tag, Term};
use hvmc::run::Val;
use std::collections::{hash_map::Entry, HashMap};
use std::collections::HashMap;
/// Erases variables that weren't used, dups the ones that were used more than once.
/// Substitutes lets into their variable use.
@ -24,9 +24,7 @@ impl Book {
impl Term {
pub fn linearize_vars(&mut self) {
let mut var_uses = HashMap::new();
count_var_uses_in_term(self, &mut var_uses);
term_to_affine(self, &mut var_uses, &mut HashMap::new(), &mut HashMap::new());
term_to_affine(self, &mut HashMap::new());
}
/// Returns false wether the term has no unscoped terms,
@ -37,130 +35,40 @@ impl Term {
}
}
fn count_var_uses_in_term(term: &Term, uses: &mut HashMap<Name, Val>) {
match term {
// Var users
Term::Var { nam } => {
*uses.entry(nam.clone()).or_default() += 1;
}
// Var producers
Term::Lam { nam, bod, .. } => {
add_var(nam.as_ref(), uses);
count_var_uses_in_term(bod, uses);
}
Term::Dup { fst, snd, val, nxt, .. }
| Term::Let { pat: Pattern::Tup(box Pattern::Var(fst), box Pattern::Var(snd)), val, nxt } => {
add_var(fst.as_ref(), uses);
add_var(snd.as_ref(), uses);
count_var_uses_in_term(val, uses);
count_var_uses_in_term(nxt, uses);
}
Term::Let { pat: Pattern::Var(nam), val, nxt } => {
add_var(nam.as_ref(), uses);
count_var_uses_in_term(val, uses);
count_var_uses_in_term(nxt, uses);
}
Term::Let { .. } => unreachable!(),
// Others
Term::Chn { bod, .. } => count_var_uses_in_term(bod, uses),
Term::App { fun: fst, arg: snd, .. }
| Term::Sup { fst, snd, .. }
| Term::Tup { fst, snd }
| Term::Opx { fst, snd, .. } => {
count_var_uses_in_term(fst, uses);
count_var_uses_in_term(snd, uses);
}
Term::Mat { matched, arms } => {
count_var_uses_in_term(matched, uses);
for (rule, term) in arms {
if let Pattern::Num(MatchNum::Succ(Some(nam))) = rule {
add_var(nam.as_ref(), uses);
}
count_var_uses_in_term(term, uses);
}
}
Term::Lst { .. } => unreachable!("Should have been desugared already"),
Term::Lnk { .. } | Term::Ref { .. } | Term::Num { .. } | Term::Str { .. } | Term::Era | Term::Err => {}
}
}
/// Var-declaring terms
fn term_with_bind_to_affine(
term: &mut Term,
nam: &mut Option<Name>,
var_uses: &mut HashMap<Name, Val>,
inst_count: &mut HashMap<Name, Val>,
let_bodies: &mut HashMap<Name, Term>,
) {
if let Some(name) = nam {
if var_uses.contains_key(name) {
term_to_affine(term, var_uses, inst_count, let_bodies);
let instantiated_count = get_var_uses(Some(name), inst_count);
duplicate_lam(nam, term, instantiated_count);
return;
}
}
fn term_with_bind_to_affine(term: &mut Term, nam: &mut Option<Name>, inst_count: &mut HashMap<Name, Val>) {
term_to_affine(term, inst_count);
term_to_affine(term, var_uses, inst_count, let_bodies);
if nam.is_some() {
let instantiated_count = get_var_uses(nam.as_ref(), inst_count);
duplicate_lam(nam, term, instantiated_count);
}
}
fn term_to_affine(
term: &mut Term,
var_uses: &mut HashMap<Name, Val>,
// Count to number of times a `Term::Var { nam }` has been reached without being linearized out
inst_count: &mut HashMap<Name, Val>,
let_bodies: &mut HashMap<Name, Term>,
) {
fn term_to_affine(term: &mut Term, inst_count: &mut HashMap<Name, Val>) {
match term {
Term::Lam { nam, bod, .. } => term_with_bind_to_affine(bod, nam, var_uses, inst_count, let_bodies),
Term::Lam { nam, bod, .. } => term_with_bind_to_affine(bod, nam, inst_count),
Term::Let { pat: Pattern::Var(Some(nam)), val, nxt } => {
let uses = var_uses[nam];
match uses {
term_to_affine(nxt, inst_count);
match get_var_uses(Some(nam), inst_count) {
0 => {
if val.has_unscoped() {
term_to_affine(val, var_uses, inst_count, let_bodies);
term_to_affine(nxt, var_uses, inst_count, let_bodies);
term_to_affine(val, inst_count);
let Term::Let { val, nxt, .. } = std::mem::take(term) else { unreachable!() };
*term = Term::Let { pat: Pattern::Var(None), val, nxt };
return;
}
// We are going to remove the val term,
// so we need to remove the free variables it uses from the vars count
for (var, used) in val.free_vars() {
let Entry::Occupied(mut entry) = var_uses.entry(var) else { unreachable!() };
if *entry.get() <= used {
entry.remove();
} else {
*entry.get_mut() -= used;
}
}
term_to_affine(nxt, var_uses, inst_count, let_bodies);
}
1 => {
term_to_affine(val, var_uses, inst_count, let_bodies);
let_bodies.insert(nam.clone(), std::mem::take(val.as_mut()));
term_to_affine(nxt, var_uses, inst_count, let_bodies);
term_to_affine(val, inst_count);
nxt.subst(&dup_name(nam, 1), val.as_ref());
}
uses => {
term_to_affine(val, var_uses, inst_count, let_bodies);
term_to_affine(nxt, var_uses, inst_count, let_bodies);
let mut instantiated_count = get_var_uses(Some(nam), inst_count);
if uses != instantiated_count {
// TODO: This is done because the number of uses changed (because a term was linearized out)
// It creates an extra half-erased-dup `let {nam_n *} = nam_m_dup; nxt` to match the correct labels.
// Should be refactored.
instantiated_count += 1;
};
instantiated_count => {
term_to_affine(val, inst_count);
duplicate_let(nam, nxt, instantiated_count, val);
}
}
@ -169,8 +77,8 @@ fn term_to_affine(
Term::Let { pat: Pattern::Var(None), val, nxt } => {
if val.has_unscoped() {
term_to_affine(val, var_uses, inst_count, let_bodies);
term_to_affine(nxt, var_uses, inst_count, let_bodies);
term_to_affine(val, inst_count);
term_to_affine(nxt, inst_count);
} else {
let Term::Let { nxt, .. } = std::mem::take(term) else { unreachable!() };
*term = *nxt;
@ -179,10 +87,10 @@ fn term_to_affine(
Term::Dup { fst, snd, val, nxt, .. }
| Term::Let { pat: Pattern::Tup(box Pattern::Var(fst), box Pattern::Var(snd)), val, nxt } => {
let uses_fst = get_var_uses(fst.as_ref(), var_uses);
let uses_snd = get_var_uses(snd.as_ref(), var_uses);
term_to_affine(val, var_uses, inst_count, let_bodies);
term_to_affine(nxt, var_uses, inst_count, let_bodies);
term_to_affine(val, inst_count);
term_to_affine(nxt, inst_count);
let uses_fst = get_var_uses(fst.as_ref(), inst_count);
let uses_snd = get_var_uses(snd.as_ref(), inst_count);
duplicate_lam(fst, nxt, uses_fst);
duplicate_lam(snd, nxt, uses_snd);
}
@ -191,38 +99,35 @@ fn term_to_affine(
// Var-using terms
Term::Var { nam } => {
*var_uses.get_mut(nam).unwrap() -= 1;
if let Some(subst) = let_bodies.remove(nam) {
*term = subst;
} else {
let instantiated_count = inst_count.entry(nam.clone()).or_default();
*instantiated_count += 1;
let instantiated_count = inst_count.entry(nam.clone()).or_default();
*instantiated_count += 1;
*nam = dup_name(nam, *instantiated_count);
}
*nam = dup_name(nam, *instantiated_count);
}
// Others
Term::Chn { bod, .. } => term_to_affine(bod, var_uses, inst_count, let_bodies),
Term::Chn { bod, .. } => term_to_affine(bod, inst_count),
Term::App { fun: fst, arg: snd, .. }
| Term::Sup { fst, snd, .. }
| Term::Tup { fst, snd }
| Term::Opx { fst, snd, .. } => {
term_to_affine(fst, var_uses, inst_count, let_bodies);
term_to_affine(snd, var_uses, inst_count, let_bodies);
term_to_affine(fst, inst_count);
term_to_affine(snd, inst_count);
}
Term::Mat { matched, arms } => {
term_to_affine(matched, var_uses, inst_count, let_bodies);
term_to_affine(matched, inst_count);
for (rule, term) in arms {
match rule {
Pattern::Num(MatchNum::Succ(Some(nam))) => {
term_with_bind_to_affine(term, nam, var_uses, inst_count, let_bodies);
term_with_bind_to_affine(term, nam, inst_count);
}
Pattern::Num(_) => term_to_affine(term, var_uses, inst_count, let_bodies),
Pattern::Num(_) => term_to_affine(term, inst_count),
_ => unreachable!(),
}
}
}
Term::Lst { .. } => unreachable!("Should have been desugared already"),
Term::Era | Term::Lnk { .. } | Term::Ref { .. } | Term::Num { .. } | Term::Str { .. } | Term::Err => {}
};
@ -243,24 +148,15 @@ fn get_var_uses(nam: Option<&Name>, var_uses: &HashMap<Name, Val>) -> Val {
/// nxt
/// ```
fn make_dup_tree(nam: &Name, nxt: &mut Term, uses: Val, mut dup_body: Option<&mut Term>) {
let free_vars = &mut nxt.free_vars();
if let Some(ref body) = dup_body {
free_vars.extend(body.free_vars());
};
let make_name = |uses| {
let dup_name = dup_name(nam, uses);
free_vars.contains_key(&dup_name).then_some(dup_name)
};
for i in (1 .. uses).rev() {
*nxt = Term::Dup {
tag: Tag::Auto,
fst: make_name(i),
snd: if i == uses - 1 { make_name(uses) } else { Some(internal_dup_name(nam, i)) },
fst: Some(dup_name(nam, i)),
snd: if i == uses - 1 { Some(dup_name(nam, uses)) } else { Some(internal_dup_name(nam, i)) },
val: if i == 1 {
Box::new(dup_body.as_mut().map_or_else(|| Term::Var { nam: nam.clone() }, |x| std::mem::take(*x)))
Box::new(
dup_body.as_deref_mut().map_or_else(|| Term::Var { nam: nam.clone() }, |x| std::mem::take(x)),
)
} else {
Box::new(Term::Var { nam: internal_dup_name(nam, i - 1) })
},
@ -277,6 +173,7 @@ fn duplicate_lam(nam: &mut Option<Name>, nxt: &mut Term, uses: Val) {
}
}
#[allow(dead_code)]
fn duplicate_let(nam: &Name, nxt: &mut Term, uses: Val, let_body: &mut Term) {
make_dup_tree(nam, nxt, uses, Some(let_body));
}
@ -289,6 +186,7 @@ fn internal_dup_name(nam: &Name, uses: Val) -> Name {
format!("{}_dup", dup_name(nam, uses)).into()
}
#[allow(dead_code)]
fn add_var(nam: Option<&Name>, uses: &mut HashMap<Name, Val>) {
if let Some(nam) = nam {
uses.entry(nam.clone()).or_insert(0);

View File

@ -3,5 +3,5 @@ source: tests/golden_tests.rs
input_file: tests/golden_tests/compile_file/unused_dup_var_linearization.hvm
---
@main = a
& (b b) ~ {3 * {5 {9 c *} {7 (c a) *}}}
& (b b) ~ {3 (c a) c}

View File

@ -2,4 +2,4 @@
source: tests/golden_tests.rs
input_file: tests/golden_tests/desugar_file/dup_linearization.hvm
---
(main) = let {a_1 a_1_dup} = *; let {a_2 a_2_dup} = a_1_dup; let {a_3 a_3_dup} = a_2_dup; let {a_4 a_5} = a_3_dup; ((a_5, a_1), (a_2, (a_3, a_4)))
(main) = let {a_1 a_1_dup} = *; let {a_2 a_2_dup} = a_1_dup; let {a_3 a_3_dup} = a_2_dup; let {a_4 a_5} = a_3_dup; ((a_1, a_5), (a_4, (a_3, a_2)))