Complete type inference pass (not working)

This commit is contained in:
Nicolas Abril 2023-10-05 22:16:34 +02:00
parent 9bc0dccd87
commit 0bfe5fa361
12 changed files with 316 additions and 183 deletions

7
Cargo.lock generated
View File

@ -252,6 +252,7 @@ dependencies = [
"logos",
"pretty_assertions",
"shrinkwraprs",
"stdext",
"walkdir",
]
@ -436,6 +437,12 @@ dependencies = [
"winapi",
]
[[package]]
name = "stdext"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f3b6b32ae82412fb897ef134867d53a294f57ba5b758f06d71e865352c3e207"
[[package]]
name = "strsim"
version = "0.10.0"

View File

@ -31,4 +31,5 @@ shrinkwraprs = "0.3.0"
[dev-dependencies]
pretty_assertions = "1.4.0"
stdext = "0.3.1"
walkdir = "2.3.3"

View File

@ -2,7 +2,6 @@
use crate::ast::{DefinitionBook, Name};
pub mod combinators;
pub mod flatten;
pub mod pattern;
pub mod vars;

View File

@ -1,168 +0,0 @@
use crate::ast::{
hvm_lang::{DefNames, Pattern},
DefId, Definition, DefinitionBook, Name,
};
use hvm_core::Val;
use itertools::Itertools;
use std::{cell::RefCell, collections::HashMap, rc::Rc};
/// Semantic passes for pattern matching on defiinition rules.
/// Extract ADTs from patterns in a book, then convert them into lambda calculus.
impl DefinitionBook {
/// Checks whether all rules of a definition have the same number of arguments
pub fn check_rule_arities(&self) -> anyhow::Result<()> {
for def in &self.defs {
let expected_arity = def.arity();
// TODO: Return all errors, don't stop at the first one
for rule in &def.rules {
let found_arity = rule.arity();
if expected_arity != found_arity {
return Err(anyhow::anyhow!(
"Inconsistent arity on definition '{}'. Expected {} patterns, found {}",
self.def_names.name(&def.def_id).unwrap(),
expected_arity,
found_arity
));
}
}
}
Ok(())
}
/// Infers ADTs from the patterns of the rules in a book.
/// Returns the infered type of the patterns of each definition.
/// Returns an error if rules use the types in a inconsistent way.
/// These could be same name in different types, different arities or mixing numbers and ADTs.
/// Precondition: Rules have been flattened, rule arity is correct.
pub fn get_types_from_patterns(&self) -> anyhow::Result<(Vec<Adt>, HashMap<DefId, Vec<Type>>)> {
// For every pattern in every rule in every definition:
// Collect all types in pattern
// For each collected ADT and their subtypes: (maybe we can flatten first?)
// If any constructors have an already used name:
// Try to merge this type with the one that was using the same constructor
// Update the old definitions to point to the new type
// Or, update types that depend on the updated one and update definitions that use it directly
// Or, set a redirection system
let mut adts: Vec<Rc<RefCell<Adt>>> = vec![];
let mut types: Vec<Vec<Type>> = vec![]; // The type of each
let mut ctr_name_to_adt: HashMap<Name, usize> = HashMap::new();
for def in &self.defs {
let mut pat_types = get_types_from_def_patterns(def, &self.def_names)?;
// Check if the types in this def share some ctr names with previous types.
// Try to merge them if they do
for typ in pat_types.clone() {
if let Type::Adt(crnt_adt) = typ {
for ctr_name in crnt_adt.borrow().ctrs.keys() {
if let Some(&old_adt_idx) = ctr_name_to_adt.get(ctr_name) {
if crnt_adt.borrow_mut().merge(&adts[old_adt_idx].borrow()) {
// TODO: Make both point at the same adt somehow?
} else {
// TODO: Differentiate between wrong arity and different adts infered for same name.
return Err(anyhow::anyhow!(
"Inconsistent use of constructor '{}' on definition '{}'.",
ctr_name,
self.def_names.name(&def.def_id).unwrap(),
));
}
}
}
}
}
types.push(pat_types);
}
todo!()
}
}
#[derive(Debug, Clone)]
pub enum Type {
Any,
Adt(Rc<RefCell<Adt>>),
#[cfg(feature = "nums")]
U32,
#[cfg(feature = "nums")]
I32,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Adt {
// Constructor names and their arities
ctrs: HashMap<Name, usize>,
others: bool,
}
impl From<&Pattern> for Type {
fn from(value: &Pattern) -> Self {
match value {
Pattern::Ctr(name, args) => Type::Adt(Rc::new(
Adt { ctrs: HashMap::from_iter([(name.clone(), args.len())]), others: false }.into(),
)),
Pattern::Var(_) => Type::Any,
#[cfg(feature = "nums")]
Pattern::U32(_) => Type::Number,
#[cfg(feature = "nums")]
Pattern::UI32(_) => Type::Number,
}
}
}
impl Type {
pub fn join(&self, other: &Type) -> Option<Type> {
match (self, other) {
(Type::Any, Type::Any) => Some(Type::Any),
(Type::Any, Type::Adt(t)) => Some(Type::Adt(t.clone())),
(Type::Adt(adt), Type::Any) => {
adt.borrow_mut().others = true;
Some(Type::Adt(adt.clone()))
}
(Type::Adt(adt_a), Type::Adt(adt_b)) => {
todo!()
}
}
}
}
impl Adt {
pub fn merge(&mut self, other: &Adt) -> bool {
// TODO: Don't copy these names so much
let accept_different = self.others || other.others;
if accept_different {
self.others = accept_different;
for (ctr_name, ctr_arity) in &other.ctrs {
if let Some(old_arity) = self.ctrs.insert(ctr_name.clone(), *ctr_arity) {
if old_arity != *ctr_arity {
return false;
}
}
}
true
} else {
self == other
}
}
}
fn get_types_from_def_patterns(def: &Definition, def_names: &DefNames) -> anyhow::Result<Vec<Type>> {
let arity = def.arity();
let mut pat_types = vec![];
for pat_idx in 0 .. arity {
let pats = def.rules.iter().map(|x| &x.pats[pat_idx]);
let mut pat_type = Type::Any;
for pat in pats {
if let Some(t) = pat_type.join(&Type::from(pat)) {
pat_type = t;
} else {
// TODO: Improve error reporting.
return Err(anyhow::anyhow!(
"Incompatible types in patterns for definition '{}'",
def_names.name(&def.def_id).unwrap()
));
}
}
pat_types.push(pat_type);
}
Ok(pat_types)
}

View File

@ -0,0 +1,28 @@
/// Semantic passes for pattern matching on defiinition rules.
/// Extract ADTs from patterns in a book, then convert them into lambda calculus.
use crate::ast::DefinitionBook;
pub mod flatten;
pub mod type_inference;
impl DefinitionBook {
/// Checks whether all rules of a definition have the same number of arguments
pub fn check_rule_arities(&self) -> anyhow::Result<()> {
for def in &self.defs {
let expected_arity = def.arity();
// TODO: Return all errors, don't stop at the first one
for rule in &def.rules {
let found_arity = rule.arity();
if expected_arity != found_arity {
return Err(anyhow::anyhow!(
"Inconsistent arity on definition '{}'. Expected {} patterns, found {}",
self.def_names.name(&def.def_id).unwrap(),
expected_arity,
found_arity
));
}
}
}
Ok(())
}
}

View File

@ -0,0 +1,194 @@
use crate::ast::{hvm_lang::Pattern, DefId, Definition, DefinitionBook, Name};
use anyhow::anyhow;
use itertools::Itertools;
use std::{collections::HashMap, fmt};
impl DefinitionBook {
/// Infers ADTs from the patterns of the rules in a book.
/// Returns the infered type of the patterns of each definition.
/// Returns an error if rules use the types in a inconsistent way.
/// These could be same name in different types, different arities or mixing numbers and ADTs.
/// Precondition: Rules have been flattened, rule arity is correct.
pub fn get_types_from_patterns(&self) -> anyhow::Result<(HashMap<AdtId, Adt>, Vec<Vec<Type>>)> {
let mut adts: HashMap<AdtId, Adt> = HashMap::new();
let mut pats_using_adt: HashMap<AdtId, Vec<(DefId, usize)>> = HashMap::new();
let mut ctr_name_to_adt: HashMap<Name, AdtId> = HashMap::new();
let mut types: Vec<Vec<Type>> = vec![]; // The type of each
let mut adt_counter = 0;
for def in &self.defs {
let pat_types = get_types_from_def_patterns(def)?;
// Check if the types in this def share some ctr names with previous types.
// Try to merge them if they do
for (i, typ) in pat_types.iter().enumerate() {
if let Type::Adt(_) = typ {
let mut crnt_adt = def.get_adt_from_pat(i)?;
// Gather the existing types that share constructor names.
// We will try to merge them all together.
let mut to_merge = vec![];
for ctr_name in crnt_adt.ctrs.keys() {
if let Some(old_adt_idx) = ctr_name_to_adt.get(ctr_name) {
to_merge.push(*old_adt_idx);
}
}
if to_merge.is_empty() {
// If nothing to merge, just add the new type and update the control vars
for ctr in crnt_adt.ctrs.keys() {
pats_using_adt.insert(adt_counter, vec![(def.def_id, i)]);
ctr_name_to_adt.insert(ctr.clone(), adt_counter);
}
adts.insert(adt_counter, crnt_adt);
adt_counter += 1;
} else {
// If this adt has to be merged with others, we use the id of the first existing adt to store it.
// We merge all the adts sharing constructor names with our current adt into the one with the first id.
// All the other adts are removed in favor of the merged one that has all the constructors.
// The control variables are updated to now point everything to the merged adt.
let dest_id = to_merge[0];
for to_merge in to_merge {
merge_adts(
&mut crnt_adt,
dest_id,
&mut adts,
to_merge,
&mut pats_using_adt,
&mut ctr_name_to_adt,
)?;
}
adts.insert(dest_id, crnt_adt);
}
}
}
types.push(pat_types);
}
// Point the definition types to the correct adts.
for (adt_id, uses) in pats_using_adt {
for (def_id, pat_id) in uses {
if let Type::Adt(type_adt) = &mut types[*def_id as usize][pat_id] {
*type_adt = adt_id;
}
}
}
Ok((adts, types))
}
}
#[derive(Debug, Clone)]
pub enum Type {
Any,
Adt(AdtId),
#[cfg(feature = "nums")]
U32,
#[cfg(feature = "nums")]
I32,
}
type AdtId = usize;
#[derive(Debug, Clone, PartialEq, Default)]
pub struct Adt {
// Constructor names and their arities
ctrs: HashMap<Name, usize>,
others: bool,
}
impl Adt {
pub fn from_ctr(nam: Name, arity: usize) -> Self {
Adt { ctrs: HashMap::from_iter([(nam, arity)]), others: false }
}
pub fn new() -> Self {
Default::default()
}
}
fn get_types_from_def_patterns(def: &Definition) -> anyhow::Result<Vec<Type>> {
let arity = def.arity();
let mut pat_types = vec![];
for pat_idx in 0 .. arity {
let pats = def.rules.iter().map(|x| &x.pats[pat_idx]);
let mut pat_type = Type::Any;
for pat in pats {
if let Pattern::Ctr(..) = pat {
pat_type = Type::Adt(usize::MAX);
}
}
pat_types.push(pat_type);
}
Ok(pat_types)
}
impl Definition {
fn get_adt_from_pat(&self, pat_idx: usize) -> anyhow::Result<Adt> {
let mut adt = Adt::new();
for rule in &self.rules {
match &rule.pats[pat_idx] {
Pattern::Ctr(nam, args) => {
if let Some(expected_arity) = adt.ctrs.get(nam) {
if *expected_arity == self.arity() {
} else {
return Err(anyhow::anyhow!("Inconsistent arity used for constructor {nam}"));
}
} else {
adt.ctrs.insert(nam.clone(), args.len());
}
}
Pattern::Var(_) => adt.others = true,
#[cfg(feature = "nums")]
_ => panic!("Expected only Ctr and Var patterns to be called here"),
}
}
Ok(adt)
}
}
fn merge_adts(
this_adt: &mut Adt,
this_id: AdtId,
adts: &mut HashMap<AdtId, Adt>,
other_id: AdtId,
pats_using_adt: &mut HashMap<AdtId, Vec<(DefId, usize)>>,
ctr_name_to_adt: &mut HashMap<Name, AdtId>,
) -> anyhow::Result<()> {
let other_adt = adts.get_mut(&other_id).unwrap();
if this_adt != other_adt {
let accept_different = this_adt.others || other_adt.others;
if accept_different {
this_adt.others = accept_different;
for (ctr_name, ctr_arity) in &other_adt.ctrs {
*ctr_name_to_adt.get_mut(ctr_name).unwrap() = this_id;
if let Some(old_arity) = this_adt.ctrs.get(ctr_name) {
if old_arity != ctr_arity {
return Err(anyhow!("Inconsistent arity used for constructor {ctr_name}"));
}
} else {
this_adt.ctrs.insert(ctr_name.clone(), *ctr_arity);
}
}
} else {
return Err(anyhow!("Found same constructor being used in incompatible types"));
}
}
if this_id != other_id {
let mut moved_pats = pats_using_adt.remove(&other_id).unwrap();
pats_using_adt.get_mut(&this_id).unwrap().append(&mut moved_pats);
adts.remove(&other_id);
}
Ok(())
}
impl fmt::Display for Adt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"[{}{}]",
self
.ctrs
.iter()
.map(|(nam, arity)| format!("({}{})", nam, (0 .. *arity).map(|_| format!(" _")).join("")))
.join(", "),
if self.others { ", ..." } else { "" }
)
}
}

