mirror of
https://github.com/AleoHQ/leo.git
synced 2024-12-12 06:53:10 +03:00
refactor type assertion solving
This commit is contained in:
parent
8e75e36532
commit
40d26dce7f
@ -16,6 +16,7 @@
|
||||
|
||||
use crate::{DynamicCheckError, FrameError, ScopeError, TypeAssertionError, VariableTableError};
|
||||
use leo_static_check::{
|
||||
flatten_array_type,
|
||||
Attribute,
|
||||
CircuitFunctionType,
|
||||
CircuitType,
|
||||
@ -1253,12 +1254,16 @@ impl Frame {
|
||||
|
||||
println!("assertion: {:?}", type_assertion);
|
||||
|
||||
// Solve the `TypeAssertion`.
|
||||
//
|
||||
// If the `TypeAssertion` has a solution, then continue the loop.
|
||||
// If the `TypeAssertion` returns a `TypeVariablePair`, then substitute the `TypeVariable`
|
||||
// for it's paired `Type` in all `TypeAssertion`s.
|
||||
if let Some(pair) = type_assertion.solve()? {
|
||||
// Collect `TypeVariablePairs` from the `TypeAssertion`.
|
||||
let pairs = type_assertion.pairs()?;
|
||||
|
||||
// If no pairs are found, attempt to evaluate the `TypeAssertion`.
|
||||
if pairs.is_empty() {
|
||||
// Evaluate the `TypeAssertion`.
|
||||
type_assertion.evaluate()?
|
||||
} else {
|
||||
// Iterate over each `TypeVariable` -> `Type` pair.
|
||||
for pair in pairs.get_pairs() {
|
||||
// Substitute the `TypeVariable` for it's paired `Type` in all `TypeAssertion`s.
|
||||
for original in &mut unsolved {
|
||||
original.substitute(&pair.0, &pair.1)
|
||||
@ -1267,7 +1272,24 @@ impl Frame {
|
||||
for original in &mut unsolved_membership {
|
||||
original.substitute(&pair.0, &pair.1)
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Solve the `TypeAssertion`.
|
||||
//
|
||||
// If the `TypeAssertion` has a solution, then continue the loop.
|
||||
// If the `TypeAssertion` returns a `TypeVariablePair`, then substitute the `TypeVariable`
|
||||
// for it's paired `Type` in all `TypeAssertion`s.
|
||||
// if let Some(pair) = type_assertion.solve()? {
|
||||
// // Substitute the `TypeVariable` for it's paired `Type` in all `TypeAssertion`s.
|
||||
// for original in &mut unsolved {
|
||||
// original.substitute(&pair.0, &pair.1)
|
||||
// }
|
||||
//
|
||||
// for original in &mut unsolved_membership {
|
||||
// original.substitute(&pair.0, &pair.1)
|
||||
// }
|
||||
// };
|
||||
}
|
||||
|
||||
// Solve all type membership assertions.
|
||||
@ -1276,7 +1298,7 @@ impl Frame {
|
||||
let type_assertion = unsolved_membership.pop().unwrap();
|
||||
|
||||
// Solve the membership assertion.
|
||||
type_assertion.solve()?;
|
||||
type_assertion.evaluate()?;
|
||||
}
|
||||
|
||||
// for type_assertion in unsolved.pop() {
|
||||
@ -1420,9 +1442,6 @@ impl VariableTable {
|
||||
}
|
||||
}
|
||||
|
||||
/// A type variable -> type pair.
|
||||
pub struct TypeVariablePair(TypeVariable, Type);
|
||||
|
||||
/// A predicate that evaluates equality between two `Types`s.
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||
pub enum TypeAssertion {
|
||||
@ -1445,8 +1464,19 @@ impl TypeAssertion {
|
||||
Self::Membership(TypeMembership::new(given, set, span))
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns one or more `TypeVariablePairs` generated by the given `TypeAssertion`.
|
||||
///
|
||||
pub fn pairs(&self) -> Result<TypeVariablePairs, TypeAssertionError> {
|
||||
match self {
|
||||
TypeAssertion::Equality(equality) => equality.pairs(),
|
||||
TypeAssertion::Membership(_) => unimplemented!("Cannot generate pairs from type membership"),
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// Substitutes the given type for self if self is equal to the type variable.
|
||||
///
|
||||
pub fn substitute(&mut self, variable: &TypeVariable, type_: &Type) {
|
||||
match self {
|
||||
TypeAssertion::Equality(equality) => equality.substitute(variable, type_),
|
||||
@ -1455,19 +1485,12 @@ impl TypeAssertion {
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns `None` if the `TypeAssertion` is solvable.
|
||||
/// Checks if the `TypeAssertion` is satisfied.
|
||||
///
|
||||
/// If the `TypeAssertion` is not solvable, throw a `TypeAssertionError`.
|
||||
/// If the `TypeAssertion` contains a `TypeVariable`, then return `Some(TypeVariable, Type)`.
|
||||
///
|
||||
pub fn solve(&self) -> Result<Option<TypeVariablePair>, TypeAssertionError> {
|
||||
pub fn evaluate(&self) -> Result<(), TypeAssertionError> {
|
||||
match self {
|
||||
TypeAssertion::Equality(equality) => equality.solve(),
|
||||
TypeAssertion::Membership(membership) => {
|
||||
membership.solve()?;
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
TypeAssertion::Equality(equality) => equality.evaluate(),
|
||||
TypeAssertion::Membership(membership) => membership.evaluate(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1502,7 +1525,7 @@ impl TypeMembership {
|
||||
///
|
||||
/// Returns true if the given type is equal to a member of the set.
|
||||
///
|
||||
pub fn solve(&self) -> Result<(), TypeAssertionError> {
|
||||
pub fn evaluate(&self) -> Result<(), TypeAssertionError> {
|
||||
if self.set.contains(&self.given) {
|
||||
Ok(())
|
||||
} else {
|
||||
@ -1555,22 +1578,120 @@ impl TypeEquality {
|
||||
self.right.substitute(variable, type_);
|
||||
}
|
||||
|
||||
///
|
||||
/// Checks if the `self.left` == `self.right`.
|
||||
///
|
||||
pub fn evaluate(&self) -> Result<(), TypeAssertionError> {
|
||||
if self.left.eq(&self.right) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(TypeAssertionError::equality_failed(&self.left, &self.right, &self.span))
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns the (type variable, type) pair from this assertion.
|
||||
///
|
||||
pub fn solve(&self) -> Result<Option<TypeVariablePair>, TypeAssertionError> {
|
||||
Ok(match (&self.left, &self.right) {
|
||||
(Type::TypeVariable(variable), type_) => Some(TypeVariablePair(variable.clone(), type_.clone())),
|
||||
(type_, Type::TypeVariable(variable)) => Some(TypeVariablePair(variable.clone(), type_.clone())),
|
||||
(type1, type2) => {
|
||||
// Compare types.
|
||||
if type1.eq(type2) {
|
||||
// Return None if the two types are equal (the equality is satisfied).
|
||||
None
|
||||
} else {
|
||||
return Err(TypeAssertionError::equality_failed(type1, type2, &self.span));
|
||||
}
|
||||
}
|
||||
})
|
||||
pub fn pairs(&self) -> Result<TypeVariablePairs, TypeAssertionError> {
|
||||
TypeVariablePairs::new(&self.left, &self.right, &self.span)
|
||||
}
|
||||
}
|
||||
|
||||
/// A type variable -> type pair.
|
||||
pub struct TypeVariablePair(TypeVariable, Type);
|
||||
|
||||
/// A vector of `TypeVariablePair`s.
|
||||
pub struct TypeVariablePairs(Vec<TypeVariablePair>);
|
||||
|
||||
impl Default for TypeVariablePairs {
|
||||
fn default() -> Self {
|
||||
Self(Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl TypeVariablePairs {
|
||||
///
|
||||
/// Returns a new `TypeVariablePairs` struct from the given left and right types.
|
||||
///
|
||||
pub fn new(left: &Type, right: &Type, span: &Span) -> Result<Self, TypeAssertionError> {
|
||||
let mut pairs = Self::default();
|
||||
|
||||
// Push all `TypeVariablePair`s.
|
||||
pairs.push_pairs(left, right, span)?;
|
||||
|
||||
Ok(pairs)
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns true if the self vector has no pairs.
|
||||
///
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
///
|
||||
/// Returns the self vector of pairs.
|
||||
///
|
||||
pub fn get_pairs(&self) -> &Vec<TypeVariablePair> {
|
||||
&self.0
|
||||
}
|
||||
|
||||
///
|
||||
/// Pushes a new `TypeVariablePair` struct to self.
|
||||
///
|
||||
pub fn push(&mut self, variable: &TypeVariable, type_: &Type) {
|
||||
// Create a new type variable -> type pair.
|
||||
let pair = TypeVariablePair(variable.clone(), type_.clone());
|
||||
|
||||
// Push the pair to the self vector.
|
||||
self.0.push(pair);
|
||||
}
|
||||
|
||||
///
|
||||
/// Checks if the given left or right type contains a `TypeVariable`.
|
||||
/// If a `TypeVariable` is found, create a new `TypeVariablePair` between the given left
|
||||
/// and right type.
|
||||
///
|
||||
pub fn push_pairs(&mut self, left: &Type, right: &Type, span: &Span) -> Result<(), TypeAssertionError> {
|
||||
match (left, right) {
|
||||
(Type::TypeVariable(variable), type_) => Ok(self.push(variable, type_)),
|
||||
(type_, Type::TypeVariable(variable)) => Ok(self.push(variable, type_)),
|
||||
(Type::Array(type1, dimensions1), Type::Array(type2, dimensions2)) => {
|
||||
self.push_pairs_array(type1, dimensions1, type2, dimensions2, span)
|
||||
}
|
||||
(_, _) => Ok(()), // No `TypeVariable` found so we do not push any pairs.
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// Checks if the given left or right array type contains a `TypeVariable`.
|
||||
/// If a `TypeVariable` is found, create a new `TypeVariablePair` between the given left
|
||||
/// and right type.
|
||||
///
|
||||
fn push_pairs_array(
|
||||
&mut self,
|
||||
left_type: &Type,
|
||||
left_dimensions: &Vec<usize>,
|
||||
right_type: &Type,
|
||||
right_dimensions: &Vec<usize>,
|
||||
span: &Span,
|
||||
) -> Result<(), TypeAssertionError> {
|
||||
// Flatten the array types to get the element types.
|
||||
let (left_type_flat, left_dimensions_flat) = flatten_array_type(left_type, left_dimensions.to_owned());
|
||||
let (right_type_flat, right_dimensions_flat) = flatten_array_type(right_type, right_dimensions.to_owned());
|
||||
|
||||
// If the dimensions do not match, then throw an error.
|
||||
if left_dimensions_flat.ne(&right_dimensions_flat) {
|
||||
return Err(TypeAssertionError::array_dimensions(
|
||||
left_dimensions_flat,
|
||||
right_dimensions_flat,
|
||||
span,
|
||||
));
|
||||
}
|
||||
|
||||
// Compare the array element types.
|
||||
self.push_pairs(left_type_flat, right_type_flat, span);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -63,4 +63,16 @@ impl TypeAssertionError {
|
||||
|
||||
Self::new_from_span(message, span.to_owned())
|
||||
}
|
||||
|
||||
///
|
||||
/// Mismatched array type dimensions.
|
||||
///
|
||||
pub fn array_dimensions(dimensions1: Vec<usize>, dimensions2: Vec<usize>, span: &Span) -> Self {
|
||||
let message = format!(
|
||||
"Expected array with dimensions `{:?}`, found array with dimensions `{:?}`.",
|
||||
dimensions1, dimensions2
|
||||
);
|
||||
|
||||
Self::new_from_span(message, span.to_owned())
|
||||
}
|
||||
}
|
||||
|
@ -311,7 +311,7 @@ impl Eq for Type {}
|
||||
///
|
||||
/// Will flatten an array type `[[[u8; 1]; 2]; 3]` into `[u8; (3, 2, 1)]`.
|
||||
///
|
||||
fn flatten_array_type(type_: &Type, mut dimensions: Vec<usize>) -> (&Type, Vec<usize>) {
|
||||
pub fn flatten_array_type(type_: &Type, mut dimensions: Vec<usize>) -> (&Type, Vec<usize>) {
|
||||
if let Type::Array(element_type, element_dimensions) = type_ {
|
||||
dimensions.append(&mut element_dimensions.to_owned());
|
||||
flatten_array_type(element_type, dimensions)
|
||||
|
Loading…
Reference in New Issue
Block a user