[sc-517] Refactor Fix match terms, add 'concat' to dictionary

This commit is contained in:
Nicolas Abril 2024-04-01 15:54:08 +02:00
parent d60935d3de
commit afd0f7d68b
6 changed files with 77 additions and 60 deletions

View File

@ -12,6 +12,7 @@
"callcc", "callcc",
"chumsky", "chumsky",
"combinators", "combinators",
"concat",
"ctrs", "ctrs",
"Dall", "Dall",
"datatypes", "datatypes",

View File

@ -106,18 +106,18 @@ To ensure that recursive pattern matching functions don't loop in strict mode, i
Pattern matching equations also support matching on non-consecutive numbers: Pattern matching equations also support matching on non-consecutive numbers:
```rust ```rust
Parse '(' = Token.LParen Parse '(' = Token.LParenthesis
Parse ')' = Token.RParen Parse ')' = Token.RParenthesis
Parse 'λ' = Token.Lambda Parse 'λ' = Token.Lambda
Parse n = (Token.Name n) Parse n = (Token.Name n)
``` ```
This is compiled to a cascade of `switch` expressions, from smallest value to largest. This is compiled to a cascade of `switch` expressions, from smallest value to largest.
```rust ```rust
Parse = λarg0 switch matched = (- arg0 '(') { Parse = λarg0 switch matched = (- arg0 '(') {
0: Token.LParen 0: Token.LParenthesis
// ')' + 1 - '(' is resolved during compile time // ')' + 1 - '(' is resolved during compile time
_: switch matched = (- matched-1 ( ')'-1-'(' ) { _: switch matched = (- matched-1 ( ')'-1-'(' ) {
0: Token.RParen 0: Token.RParenthesis
_: switch matched = (- matched-1 ( 'λ'-1-')' ) { _: switch matched = (- matched-1 ( 'λ'-1-')' ) {
0: Token.Lambda 0: Token.Lambda
_: use n = (+ 1 matched-1); (Token.Name n) _: use n = (+ 1 matched-1); (Token.Name n)

View File

@ -142,12 +142,12 @@ pub enum Term {
/// Pattern matching on an ADT. /// Pattern matching on an ADT.
Mat { Mat {
arg: Box<Term>, arg: Box<Term>,
rules: Vec<(Option<Name>, Vec<Option<Name>>, Term)>, rules: Vec<MatchRule>,
}, },
/// Native pattern matching on numbers /// Native pattern matching on numbers
Swt { Swt {
arg: Box<Term>, arg: Box<Term>,
rules: Vec<(NumCtr, Term)>, rules: Vec<SwitchRule>,
}, },
Ref { Ref {
nam: Name, nam: Name,
@ -157,6 +157,9 @@ pub enum Term {
Err, Err,
} }
pub type MatchRule = (Option<Name>, Vec<Option<Name>>, Term);
pub type SwitchRule = (NumCtr, Term);
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Pattern { pub enum Pattern {
Var(Option<Name>), Var(Option<Name>),

View File

@ -194,7 +194,7 @@ fn tup_rule(
let pat = rule.pats[0].clone(); let pat = rule.pats[0].clone();
let old_pats = rule.pats.split_off(1); let old_pats = rule.pats.split_off(1);
// Extract subpats from the tuple pattern // Extract subpatterns from the tuple pattern
let mut new_pats = match pat { let mut new_pats = match pat {
Pattern::Tup(sub_pats) => sub_pats, Pattern::Tup(sub_pats) => sub_pats,
Pattern::Var(var) => { Pattern::Var(var) => {
@ -387,7 +387,7 @@ fn switch_rule(
for rule in &rules { for rule in &rules {
let old_pats = rule.pats[1 ..].to_vec(); let old_pats = rule.pats[1 ..].to_vec();
match &rule.pats[0] { match &rule.pats[0] {
// Same ctr, extract subpats. // Same ctr, extract subpatterns.
// (Ctr pat0_0 ... pat0_m) pat1 ... patN: body // (Ctr pat0_0 ... pat0_m) pat1 ... patN: body
// becomes // becomes
// pat0_0 ... pat0_m pat1 ... patN: body // pat0_0 ... pat0_m pat1 ... patN: body

View File

@ -1,4 +1,4 @@
use crate::term::{AdtEncoding, Book, Constructors, Name, NumCtr, Tag, Term}; use crate::term::{AdtEncoding, Book, Constructors, MatchRule, Name, NumCtr, SwitchRule, Tag, Term};
impl Book { impl Book {
/// Encodes pattern matching expressions in the book into their /// Encodes pattern matching expressions in the book into their
@ -38,12 +38,7 @@ impl Term {
} }
} }
fn encode_match( fn encode_match(arg: Term, rules: Vec<MatchRule>, ctrs: &Constructors, adt_encoding: AdtEncoding) -> Term {
arg: Term,
rules: Vec<(Option<Name>, Vec<Option<Name>>, Term)>,
ctrs: &Constructors,
adt_encoding: AdtEncoding,
) -> Term {
let adt = ctrs.get(rules[0].0.as_ref().unwrap()).unwrap(); let adt = ctrs.get(rules[0].0.as_ref().unwrap()).unwrap();
// ADT Encoding depends on compiler option // ADT Encoding depends on compiler option
@ -73,7 +68,7 @@ fn encode_match(
/// Convert into a sequence of native matches, decrementing by 1 each match. /// Convert into a sequence of native matches, decrementing by 1 each match.
/// match n {0: A; 1: B; 2+: (C n-2)} converted to /// match n {0: A; 1: B; 2+: (C n-2)} converted to
/// match n {0: A; 1+: @%x match %x {0: B; 1+: @n-2 (C n-2)}} /// match n {0: A; 1+: @%x match %x {0: B; 1+: @n-2 (C n-2)}}
fn encode_switch(arg: Term, mut rules: Vec<(NumCtr, Term)>) -> Term { fn encode_switch(arg: Term, mut rules: Vec<SwitchRule>) -> Term {
let last_rule = rules.pop().unwrap(); let last_rule = rules.pop().unwrap();
let match_var = Name::from("%x"); let match_var = Name::from("%x");

View File

@ -1,6 +1,6 @@
use crate::{ use crate::{
diagnostics::{Diagnostics, ToStringVerbose, WarningType}, diagnostics::{Diagnostics, ToStringVerbose, WarningType},
term::{Adts, Constructors, Ctx, Name, NumCtr, Term}, term::{Adts, Constructors, Ctx, MatchRule, Name, NumCtr, Term},
}; };
use std::collections::HashMap; use std::collections::HashMap;
@ -99,54 +99,14 @@ impl Term {
let (arg_nam, arg) = extract_match_arg(arg); let (arg_nam, arg) = extract_match_arg(arg);
// Normalize arms // Normalize arms, making one arm for each constructor of the matched adt.
if let Some(ctr_nam) = &rules[0].0 if let Some(ctr_nam) = &rules[0].0
&& let Some(adt_nam) = ctrs.get(ctr_nam) && let Some(adt_nam) = ctrs.get(ctr_nam)
{ {
let adt_ctrs = &adts[adt_nam].ctrs; let adt_ctrs = &adts[adt_nam].ctrs;
// For each arm, decide to which arm of the fixed match they map to. // Decide which constructor corresponds to which arm of the match.
let mut bodies = HashMap::<&Name, Option<Term>>::from_iter(adt_ctrs.iter().map(|(ctr, _)| (ctr, None))); let mut bodies = fixed_match_arms(rules, &arg_nam, adt_nam, adt_ctrs.keys(), ctrs, adts, errs);
for rule_idx in 0 .. rules.len() {
if let Some(ctr_nam) = &rules[rule_idx].0
&& let Some(found_adt) = ctrs.get(ctr_nam)
{
// Ctr arm
if found_adt == adt_nam {
let body = bodies.get_mut(ctr_nam).unwrap();
if body.is_none() {
// Use this rule for this constructor
*body = Some(rules[rule_idx].2.clone());
} else {
errs.push(FixMatchErr::RedundantArm { ctr: ctr_nam.clone() });
}
} else {
errs.push(FixMatchErr::AdtMismatch {
expected: adt_nam.clone(),
found: found_adt.clone(),
ctr: ctr_nam.clone(),
})
}
} else {
// Var arm
if rule_idx != rules.len() - 1 {
errs.push(FixMatchErr::UnreachableMatchArms { var: rules[rule_idx].0.clone() });
rules.truncate(rule_idx + 1);
}
// Use this rule for all remaining constructors
for (ctr, body) in bodies.iter_mut() {
if body.is_none() {
let mut new_body = rules[rule_idx].2.clone();
if let Some(var) = &rules[rule_idx].0 {
new_body.subst(var, &rebuild_ctr(&arg_nam, ctr, &adts[adt_nam].ctrs[&**ctr]));
}
*body = Some(new_body);
}
}
break;
}
}
// Build the match arms, with all constructors // Build the match arms, with all constructors
let mut new_rules = vec![]; let mut new_rules = vec![];
@ -162,6 +122,7 @@ impl Term {
} }
*rules = new_rules; *rules = new_rules;
} else { } else {
// First arm was not matching a constructor, convert into a use term.
errs.push(FixMatchErr::IrrefutableMatch { var: rules[0].0.clone() }); errs.push(FixMatchErr::IrrefutableMatch { var: rules[0].0.clone() });
let match_var = rules[0].0.take(); let match_var = rules[0].0.take();
@ -235,6 +196,63 @@ fn apply_arg(term: Term, arg_nam: Name, arg: Option<Term>) -> Term {
} }
} }
/// Given the rules of a match term, return the bodies that match
/// each of the constructors of the matched ADT.
///
/// If no rules match a certain constructor, return None in the map,
/// indicating a non-exhaustive match.
fn fixed_match_arms<'a>(
rules: &mut Vec<MatchRule>,
arg_nam: &Name,
adt_nam: &Name,
adt_ctrs: impl Iterator<Item = &'a Name>,
ctrs: &Constructors,
adts: &Adts,
errs: &mut Vec<FixMatchErr>,
) -> HashMap<&'a Name, Option<Term>> {
let mut bodies = HashMap::<&Name, Option<Term>>::from_iter(adt_ctrs.map(|ctr| (ctr, None)));
for rule_idx in 0 .. rules.len() {
if let Some(ctr_nam) = &rules[rule_idx].0
&& let Some(found_adt) = ctrs.get(ctr_nam)
{
// Ctr arm, use the body of this rule for this constructor.
if found_adt == adt_nam {
let body = bodies.get_mut(ctr_nam).unwrap();
if body.is_none() {
// Use this rule for this constructor
*body = Some(rules[rule_idx].2.clone());
} else {
errs.push(FixMatchErr::RedundantArm { ctr: ctr_nam.clone() });
}
} else {
errs.push(FixMatchErr::AdtMismatch {
expected: adt_nam.clone(),
found: found_adt.clone(),
ctr: ctr_nam.clone(),
})
}
} else {
// Var arm, use the body of this rule for all non-covered constructors.
for (ctr, body) in bodies.iter_mut() {
if body.is_none() {
let mut new_body = rules[rule_idx].2.clone();
if let Some(var) = &rules[rule_idx].0 {
new_body.subst(var, &rebuild_ctr(arg_nam, ctr, &adts[adt_nam].ctrs[&**ctr]));
}
*body = Some(new_body);
}
}
if rule_idx != rules.len() - 1 {
errs.push(FixMatchErr::UnreachableMatchArms { var: rules[rule_idx].0.clone() });
rules.truncate(rule_idx + 1);
}
break;
}
}
bodies
}
fn match_field(arg: &Name, field: &Name) -> Name { fn match_field(arg: &Name, field: &Name) -> Name {
Name::new(format!("{arg}.{field}")) Name::new(format!("{arg}.{field}"))
} }