refactor type assertion solving

This commit is contained in:
collin 2020-10-24 00:51:04 -07:00
parent 8e75e36532
commit 40d26dce7f
3 changed files with 174 additions and 41 deletions

View File

@ -16,6 +16,7 @@
use crate::{DynamicCheckError, FrameError, ScopeError, TypeAssertionError, VariableTableError}; use crate::{DynamicCheckError, FrameError, ScopeError, TypeAssertionError, VariableTableError};
use leo_static_check::{ use leo_static_check::{
flatten_array_type,
Attribute, Attribute,
CircuitFunctionType, CircuitFunctionType,
CircuitType, CircuitType,
@ -1253,12 +1254,16 @@ impl Frame {
println!("assertion: {:?}", type_assertion); println!("assertion: {:?}", type_assertion);
// Solve the `TypeAssertion`. // Collect `TypeVariablePairs` from the `TypeAssertion`.
// let pairs = type_assertion.pairs()?;
// If the `TypeAssertion` has a solution, then continue the loop.
// If the `TypeAssertion` returns a `TypeVariablePair`, then substitute the `TypeVariable` // If no pairs are found, attempt to evaluate the `TypeAssertion`.
// for it's paired `Type` in all `TypeAssertion`s. if pairs.is_empty() {
if let Some(pair) = type_assertion.solve()? { // Evaluate the `TypeAssertion`.
type_assertion.evaluate()?
} else {
// Iterate over each `TypeVariable` -> `Type` pair.
for pair in pairs.get_pairs() {
// Substitute the `TypeVariable` for it's paired `Type` in all `TypeAssertion`s. // Substitute the `TypeVariable` for it's paired `Type` in all `TypeAssertion`s.
for original in &mut unsolved { for original in &mut unsolved {
original.substitute(&pair.0, &pair.1) original.substitute(&pair.0, &pair.1)
@ -1267,7 +1272,24 @@ impl Frame {
for original in &mut unsolved_membership { for original in &mut unsolved_membership {
original.substitute(&pair.0, &pair.1) original.substitute(&pair.0, &pair.1)
} }
}; }
}
// Solve the `TypeAssertion`.
//
// If the `TypeAssertion` has a solution, then continue the loop.
// If the `TypeAssertion` returns a `TypeVariablePair`, then substitute the `TypeVariable`
// for it's paired `Type` in all `TypeAssertion`s.
// if let Some(pair) = type_assertion.solve()? {
// // Substitute the `TypeVariable` for it's paired `Type` in all `TypeAssertion`s.
// for original in &mut unsolved {
// original.substitute(&pair.0, &pair.1)
// }
//
// for original in &mut unsolved_membership {
// original.substitute(&pair.0, &pair.1)
// }
// };
} }
// Solve all type membership assertions. // Solve all type membership assertions.
@ -1276,7 +1298,7 @@ impl Frame {
let type_assertion = unsolved_membership.pop().unwrap(); let type_assertion = unsolved_membership.pop().unwrap();
// Solve the membership assertion. // Solve the membership assertion.
type_assertion.solve()?; type_assertion.evaluate()?;
} }
// for type_assertion in unsolved.pop() { // for type_assertion in unsolved.pop() {
@ -1420,9 +1442,6 @@ impl VariableTable {
} }
} }
/// A type variable -> type pair.
pub struct TypeVariablePair(TypeVariable, Type);
/// A predicate that evaluates equality between two `Types`s. /// A predicate that evaluates equality between two `Types`s.
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum TypeAssertion { pub enum TypeAssertion {
@ -1445,8 +1464,19 @@ impl TypeAssertion {
Self::Membership(TypeMembership::new(given, set, span)) Self::Membership(TypeMembership::new(given, set, span))
} }
///
/// Returns one or more `TypeVariablePairs` generated by the given `TypeAssertion`.
///
pub fn pairs(&self) -> Result<TypeVariablePairs, TypeAssertionError> {
match self {
TypeAssertion::Equality(equality) => equality.pairs(),
TypeAssertion::Membership(_) => unimplemented!("Cannot generate pairs from type membership"),
}
}
/// ///
/// Substitutes the given type for self if self is equal to the type variable. /// Substitutes the given type for self if self is equal to the type variable.
///
pub fn substitute(&mut self, variable: &TypeVariable, type_: &Type) { pub fn substitute(&mut self, variable: &TypeVariable, type_: &Type) {
match self { match self {
TypeAssertion::Equality(equality) => equality.substitute(variable, type_), TypeAssertion::Equality(equality) => equality.substitute(variable, type_),
@ -1455,19 +1485,12 @@ impl TypeAssertion {
} }
/// ///
/// Returns `None` if the `TypeAssertion` is solvable. /// Checks if the `TypeAssertion` is satisfied.
/// ///
/// If the `TypeAssertion` is not solvable, throw a `TypeAssertionError`. pub fn evaluate(&self) -> Result<(), TypeAssertionError> {
/// If the `TypeAssertion` contains a `TypeVariable`, then return `Some(TypeVariable, Type)`.
///
pub fn solve(&self) -> Result<Option<TypeVariablePair>, TypeAssertionError> {
match self { match self {
TypeAssertion::Equality(equality) => equality.solve(), TypeAssertion::Equality(equality) => equality.evaluate(),
TypeAssertion::Membership(membership) => { TypeAssertion::Membership(membership) => membership.evaluate(),
membership.solve()?;
Ok(None)
}
} }
} }
} }
@ -1502,7 +1525,7 @@ impl TypeMembership {
/// ///
/// Returns true if the given type is equal to a member of the set. /// Returns true if the given type is equal to a member of the set.
/// ///
pub fn solve(&self) -> Result<(), TypeAssertionError> { pub fn evaluate(&self) -> Result<(), TypeAssertionError> {
if self.set.contains(&self.given) { if self.set.contains(&self.given) {
Ok(()) Ok(())
} else { } else {
@ -1555,22 +1578,120 @@ impl TypeEquality {
self.right.substitute(variable, type_); self.right.substitute(variable, type_);
} }
///
/// Checks if the `self.left` == `self.right`.
///
pub fn evaluate(&self) -> Result<(), TypeAssertionError> {
if self.left.eq(&self.right) {
Ok(())
} else {
Err(TypeAssertionError::equality_failed(&self.left, &self.right, &self.span))
}
}
/// ///
/// Returns the (type variable, type) pair from this assertion. /// Returns the (type variable, type) pair from this assertion.
/// ///
pub fn solve(&self) -> Result<Option<TypeVariablePair>, TypeAssertionError> { pub fn pairs(&self) -> Result<TypeVariablePairs, TypeAssertionError> {
Ok(match (&self.left, &self.right) { TypeVariablePairs::new(&self.left, &self.right, &self.span)
(Type::TypeVariable(variable), type_) => Some(TypeVariablePair(variable.clone(), type_.clone())), }
(type_, Type::TypeVariable(variable)) => Some(TypeVariablePair(variable.clone(), type_.clone())), }
(type1, type2) => {
// Compare types. /// A type variable -> type pair.
if type1.eq(type2) { pub struct TypeVariablePair(TypeVariable, Type);
// Return None if the two types are equal (the equality is satisfied).
None /// A vector of `TypeVariablePair`s.
} else { pub struct TypeVariablePairs(Vec<TypeVariablePair>);
return Err(TypeAssertionError::equality_failed(type1, type2, &self.span));
} impl Default for TypeVariablePairs {
} fn default() -> Self {
}) Self(Vec::new())
}
}
impl TypeVariablePairs {
///
/// Returns a new `TypeVariablePairs` struct from the given left and right types.
///
pub fn new(left: &Type, right: &Type, span: &Span) -> Result<Self, TypeAssertionError> {
let mut pairs = Self::default();
// Push all `TypeVariablePair`s.
pairs.push_pairs(left, right, span)?;
Ok(pairs)
}
///
/// Returns true if the self vector has no pairs.
///
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
///
/// Returns the self vector of pairs.
///
pub fn get_pairs(&self) -> &Vec<TypeVariablePair> {
&self.0
}
///
/// Pushes a new `TypeVariablePair` struct to self.
///
pub fn push(&mut self, variable: &TypeVariable, type_: &Type) {
// Create a new type variable -> type pair.
let pair = TypeVariablePair(variable.clone(), type_.clone());
// Push the pair to the self vector.
self.0.push(pair);
}
///
/// Checks if the given left or right type contains a `TypeVariable`.
/// If a `TypeVariable` is found, create a new `TypeVariablePair` between the given left
/// and right type.
///
pub fn push_pairs(&mut self, left: &Type, right: &Type, span: &Span) -> Result<(), TypeAssertionError> {
match (left, right) {
(Type::TypeVariable(variable), type_) => Ok(self.push(variable, type_)),
(type_, Type::TypeVariable(variable)) => Ok(self.push(variable, type_)),
(Type::Array(type1, dimensions1), Type::Array(type2, dimensions2)) => {
self.push_pairs_array(type1, dimensions1, type2, dimensions2, span)
}
(_, _) => Ok(()), // No `TypeVariable` found so we do not push any pairs.
}
}
///
/// Checks if the given left or right array type contains a `TypeVariable`.
/// If a `TypeVariable` is found, create a new `TypeVariablePair` between the given left
/// and right type.
///
fn push_pairs_array(
&mut self,
left_type: &Type,
left_dimensions: &Vec<usize>,
right_type: &Type,
right_dimensions: &Vec<usize>,
span: &Span,
) -> Result<(), TypeAssertionError> {
// Flatten the array types to get the element types.
let (left_type_flat, left_dimensions_flat) = flatten_array_type(left_type, left_dimensions.to_owned());
let (right_type_flat, right_dimensions_flat) = flatten_array_type(right_type, right_dimensions.to_owned());
// If the dimensions do not match, then throw an error.
if left_dimensions_flat.ne(&right_dimensions_flat) {
return Err(TypeAssertionError::array_dimensions(
left_dimensions_flat,
right_dimensions_flat,
span,
));
}
// Compare the array element types.
self.push_pairs(left_type_flat, right_type_flat, span);
Ok(())
} }
} }

View File

@ -63,4 +63,16 @@ impl TypeAssertionError {
Self::new_from_span(message, span.to_owned()) Self::new_from_span(message, span.to_owned())
} }
///
/// Mismatched array type dimensions.
///
pub fn array_dimensions(dimensions1: Vec<usize>, dimensions2: Vec<usize>, span: &Span) -> Self {
let message = format!(
"Expected array with dimensions `{:?}`, found array with dimensions `{:?}`.",
dimensions1, dimensions2
);
Self::new_from_span(message, span.to_owned())
}
} }

View File

@ -311,7 +311,7 @@ impl Eq for Type {}
/// ///
/// Will flatten an array type `[[[u8; 1]; 2]; 3]` into `[u8; (3, 2, 1)]`. /// Will flatten an array type `[[[u8; 1]; 2]; 3]` into `[u8; (3, 2, 1)]`.
/// ///
fn flatten_array_type(type_: &Type, mut dimensions: Vec<usize>) -> (&Type, Vec<usize>) { pub fn flatten_array_type(type_: &Type, mut dimensions: Vec<usize>) -> (&Type, Vec<usize>) {
if let Type::Array(element_type, element_dimensions) = type_ { if let Type::Array(element_type, element_dimensions) = type_ {
dimensions.append(&mut element_dimensions.to_owned()); dimensions.append(&mut element_dimensions.to_owned());
flatten_array_type(element_type, dimensions) flatten_array_type(element_type, dimensions)