From 0bfe5fa36101779f3f0beae35e3e6564fcfffabe Mon Sep 17 00:00:00 2001 From: Nicolas Abril Date: Thu, 5 Oct 2023 22:16:34 +0200 Subject: [PATCH] Complete type inference pass (not working) --- Cargo.lock | 7 + Cargo.toml | 1 + src/semantic/mod.rs | 1 - src/semantic/pattern.rs | 168 --------------- src/semantic/{ => pattern}/flatten.rs | 0 src/semantic/pattern/mod.rs | 28 +++ src/semantic/pattern/type_inference.rs | 194 ++++++++++++++++++ tests/golden_tests.rs | 62 ++++-- tests/golden_tests/type_inference/bool.golden | 7 + tests/golden_tests/type_inference/bool.hvm | 5 + .../type_inference/records_used_once.golden | 15 ++ .../type_inference/records_used_once.hvm | 11 + 12 files changed, 316 insertions(+), 183 deletions(-) delete mode 100644 src/semantic/pattern.rs rename src/semantic/{ => pattern}/flatten.rs (100%) create mode 100644 src/semantic/pattern/mod.rs create mode 100644 src/semantic/pattern/type_inference.rs create mode 100644 tests/golden_tests/type_inference/bool.golden create mode 100644 tests/golden_tests/type_inference/bool.hvm create mode 100644 tests/golden_tests/type_inference/records_used_once.golden create mode 100644 tests/golden_tests/type_inference/records_used_once.hvm diff --git a/Cargo.lock b/Cargo.lock index facfc3b4..36b49c95 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -252,6 +252,7 @@ dependencies = [ "logos", "pretty_assertions", "shrinkwraprs", + "stdext", "walkdir", ] @@ -436,6 +437,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "stdext" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f3b6b32ae82412fb897ef134867d53a294f57ba5b758f06d71e865352c3e207" + [[package]] name = "strsim" version = "0.10.0" diff --git a/Cargo.toml b/Cargo.toml index 3dfdd368..4b565ce3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,4 +31,5 @@ shrinkwraprs = "0.3.0" [dev-dependencies] pretty_assertions = "1.4.0" +stdext = "0.3.1" walkdir = "2.3.3" diff --git a/src/semantic/mod.rs b/src/semantic/mod.rs index 8f955221..0b2cbe9b 100644 --- a/src/semantic/mod.rs +++ b/src/semantic/mod.rs @@ -2,7 +2,6 @@ use crate::ast::{DefinitionBook, Name}; pub mod combinators; -pub mod flatten; pub mod pattern; pub mod vars; diff --git a/src/semantic/pattern.rs b/src/semantic/pattern.rs deleted file mode 100644 index db1b32f8..00000000 --- a/src/semantic/pattern.rs +++ /dev/null @@ -1,168 +0,0 @@ -use crate::ast::{ - hvm_lang::{DefNames, Pattern}, - DefId, Definition, DefinitionBook, Name, -}; -use hvm_core::Val; -use itertools::Itertools; -use std::{cell::RefCell, collections::HashMap, rc::Rc}; - -/// Semantic passes for pattern matching on defiinition rules. -/// Extract ADTs from patterns in a book, then convert them into lambda calculus. - -impl DefinitionBook { - /// Checks whether all rules of a definition have the same number of arguments - pub fn check_rule_arities(&self) -> anyhow::Result<()> { - for def in &self.defs { - let expected_arity = def.arity(); - // TODO: Return all errors, don't stop at the first one - for rule in &def.rules { - let found_arity = rule.arity(); - if expected_arity != found_arity { - return Err(anyhow::anyhow!( - "Inconsistent arity on definition '{}'. Expected {} patterns, found {}", - self.def_names.name(&def.def_id).unwrap(), - expected_arity, - found_arity - )); - } - } - } - Ok(()) - } - - /// Infers ADTs from the patterns of the rules in a book. - /// Returns the infered type of the patterns of each definition. - /// Returns an error if rules use the types in a inconsistent way. - /// These could be same name in different types, different arities or mixing numbers and ADTs. - /// Precondition: Rules have been flattened, rule arity is correct. - pub fn get_types_from_patterns(&self) -> anyhow::Result<(Vec, HashMap>)> { - // For every pattern in every rule in every definition: - // Collect all types in pattern - // For each collected ADT and their subtypes: (maybe we can flatten first?) - // If any constructors have an already used name: - // Try to merge this type with the one that was using the same constructor - // Update the old definitions to point to the new type - // Or, update types that depend on the updated one and update definitions that use it directly - // Or, set a redirection system - - let mut adts: Vec>> = vec![]; - let mut types: Vec> = vec![]; // The type of each - let mut ctr_name_to_adt: HashMap = HashMap::new(); - for def in &self.defs { - let mut pat_types = get_types_from_def_patterns(def, &self.def_names)?; - // Check if the types in this def share some ctr names with previous types. - // Try to merge them if they do - for typ in pat_types.clone() { - if let Type::Adt(crnt_adt) = typ { - for ctr_name in crnt_adt.borrow().ctrs.keys() { - if let Some(&old_adt_idx) = ctr_name_to_adt.get(ctr_name) { - if crnt_adt.borrow_mut().merge(&adts[old_adt_idx].borrow()) { - // TODO: Make both point at the same adt somehow? - } else { - // TODO: Differentiate between wrong arity and different adts infered for same name. - return Err(anyhow::anyhow!( - "Inconsistent use of constructor '{}' on definition '{}'.", - ctr_name, - self.def_names.name(&def.def_id).unwrap(), - )); - } - } - } - } - } - types.push(pat_types); - } - todo!() - } -} - -#[derive(Debug, Clone)] -pub enum Type { - Any, - Adt(Rc>), - #[cfg(feature = "nums")] - U32, - #[cfg(feature = "nums")] - I32, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct Adt { - // Constructor names and their arities - ctrs: HashMap, - others: bool, -} - -impl From<&Pattern> for Type { - fn from(value: &Pattern) -> Self { - match value { - Pattern::Ctr(name, args) => Type::Adt(Rc::new( - Adt { ctrs: HashMap::from_iter([(name.clone(), args.len())]), others: false }.into(), - )), - Pattern::Var(_) => Type::Any, - #[cfg(feature = "nums")] - Pattern::U32(_) => Type::Number, - #[cfg(feature = "nums")] - Pattern::UI32(_) => Type::Number, - } - } -} - -impl Type { - pub fn join(&self, other: &Type) -> Option { - match (self, other) { - (Type::Any, Type::Any) => Some(Type::Any), - (Type::Any, Type::Adt(t)) => Some(Type::Adt(t.clone())), - (Type::Adt(adt), Type::Any) => { - adt.borrow_mut().others = true; - Some(Type::Adt(adt.clone())) - } - (Type::Adt(adt_a), Type::Adt(adt_b)) => { - todo!() - } - } - } -} - -impl Adt { - pub fn merge(&mut self, other: &Adt) -> bool { - // TODO: Don't copy these names so much - let accept_different = self.others || other.others; - if accept_different { - self.others = accept_different; - for (ctr_name, ctr_arity) in &other.ctrs { - if let Some(old_arity) = self.ctrs.insert(ctr_name.clone(), *ctr_arity) { - if old_arity != *ctr_arity { - return false; - } - } - } - true - } else { - self == other - } - } -} - -fn get_types_from_def_patterns(def: &Definition, def_names: &DefNames) -> anyhow::Result> { - let arity = def.arity(); - let mut pat_types = vec![]; - for pat_idx in 0 .. arity { - let pats = def.rules.iter().map(|x| &x.pats[pat_idx]); - let mut pat_type = Type::Any; - for pat in pats { - if let Some(t) = pat_type.join(&Type::from(pat)) { - pat_type = t; - } else { - // TODO: Improve error reporting. - return Err(anyhow::anyhow!( - "Incompatible types in patterns for definition '{}'", - def_names.name(&def.def_id).unwrap() - )); - } - } - pat_types.push(pat_type); - } - - Ok(pat_types) -} diff --git a/src/semantic/flatten.rs b/src/semantic/pattern/flatten.rs similarity index 100% rename from src/semantic/flatten.rs rename to src/semantic/pattern/flatten.rs diff --git a/src/semantic/pattern/mod.rs b/src/semantic/pattern/mod.rs new file mode 100644 index 00000000..cf3806af --- /dev/null +++ b/src/semantic/pattern/mod.rs @@ -0,0 +1,28 @@ +/// Semantic passes for pattern matching on defiinition rules. +/// Extract ADTs from patterns in a book, then convert them into lambda calculus. +use crate::ast::DefinitionBook; + +pub mod flatten; +pub mod type_inference; + +impl DefinitionBook { + /// Checks whether all rules of a definition have the same number of arguments + pub fn check_rule_arities(&self) -> anyhow::Result<()> { + for def in &self.defs { + let expected_arity = def.arity(); + // TODO: Return all errors, don't stop at the first one + for rule in &def.rules { + let found_arity = rule.arity(); + if expected_arity != found_arity { + return Err(anyhow::anyhow!( + "Inconsistent arity on definition '{}'. Expected {} patterns, found {}", + self.def_names.name(&def.def_id).unwrap(), + expected_arity, + found_arity + )); + } + } + } + Ok(()) + } +} diff --git a/src/semantic/pattern/type_inference.rs b/src/semantic/pattern/type_inference.rs new file mode 100644 index 00000000..3fc937a6 --- /dev/null +++ b/src/semantic/pattern/type_inference.rs @@ -0,0 +1,194 @@ +use crate::ast::{hvm_lang::Pattern, DefId, Definition, DefinitionBook, Name}; +use anyhow::anyhow; +use itertools::Itertools; +use std::{collections::HashMap, fmt}; + +impl DefinitionBook { + /// Infers ADTs from the patterns of the rules in a book. + /// Returns the infered type of the patterns of each definition. + /// Returns an error if rules use the types in a inconsistent way. + /// These could be same name in different types, different arities or mixing numbers and ADTs. + /// Precondition: Rules have been flattened, rule arity is correct. + pub fn get_types_from_patterns(&self) -> anyhow::Result<(HashMap, Vec>)> { + let mut adts: HashMap = HashMap::new(); + let mut pats_using_adt: HashMap> = HashMap::new(); + let mut ctr_name_to_adt: HashMap = HashMap::new(); + let mut types: Vec> = vec![]; // The type of each + let mut adt_counter = 0; + for def in &self.defs { + let pat_types = get_types_from_def_patterns(def)?; + // Check if the types in this def share some ctr names with previous types. + // Try to merge them if they do + for (i, typ) in pat_types.iter().enumerate() { + if let Type::Adt(_) = typ { + let mut crnt_adt = def.get_adt_from_pat(i)?; + // Gather the existing types that share constructor names. + // We will try to merge them all together. + let mut to_merge = vec![]; + for ctr_name in crnt_adt.ctrs.keys() { + if let Some(old_adt_idx) = ctr_name_to_adt.get(ctr_name) { + to_merge.push(*old_adt_idx); + } + } + if to_merge.is_empty() { + // If nothing to merge, just add the new type and update the control vars + for ctr in crnt_adt.ctrs.keys() { + pats_using_adt.insert(adt_counter, vec![(def.def_id, i)]); + ctr_name_to_adt.insert(ctr.clone(), adt_counter); + } + adts.insert(adt_counter, crnt_adt); + adt_counter += 1; + } else { + // If this adt has to be merged with others, we use the id of the first existing adt to store it. + // We merge all the adts sharing constructor names with our current adt into the one with the first id. + // All the other adts are removed in favor of the merged one that has all the constructors. + // The control variables are updated to now point everything to the merged adt. + let dest_id = to_merge[0]; + for to_merge in to_merge { + merge_adts( + &mut crnt_adt, + dest_id, + &mut adts, + to_merge, + &mut pats_using_adt, + &mut ctr_name_to_adt, + )?; + } + adts.insert(dest_id, crnt_adt); + } + } + } + types.push(pat_types); + } + // Point the definition types to the correct adts. + for (adt_id, uses) in pats_using_adt { + for (def_id, pat_id) in uses { + if let Type::Adt(type_adt) = &mut types[*def_id as usize][pat_id] { + *type_adt = adt_id; + } + } + } + Ok((adts, types)) + } +} + +#[derive(Debug, Clone)] +pub enum Type { + Any, + Adt(AdtId), + #[cfg(feature = "nums")] + U32, + #[cfg(feature = "nums")] + I32, +} + +type AdtId = usize; + +#[derive(Debug, Clone, PartialEq, Default)] +pub struct Adt { + // Constructor names and their arities + ctrs: HashMap, + others: bool, +} + +impl Adt { + pub fn from_ctr(nam: Name, arity: usize) -> Self { + Adt { ctrs: HashMap::from_iter([(nam, arity)]), others: false } + } + + pub fn new() -> Self { + Default::default() + } +} + +fn get_types_from_def_patterns(def: &Definition) -> anyhow::Result> { + let arity = def.arity(); + let mut pat_types = vec![]; + for pat_idx in 0 .. arity { + let pats = def.rules.iter().map(|x| &x.pats[pat_idx]); + let mut pat_type = Type::Any; + for pat in pats { + if let Pattern::Ctr(..) = pat { + pat_type = Type::Adt(usize::MAX); + } + } + pat_types.push(pat_type); + } + + Ok(pat_types) +} + +impl Definition { + fn get_adt_from_pat(&self, pat_idx: usize) -> anyhow::Result { + let mut adt = Adt::new(); + for rule in &self.rules { + match &rule.pats[pat_idx] { + Pattern::Ctr(nam, args) => { + if let Some(expected_arity) = adt.ctrs.get(nam) { + if *expected_arity == self.arity() { + } else { + return Err(anyhow::anyhow!("Inconsistent arity used for constructor {nam}")); + } + } else { + adt.ctrs.insert(nam.clone(), args.len()); + } + } + Pattern::Var(_) => adt.others = true, + #[cfg(feature = "nums")] + _ => panic!("Expected only Ctr and Var patterns to be called here"), + } + } + Ok(adt) + } +} + +fn merge_adts( + this_adt: &mut Adt, + this_id: AdtId, + adts: &mut HashMap, + other_id: AdtId, + pats_using_adt: &mut HashMap>, + ctr_name_to_adt: &mut HashMap, +) -> anyhow::Result<()> { + let other_adt = adts.get_mut(&other_id).unwrap(); + if this_adt != other_adt { + let accept_different = this_adt.others || other_adt.others; + if accept_different { + this_adt.others = accept_different; + for (ctr_name, ctr_arity) in &other_adt.ctrs { + *ctr_name_to_adt.get_mut(ctr_name).unwrap() = this_id; + if let Some(old_arity) = this_adt.ctrs.get(ctr_name) { + if old_arity != ctr_arity { + return Err(anyhow!("Inconsistent arity used for constructor {ctr_name}")); + } + } else { + this_adt.ctrs.insert(ctr_name.clone(), *ctr_arity); + } + } + } else { + return Err(anyhow!("Found same constructor being used in incompatible types")); + } + } + if this_id != other_id { + let mut moved_pats = pats_using_adt.remove(&other_id).unwrap(); + pats_using_adt.get_mut(&this_id).unwrap().append(&mut moved_pats); + adts.remove(&other_id); + } + + Ok(()) +} + +impl fmt::Display for Adt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "[{}{}]", + self + .ctrs + .iter() + .map(|(nam, arity)| format!("({}{})", nam, (0 .. *arity).map(|_| format!(" _")).join(""))) + .join(", "), + if self.others { ", ..." } else { "" } + ) + } +} diff --git a/tests/golden_tests.rs b/tests/golden_tests.rs index 72bc46f4..4ce48f37 100644 --- a/tests/golden_tests.rs +++ b/tests/golden_tests.rs @@ -1,5 +1,6 @@ -use hvm_core::{parse_lnet, show_lnet}; +use hvm_core::{parse_lnet, show_lnet, Val}; use hvm_lang::{ + ast::DefId, compile_book, from_core::readback_net, loader::display_err_for_text, @@ -9,7 +10,12 @@ use hvm_lang::{ }; use itertools::Itertools; use pretty_assertions::assert_eq; -use std::{fs, io::Write, path::Path}; +use std::{ + fs, + io::Write, + path::{Path, PathBuf}, +}; +use stdext::function_name; use walkdir::WalkDir; fn run_single_golden_test( @@ -32,8 +38,13 @@ fn run_single_golden_test( } } -fn run_golden_test_dir(root: &Path, run: &dyn Fn(&Path, &str) -> anyhow::Result) { - let walker = WalkDir::new(root).sort_by_file_name().max_depth(2).into_iter().filter_entry(|e| { +fn run_golden_test_dir(test_name: &str, run: &dyn Fn(&Path, &str) -> anyhow::Result) { + let root = PathBuf::from(format!( + "{}/tests/golden_tests/{}", + env!("CARGO_MANIFEST_DIR"), + test_name.rsplit_once(":").unwrap().1 + )); + let walker = WalkDir::new(&root).sort_by_file_name().max_depth(2).into_iter().filter_entry(|e| { let path = e.path(); if path == root { true @@ -66,8 +77,7 @@ fn run_golden_test_dir(root: &Path, run: &dyn Fn(&Path, &str) -> anyhow::Result< #[test] fn compile_single_terms() { - let root = format!("{}/tests/golden_tests/compile_single_terms", env!("CARGO_MANIFEST_DIR")); - run_golden_test_dir(Path::new(&root), &|_, code| { + run_golden_test_dir(function_name!(), &|_, code| { let term = parse_term(code).map_err(|errs| { let msg = errs.into_iter().map(|e| display_err_for_text(e)).join("\n"); anyhow::anyhow!(msg) @@ -80,8 +90,7 @@ fn compile_single_terms() { #[test] fn compile_single_files() { - let root = format!("{}/tests/golden_tests/compile_single_files", env!("CARGO_MANIFEST_DIR")); - run_golden_test_dir(Path::new(&root), &|_, code| { + run_golden_test_dir(function_name!(), &|_, code| { let book = parse_definition_book(code).map_err(|errs| { let msg = errs.into_iter().map(|e| display_err_for_text(e)).join("\n"); anyhow::anyhow!(msg) @@ -93,8 +102,7 @@ fn compile_single_files() { #[test] fn run_single_files() { - let root = format!("{}/tests/golden_tests/run_single_files", env!("CARGO_MANIFEST_DIR")); - run_golden_test_dir(Path::new(&root), &|_, code| { + run_golden_test_dir(function_name!(), &|_, code| { let book = parse_definition_book(code).map_err(|errs| { let msg = errs.into_iter().map(|e| display_err_for_text(e)).join("\n"); anyhow::anyhow!(msg) @@ -112,8 +120,7 @@ fn run_single_files() { #[test] fn readback_lnet() { - let root = format!("{}/tests/golden_tests/readback_lnet", env!("CARGO_MANIFEST_DIR")); - run_golden_test_dir(Path::new(&root), &|_, code| { + run_golden_test_dir(function_name!(), &|_, code| { let lnet = parse_lnet(&mut code.chars().peekable()); let def_names = Default::default(); let (term, valid) = readback_net(&lnet)?; @@ -127,8 +134,7 @@ fn readback_lnet() { #[test] fn flatten_rules() { - let root = format!("{}/tests/golden_tests/flatten_rules", env!("CARGO_MANIFEST_DIR")); - run_golden_test_dir(Path::new(&root), &|_, code| { + run_golden_test_dir(function_name!(), &|_, code| { let mut book = parse_definition_book(code).map_err(|errs| { let msg = errs.into_iter().map(|e| display_err_for_text(e)).join("\n"); anyhow::anyhow!(msg) @@ -137,3 +143,31 @@ fn flatten_rules() { Ok(book.to_string()) }) } + +#[test] +fn type_inference() { + run_golden_test_dir(function_name!(), &|_, code| { + let mut book = parse_definition_book(code).map_err(|errs| { + let msg = errs.into_iter().map(|e| display_err_for_text(e)).join("\n"); + anyhow::anyhow!(msg) + })?; + book.check_rule_arities()?; + book.flatten_rules(); + let (adts, def_types) = book.get_types_from_patterns()?; + let mut out = String::new(); + out.push_str("Adts: [\n"); + for (adt_id, adt) in adts.iter().sorted_by_key(|x| x.0) { + out.push_str(&format!(" {adt_id} => {adt}\n")); + } + out.push_str("]\nTypes: [\n"); + for (def_id, def_types) in def_types.iter().enumerate() { + out.push_str(&format!( + " {} => [{}]\n", + book.def_names.name(&DefId::from(def_id as Val)).unwrap(), + def_types.iter().map(|x| format!("{x:?}")).join(", "), + )); + } + out.push_str("]"); + Ok(out) + }) +} diff --git a/tests/golden_tests/type_inference/bool.golden b/tests/golden_tests/type_inference/bool.golden new file mode 100644 index 00000000..8f8194f8 --- /dev/null +++ b/tests/golden_tests/type_inference/bool.golden @@ -0,0 +1,7 @@ +Adts: [ + 0 => [(True), (False)] +] +Types: [ + Not => [Adt(0)] + And => [Adt(18446744073709551615), Any] +] \ No newline at end of file diff --git a/tests/golden_tests/type_inference/bool.hvm b/tests/golden_tests/type_inference/bool.hvm new file mode 100644 index 00000000..55f526d6 --- /dev/null +++ b/tests/golden_tests/type_inference/bool.hvm @@ -0,0 +1,5 @@ +Not (True) = False +Not (False) = True + +And (True) b = b +And (False) b = False \ No newline at end of file diff --git a/tests/golden_tests/type_inference/records_used_once.golden b/tests/golden_tests/type_inference/records_used_once.golden new file mode 100644 index 00000000..9a03c0cc --- /dev/null +++ b/tests/golden_tests/type_inference/records_used_once.golden @@ -0,0 +1,15 @@ +Adts: [ + 0 => [(Ctr0)] + 1 => [(Ctr1 _)] + 2 => [(Ctr2 _ _)] + 3 => [(Ctr3 _ _ _)] + 4 => [(Ctr4 _ _ _ _)] +] +Types: [ + Id => [Any] + Record0 => [Adt(0)] + Record1 => [Adt(1)] + Record2 => [Adt(2)] + Record3 => [Adt(3)] + Record4 => [Adt(4)] +] \ No newline at end of file diff --git a/tests/golden_tests/type_inference/records_used_once.hvm b/tests/golden_tests/type_inference/records_used_once.hvm new file mode 100644 index 00000000..80d12d1a --- /dev/null +++ b/tests/golden_tests/type_inference/records_used_once.hvm @@ -0,0 +1,11 @@ +Id x = x + +Record0 (Ctr0) = λx x + +Record1 (Ctr1 a) = a + +Record2 (Ctr2 a b) = (a b) + +Record3 (Ctr3 a b c) = (a b c) + +Record4 (Ctr4 a b c d) = (a b c d)