From 40d26dce7f4b58008290c81087c8a7e3cb05afd1 Mon Sep 17 00:00:00 2001 From: collin Date: Sat, 24 Oct 2020 00:51:04 -0700 Subject: [PATCH] refactor type assertion solving --- dynamic-check/src/dynamic_check.rs | 201 +++++++++++++++++---- dynamic-check/src/errors/type_assertion.rs | 12 ++ static-check/src/types/type_.rs | 2 +- 3 files changed, 174 insertions(+), 41 deletions(-) diff --git a/dynamic-check/src/dynamic_check.rs b/dynamic-check/src/dynamic_check.rs index 7dfe816090..61308fa892 100644 --- a/dynamic-check/src/dynamic_check.rs +++ b/dynamic-check/src/dynamic_check.rs @@ -16,6 +16,7 @@ use crate::{DynamicCheckError, FrameError, ScopeError, TypeAssertionError, VariableTableError}; use leo_static_check::{ + flatten_array_type, Attribute, CircuitFunctionType, CircuitType, @@ -1253,21 +1254,42 @@ impl Frame { println!("assertion: {:?}", type_assertion); + // Collect `TypeVariablePairs` from the `TypeAssertion`. + let pairs = type_assertion.pairs()?; + + // If no pairs are found, attempt to evaluate the `TypeAssertion`. + if pairs.is_empty() { + // Evaluate the `TypeAssertion`. + type_assertion.evaluate()? + } else { + // Iterate over each `TypeVariable` -> `Type` pair. + for pair in pairs.get_pairs() { + // Substitute the `TypeVariable` for it's paired `Type` in all `TypeAssertion`s. + for original in &mut unsolved { + original.substitute(&pair.0, &pair.1) + } + + for original in &mut unsolved_membership { + original.substitute(&pair.0, &pair.1) + } + } + } + // Solve the `TypeAssertion`. // // If the `TypeAssertion` has a solution, then continue the loop. // If the `TypeAssertion` returns a `TypeVariablePair`, then substitute the `TypeVariable` // for it's paired `Type` in all `TypeAssertion`s. - if let Some(pair) = type_assertion.solve()? { - // Substitute the `TypeVariable` for it's paired `Type` in all `TypeAssertion`s. - for original in &mut unsolved { - original.substitute(&pair.0, &pair.1) - } - - for original in &mut unsolved_membership { - original.substitute(&pair.0, &pair.1) - } - }; + // if let Some(pair) = type_assertion.solve()? { + // // Substitute the `TypeVariable` for it's paired `Type` in all `TypeAssertion`s. + // for original in &mut unsolved { + // original.substitute(&pair.0, &pair.1) + // } + // + // for original in &mut unsolved_membership { + // original.substitute(&pair.0, &pair.1) + // } + // }; } // Solve all type membership assertions. @@ -1276,7 +1298,7 @@ impl Frame { let type_assertion = unsolved_membership.pop().unwrap(); // Solve the membership assertion. - type_assertion.solve()?; + type_assertion.evaluate()?; } // for type_assertion in unsolved.pop() { @@ -1420,9 +1442,6 @@ impl VariableTable { } } -/// A type variable -> type pair. -pub struct TypeVariablePair(TypeVariable, Type); - /// A predicate that evaluates equality between two `Types`s. #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub enum TypeAssertion { @@ -1445,8 +1464,19 @@ impl TypeAssertion { Self::Membership(TypeMembership::new(given, set, span)) } + /// + /// Returns one or more `TypeVariablePairs` generated by the given `TypeAssertion`. + /// + pub fn pairs(&self) -> Result { + match self { + TypeAssertion::Equality(equality) => equality.pairs(), + TypeAssertion::Membership(_) => unimplemented!("Cannot generate pairs from type membership"), + } + } + /// /// Substitutes the given type for self if self is equal to the type variable. + /// pub fn substitute(&mut self, variable: &TypeVariable, type_: &Type) { match self { TypeAssertion::Equality(equality) => equality.substitute(variable, type_), @@ -1455,19 +1485,12 @@ impl TypeAssertion { } /// - /// Returns `None` if the `TypeAssertion` is solvable. + /// Checks if the `TypeAssertion` is satisfied. /// - /// If the `TypeAssertion` is not solvable, throw a `TypeAssertionError`. - /// If the `TypeAssertion` contains a `TypeVariable`, then return `Some(TypeVariable, Type)`. - /// - pub fn solve(&self) -> Result, TypeAssertionError> { + pub fn evaluate(&self) -> Result<(), TypeAssertionError> { match self { - TypeAssertion::Equality(equality) => equality.solve(), - TypeAssertion::Membership(membership) => { - membership.solve()?; - - Ok(None) - } + TypeAssertion::Equality(equality) => equality.evaluate(), + TypeAssertion::Membership(membership) => membership.evaluate(), } } } @@ -1502,7 +1525,7 @@ impl TypeMembership { /// /// Returns true if the given type is equal to a member of the set. /// - pub fn solve(&self) -> Result<(), TypeAssertionError> { + pub fn evaluate(&self) -> Result<(), TypeAssertionError> { if self.set.contains(&self.given) { Ok(()) } else { @@ -1555,22 +1578,120 @@ impl TypeEquality { self.right.substitute(variable, type_); } + /// + /// Checks if the `self.left` == `self.right`. + /// + pub fn evaluate(&self) -> Result<(), TypeAssertionError> { + if self.left.eq(&self.right) { + Ok(()) + } else { + Err(TypeAssertionError::equality_failed(&self.left, &self.right, &self.span)) + } + } + /// /// Returns the (type variable, type) pair from this assertion. /// - pub fn solve(&self) -> Result, TypeAssertionError> { - Ok(match (&self.left, &self.right) { - (Type::TypeVariable(variable), type_) => Some(TypeVariablePair(variable.clone(), type_.clone())), - (type_, Type::TypeVariable(variable)) => Some(TypeVariablePair(variable.clone(), type_.clone())), - (type1, type2) => { - // Compare types. - if type1.eq(type2) { - // Return None if the two types are equal (the equality is satisfied). - None - } else { - return Err(TypeAssertionError::equality_failed(type1, type2, &self.span)); - } - } - }) + pub fn pairs(&self) -> Result { + TypeVariablePairs::new(&self.left, &self.right, &self.span) + } +} + +/// A type variable -> type pair. +pub struct TypeVariablePair(TypeVariable, Type); + +/// A vector of `TypeVariablePair`s. +pub struct TypeVariablePairs(Vec); + +impl Default for TypeVariablePairs { + fn default() -> Self { + Self(Vec::new()) + } +} + +impl TypeVariablePairs { + /// + /// Returns a new `TypeVariablePairs` struct from the given left and right types. + /// + pub fn new(left: &Type, right: &Type, span: &Span) -> Result { + let mut pairs = Self::default(); + + // Push all `TypeVariablePair`s. + pairs.push_pairs(left, right, span)?; + + Ok(pairs) + } + + /// + /// Returns true if the self vector has no pairs. + /// + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// + /// Returns the self vector of pairs. + /// + pub fn get_pairs(&self) -> &Vec { + &self.0 + } + + /// + /// Pushes a new `TypeVariablePair` struct to self. + /// + pub fn push(&mut self, variable: &TypeVariable, type_: &Type) { + // Create a new type variable -> type pair. + let pair = TypeVariablePair(variable.clone(), type_.clone()); + + // Push the pair to the self vector. + self.0.push(pair); + } + + /// + /// Checks if the given left or right type contains a `TypeVariable`. + /// If a `TypeVariable` is found, create a new `TypeVariablePair` between the given left + /// and right type. + /// + pub fn push_pairs(&mut self, left: &Type, right: &Type, span: &Span) -> Result<(), TypeAssertionError> { + match (left, right) { + (Type::TypeVariable(variable), type_) => Ok(self.push(variable, type_)), + (type_, Type::TypeVariable(variable)) => Ok(self.push(variable, type_)), + (Type::Array(type1, dimensions1), Type::Array(type2, dimensions2)) => { + self.push_pairs_array(type1, dimensions1, type2, dimensions2, span) + } + (_, _) => Ok(()), // No `TypeVariable` found so we do not push any pairs. + } + } + + /// + /// Checks if the given left or right array type contains a `TypeVariable`. + /// If a `TypeVariable` is found, create a new `TypeVariablePair` between the given left + /// and right type. + /// + fn push_pairs_array( + &mut self, + left_type: &Type, + left_dimensions: &Vec, + right_type: &Type, + right_dimensions: &Vec, + span: &Span, + ) -> Result<(), TypeAssertionError> { + // Flatten the array types to get the element types. + let (left_type_flat, left_dimensions_flat) = flatten_array_type(left_type, left_dimensions.to_owned()); + let (right_type_flat, right_dimensions_flat) = flatten_array_type(right_type, right_dimensions.to_owned()); + + // If the dimensions do not match, then throw an error. + if left_dimensions_flat.ne(&right_dimensions_flat) { + return Err(TypeAssertionError::array_dimensions( + left_dimensions_flat, + right_dimensions_flat, + span, + )); + } + + // Compare the array element types. + self.push_pairs(left_type_flat, right_type_flat, span); + + Ok(()) } } diff --git a/dynamic-check/src/errors/type_assertion.rs b/dynamic-check/src/errors/type_assertion.rs index 22b1090714..da61010947 100644 --- a/dynamic-check/src/errors/type_assertion.rs +++ b/dynamic-check/src/errors/type_assertion.rs @@ -63,4 +63,16 @@ impl TypeAssertionError { Self::new_from_span(message, span.to_owned()) } + + /// + /// Mismatched array type dimensions. + /// + pub fn array_dimensions(dimensions1: Vec, dimensions2: Vec, span: &Span) -> Self { + let message = format!( + "Expected array with dimensions `{:?}`, found array with dimensions `{:?}`.", + dimensions1, dimensions2 + ); + + Self::new_from_span(message, span.to_owned()) + } } diff --git a/static-check/src/types/type_.rs b/static-check/src/types/type_.rs index 03b94ba50e..69c7a346da 100644 --- a/static-check/src/types/type_.rs +++ b/static-check/src/types/type_.rs @@ -311,7 +311,7 @@ impl Eq for Type {} /// /// Will flatten an array type `[[[u8; 1]; 2]; 3]` into `[u8; (3, 2, 1)]`. /// -fn flatten_array_type(type_: &Type, mut dimensions: Vec) -> (&Type, Vec) { +pub fn flatten_array_type(type_: &Type, mut dimensions: Vec) -> (&Type, Vec) { if let Type::Array(element_type, element_dimensions) = type_ { dimensions.append(&mut element_dimensions.to_owned()); flatten_array_type(element_type, dimensions)