From afd0f7d68bf5e2a57c77be5038b917b05857b877 Mon Sep 17 00:00:00 2001 From: Nicolas Abril Date: Mon, 1 Apr 2024 15:54:08 +0200 Subject: [PATCH] [sc-517] Refactor Fix match terms, add 'concat' to dictionary --- cspell.json | 1 + docs/pattern-matching.md | 8 +- src/term/mod.rs | 7 +- src/term/transform/desugar_match_defs.rs | 4 +- src/term/transform/encode_match_terms.rs | 11 +-- src/term/transform/fix_match_terms.rs | 106 +++++++++++++---------- 6 files changed, 77 insertions(+), 60 deletions(-) diff --git a/cspell.json b/cspell.json index fa3e07e3..f225b5c1 100644 --- a/cspell.json +++ b/cspell.json @@ -12,6 +12,7 @@ "callcc", "chumsky", "combinators", + "concat", "ctrs", "Dall", "datatypes", diff --git a/docs/pattern-matching.md b/docs/pattern-matching.md index 53ce995b..db853831 100644 --- a/docs/pattern-matching.md +++ b/docs/pattern-matching.md @@ -106,18 +106,18 @@ To ensure that recursive pattern matching functions don't loop in strict mode, i Pattern matching equations also support matching on non-consecutive numbers: ```rust -Parse '(' = Token.LParen -Parse ')' = Token.RParen +Parse '(' = Token.LParenthesis +Parse ')' = Token.RParenthesis Parse 'λ' = Token.Lambda Parse n = (Token.Name n) ``` This is compiled to a cascade of `switch` expressions, from smallest value to largest. ```rust Parse = λarg0 switch matched = (- arg0 '(') { - 0: Token.LParen + 0: Token.LParenthesis // ')' + 1 - '(' is resolved during compile time _: switch matched = (- matched-1 ( ')'-1-'(' ) { - 0: Token.RParen + 0: Token.RParenthesis _: switch matched = (- matched-1 ( 'λ'-1-')' ) { 0: Token.Lambda _: use n = (+ 1 matched-1); (Token.Name n) diff --git a/src/term/mod.rs b/src/term/mod.rs index b96d3657..7daff0d6 100644 --- a/src/term/mod.rs +++ b/src/term/mod.rs @@ -142,12 +142,12 @@ pub enum Term { /// Pattern matching on an ADT. Mat { arg: Box, - rules: Vec<(Option, Vec>, Term)>, + rules: Vec, }, /// Native pattern matching on numbers Swt { arg: Box, - rules: Vec<(NumCtr, Term)>, + rules: Vec, }, Ref { nam: Name, @@ -157,6 +157,9 @@ pub enum Term { Err, } +pub type MatchRule = (Option, Vec>, Term); +pub type SwitchRule = (NumCtr, Term); + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Pattern { Var(Option), diff --git a/src/term/transform/desugar_match_defs.rs b/src/term/transform/desugar_match_defs.rs index f3cd26e3..b846b55c 100644 --- a/src/term/transform/desugar_match_defs.rs +++ b/src/term/transform/desugar_match_defs.rs @@ -194,7 +194,7 @@ fn tup_rule( let pat = rule.pats[0].clone(); let old_pats = rule.pats.split_off(1); - // Extract subpats from the tuple pattern + // Extract subpatterns from the tuple pattern let mut new_pats = match pat { Pattern::Tup(sub_pats) => sub_pats, Pattern::Var(var) => { @@ -387,7 +387,7 @@ fn switch_rule( for rule in &rules { let old_pats = rule.pats[1 ..].to_vec(); match &rule.pats[0] { - // Same ctr, extract subpats. + // Same ctr, extract subpatterns. // (Ctr pat0_0 ... pat0_m) pat1 ... patN: body // becomes // pat0_0 ... pat0_m pat1 ... patN: body diff --git a/src/term/transform/encode_match_terms.rs b/src/term/transform/encode_match_terms.rs index 917a4dae..3bf9a9d3 100644 --- a/src/term/transform/encode_match_terms.rs +++ b/src/term/transform/encode_match_terms.rs @@ -1,4 +1,4 @@ -use crate::term::{AdtEncoding, Book, Constructors, Name, NumCtr, Tag, Term}; +use crate::term::{AdtEncoding, Book, Constructors, MatchRule, Name, NumCtr, SwitchRule, Tag, Term}; impl Book { /// Encodes pattern matching expressions in the book into their @@ -38,12 +38,7 @@ impl Term { } } -fn encode_match( - arg: Term, - rules: Vec<(Option, Vec>, Term)>, - ctrs: &Constructors, - adt_encoding: AdtEncoding, -) -> Term { +fn encode_match(arg: Term, rules: Vec, ctrs: &Constructors, adt_encoding: AdtEncoding) -> Term { let adt = ctrs.get(rules[0].0.as_ref().unwrap()).unwrap(); // ADT Encoding depends on compiler option @@ -73,7 +68,7 @@ fn encode_match( /// Convert into a sequence of native matches, decrementing by 1 each match. /// match n {0: A; 1: B; 2+: (C n-2)} converted to /// match n {0: A; 1+: @%x match %x {0: B; 1+: @n-2 (C n-2)}} -fn encode_switch(arg: Term, mut rules: Vec<(NumCtr, Term)>) -> Term { +fn encode_switch(arg: Term, mut rules: Vec) -> Term { let last_rule = rules.pop().unwrap(); let match_var = Name::from("%x"); diff --git a/src/term/transform/fix_match_terms.rs b/src/term/transform/fix_match_terms.rs index 904d7359..51fc122f 100644 --- a/src/term/transform/fix_match_terms.rs +++ b/src/term/transform/fix_match_terms.rs @@ -1,6 +1,6 @@ use crate::{ diagnostics::{Diagnostics, ToStringVerbose, WarningType}, - term::{Adts, Constructors, Ctx, Name, NumCtr, Term}, + term::{Adts, Constructors, Ctx, MatchRule, Name, NumCtr, Term}, }; use std::collections::HashMap; @@ -99,54 +99,14 @@ impl Term { let (arg_nam, arg) = extract_match_arg(arg); - // Normalize arms + // Normalize arms, making one arm for each constructor of the matched adt. if let Some(ctr_nam) = &rules[0].0 && let Some(adt_nam) = ctrs.get(ctr_nam) { let adt_ctrs = &adts[adt_nam].ctrs; - // For each arm, decide to which arm of the fixed match they map to. - let mut bodies = HashMap::<&Name, Option>::from_iter(adt_ctrs.iter().map(|(ctr, _)| (ctr, None))); - for rule_idx in 0 .. rules.len() { - if let Some(ctr_nam) = &rules[rule_idx].0 - && let Some(found_adt) = ctrs.get(ctr_nam) - { - // Ctr arm - if found_adt == adt_nam { - let body = bodies.get_mut(ctr_nam).unwrap(); - if body.is_none() { - // Use this rule for this constructor - *body = Some(rules[rule_idx].2.clone()); - } else { - errs.push(FixMatchErr::RedundantArm { ctr: ctr_nam.clone() }); - } - } else { - errs.push(FixMatchErr::AdtMismatch { - expected: adt_nam.clone(), - found: found_adt.clone(), - ctr: ctr_nam.clone(), - }) - } - } else { - // Var arm - if rule_idx != rules.len() - 1 { - errs.push(FixMatchErr::UnreachableMatchArms { var: rules[rule_idx].0.clone() }); - rules.truncate(rule_idx + 1); - } - // Use this rule for all remaining constructors - for (ctr, body) in bodies.iter_mut() { - if body.is_none() { - let mut new_body = rules[rule_idx].2.clone(); - if let Some(var) = &rules[rule_idx].0 { - new_body.subst(var, &rebuild_ctr(&arg_nam, ctr, &adts[adt_nam].ctrs[&**ctr])); - } - *body = Some(new_body); - } - } - - break; - } - } + // Decide which constructor corresponds to which arm of the match. + let mut bodies = fixed_match_arms(rules, &arg_nam, adt_nam, adt_ctrs.keys(), ctrs, adts, errs); // Build the match arms, with all constructors let mut new_rules = vec![]; @@ -162,6 +122,7 @@ impl Term { } *rules = new_rules; } else { + // First arm was not matching a constructor, convert into a use term. errs.push(FixMatchErr::IrrefutableMatch { var: rules[0].0.clone() }); let match_var = rules[0].0.take(); @@ -235,6 +196,63 @@ fn apply_arg(term: Term, arg_nam: Name, arg: Option) -> Term { } } +/// Given the rules of a match term, return the bodies that match +/// each of the constructors of the matched ADT. +/// +/// If no rules match a certain constructor, return None in the map, +/// indicating a non-exhaustive match. +fn fixed_match_arms<'a>( + rules: &mut Vec, + arg_nam: &Name, + adt_nam: &Name, + adt_ctrs: impl Iterator, + ctrs: &Constructors, + adts: &Adts, + errs: &mut Vec, +) -> HashMap<&'a Name, Option> { + let mut bodies = HashMap::<&Name, Option>::from_iter(adt_ctrs.map(|ctr| (ctr, None))); + for rule_idx in 0 .. rules.len() { + if let Some(ctr_nam) = &rules[rule_idx].0 + && let Some(found_adt) = ctrs.get(ctr_nam) + { + // Ctr arm, use the body of this rule for this constructor. + if found_adt == adt_nam { + let body = bodies.get_mut(ctr_nam).unwrap(); + if body.is_none() { + // Use this rule for this constructor + *body = Some(rules[rule_idx].2.clone()); + } else { + errs.push(FixMatchErr::RedundantArm { ctr: ctr_nam.clone() }); + } + } else { + errs.push(FixMatchErr::AdtMismatch { + expected: adt_nam.clone(), + found: found_adt.clone(), + ctr: ctr_nam.clone(), + }) + } + } else { + // Var arm, use the body of this rule for all non-covered constructors. + for (ctr, body) in bodies.iter_mut() { + if body.is_none() { + let mut new_body = rules[rule_idx].2.clone(); + if let Some(var) = &rules[rule_idx].0 { + new_body.subst(var, &rebuild_ctr(arg_nam, ctr, &adts[adt_nam].ctrs[&**ctr])); + } + *body = Some(new_body); + } + } + + if rule_idx != rules.len() - 1 { + errs.push(FixMatchErr::UnreachableMatchArms { var: rules[rule_idx].0.clone() }); + rules.truncate(rule_idx + 1); + } + break; + } + } + bodies +} + fn match_field(arg: &Name, field: &Name) -> Name { Name::new(format!("{arg}.{field}")) }