View File

@ -1,5 +1,6 @@
use hvm_core::{parse_lnet, show_lnet};
use hvm_core::{parse_lnet, show_lnet, Val};
use hvm_lang::{
ast::DefId,
compile_book,
from_core::readback_net,
loader::display_err_for_text,
@ -9,7 +10,12 @@ use hvm_lang::{
};
use itertools::Itertools;
use pretty_assertions::assert_eq;
use std::{fs, io::Write, path::Path};
use std::{
fs,
io::Write,
path::{Path, PathBuf},
};
use stdext::function_name;
use walkdir::WalkDir;
fn run_single_golden_test(
@ -32,8 +38,13 @@ fn run_single_golden_test(
}
}
fn run_golden_test_dir(root: &Path, run: &dyn Fn(&Path, &str) -> anyhow::Result<String>) {
let walker = WalkDir::new(root).sort_by_file_name().max_depth(2).into_iter().filter_entry(|e| {
fn run_golden_test_dir(test_name: &str, run: &dyn Fn(&Path, &str) -> anyhow::Result<String>) {
let root = PathBuf::from(format!(
"{}/tests/golden_tests/{}",
env!("CARGO_MANIFEST_DIR"),
test_name.rsplit_once(":").unwrap().1
));
let walker = WalkDir::new(&root).sort_by_file_name().max_depth(2).into_iter().filter_entry(|e| {
let path = e.path();
if path == root {
true
@ -66,8 +77,7 @@ fn run_golden_test_dir(root: &Path, run: &dyn Fn(&Path, &str) -> anyhow::Result<
#[test]
fn compile_single_terms() {
let root = format!("{}/tests/golden_tests/compile_single_terms", env!("CARGO_MANIFEST_DIR"));
run_golden_test_dir(Path::new(&root), &|_, code| {
run_golden_test_dir(function_name!(), &|_, code| {
let term = parse_term(code).map_err(|errs| {
let msg = errs.into_iter().map(|e| display_err_for_text(e)).join("\n");
anyhow::anyhow!(msg)
@ -80,8 +90,7 @@ fn compile_single_terms() {
#[test]
fn compile_single_files() {
let root = format!("{}/tests/golden_tests/compile_single_files", env!("CARGO_MANIFEST_DIR"));
run_golden_test_dir(Path::new(&root), &|_, code| {
run_golden_test_dir(function_name!(), &|_, code| {
let book = parse_definition_book(code).map_err(|errs| {
let msg = errs.into_iter().map(|e| display_err_for_text(e)).join("\n");
anyhow::anyhow!(msg)
@ -93,8 +102,7 @@ fn compile_single_files() {
#[test]
fn run_single_files() {
let root = format!("{}/tests/golden_tests/run_single_files", env!("CARGO_MANIFEST_DIR"));
run_golden_test_dir(Path::new(&root), &|_, code| {
run_golden_test_dir(function_name!(), &|_, code| {
let book = parse_definition_book(code).map_err(|errs| {
let msg = errs.into_iter().map(|e| display_err_for_text(e)).join("\n");
anyhow::anyhow!(msg)
@ -112,8 +120,7 @@ fn run_single_files() {
#[test]
fn readback_lnet() {
let root = format!("{}/tests/golden_tests/readback_lnet", env!("CARGO_MANIFEST_DIR"));
run_golden_test_dir(Path::new(&root), &|_, code| {
run_golden_test_dir(function_name!(), &|_, code| {
let lnet = parse_lnet(&mut code.chars().peekable());
let def_names = Default::default();
let (term, valid) = readback_net(&lnet)?;
@ -127,8 +134,7 @@ fn readback_lnet() {
#[test]
fn flatten_rules() {
let root = format!("{}/tests/golden_tests/flatten_rules", env!("CARGO_MANIFEST_DIR"));
run_golden_test_dir(Path::new(&root), &|_, code| {
run_golden_test_dir(function_name!(), &|_, code| {
let mut book = parse_definition_book(code).map_err(|errs| {
let msg = errs.into_iter().map(|e| display_err_for_text(e)).join("\n");
anyhow::anyhow!(msg)
@ -137,3 +143,31 @@ fn flatten_rules() {
Ok(book.to_string())
})
}
#[test]
fn type_inference() {
run_golden_test_dir(function_name!(), &|_, code| {
let mut book = parse_definition_book(code).map_err(|errs| {
let msg = errs.into_iter().map(|e| display_err_for_text(e)).join("\n");
anyhow::anyhow!(msg)
})?;
book.check_rule_arities()?;
book.flatten_rules();
let (adts, def_types) = book.get_types_from_patterns()?;
let mut out = String::new();
out.push_str("Adts: [\n");
for (adt_id, adt) in adts.iter().sorted_by_key(|x| x.0) {
out.push_str(&format!(" {adt_id} => {adt}\n"));
}
out.push_str("]\nTypes: [\n");
for (def_id, def_types) in def_types.iter().enumerate() {
out.push_str(&format!(
" {} => [{}]\n",
book.def_names.name(&DefId::from(def_id as Val)).unwrap(),
def_types.iter().map(|x| format!("{x:?}")).join(", "),
));
}
out.push_str("]");
Ok(out)
})
}

View File

@ -0,0 +1,7 @@
Adts: [
0 => [(True), (False)]
]
Types: [
Not => [Adt(0)]
And => [Adt(18446744073709551615), Any]
]

View File

@ -0,0 +1,5 @@
Not (True) = False
Not (False) = True
And (True) b = b
And (False) b = False

View File

@ -0,0 +1,15 @@
Adts: [
0 => [(Ctr0)]
1 => [(Ctr1 _)]
2 => [(Ctr2 _ _)]
3 => [(Ctr3 _ _ _)]
4 => [(Ctr4 _ _ _ _)]
]
Types: [
Id => [Any]
Record0 => [Adt(0)]
Record1 => [Adt(1)]
Record2 => [Adt(2)]
Record3 => [Adt(3)]
Record4 => [Adt(4)]
]

View File

@ -0,0 +1,11 @@
Id x = x
Record0 (Ctr0) = λx x
Record1 (Ctr1 a) = a
Record2 (Ctr2 a b) = (a b)
Record3 (Ctr3 a b c) = (a b c)
Record4 (Ctr4 a b c d) = (a b c d)