Refactor symbol table array type comparison

This commit is contained in:
collin 2020-11-02 13:28:03 -08:00
parent 0868ef52d3
commit 6f3a235c76
6 changed files with 33 additions and 66 deletions

View File

@ -100,20 +100,6 @@ impl Type {
} }
} }
fn expand_array_type(_type: &Type, _dimensions: &[usize]) -> (Type, Vec<usize>) {
unimplemented!("deprecated")
// if let Type::Array(nested_type, nested_dimensions) = type_ {
// // Expand nested array type
// let mut expanded_dimensions = dimensions.to_vec();
// expanded_dimensions.append(&mut nested_dimensions.clone());
//
// expand_array_type(nested_type, &expanded_dimensions)
// } else {
// // Array type is fully expanded
// (type_.clone(), dimensions.to_vec())
// }
}
/// pest ast -> Explicit Type for defining circuit members and function params /// pest ast -> Explicit Type for defining circuit members and function params
impl From<DataType> for Type { impl From<DataType> for Type {

View File

@ -243,7 +243,7 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
), ),
// Arrays // Arrays
Expression::Array(array, span) => { Expression::ArrayInline(array, span) => {
self.enforce_array(cs, file_scope, function_scope, expected_type, array, span) self.enforce_array(cs, file_scope, function_scope, expected_type, array, span)
} }
Expression::ArrayAccess(array_w_index, span) => self.enforce_array_access( Expression::ArrayAccess(array_w_index, span) => self.enforce_array_access(

View File

@ -77,6 +77,12 @@ impl InputParserError {
InputParserError::SyntaxError(InputSyntaxError::from(error)) InputParserError::SyntaxError(InputSyntaxError::from(error))
} }
pub fn array_index(actual: String, span: Span) -> Self {
let message = format!("Expected constant number for array index, found `{}`", actual);
Self::new_from_span(message, span)
}
pub fn implicit_type(data_type: DataType, implicit: NumberValue) -> Self { pub fn implicit_type(data_type: DataType, implicit: NumberValue) -> Self {
let message = format!("expected `{}`, found `{}`", data_type, implicit); let message = format!("expected `{}`, found `{}`", data_type, implicit);

View File

@ -14,11 +14,10 @@
// You should have received a copy of the GNU General Public License // You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>. // along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use crate::{SymbolTable, TypeError, TypeVariable}; use crate::{SymbolTable, TypeError, TypeVariable};
use leo_ast::{Identifier, IntegerType, Span, Type as UnresolvedType}; use leo_ast::{ArrayDimensions, Identifier, IntegerType, Span, Type as UnresolvedType};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{
borrow::Cow,
cmp::{Eq, PartialEq}, cmp::{Eq, PartialEq},
fmt, fmt,
}; };
@ -34,7 +33,7 @@ pub enum Type {
IntegerType(IntegerType), IntegerType(IntegerType),
// Data type wrappers // Data type wrappers
Array(Box<Type>, Vec<usize>), Array(Box<Type>, ArrayDimensions),
Tuple(Vec<Type>), Tuple(Vec<Type>),
// User defined types // User defined types
@ -145,7 +144,7 @@ impl Type {
/// ///
/// Returns array element type and dimensions if self is an expected array type `Type::Array`. /// Returns array element type and dimensions if self is an expected array type `Type::Array`.
/// ///
pub fn get_type_array(&self, span: Span) -> Result<(&Type, &Vec<usize>), TypeError> { pub fn get_type_array(&self, span: Span) -> Result<(&Type, &ArrayDimensions), TypeError> {
match self { match self {
Type::Array(element_type, dimensions) => Ok((element_type, dimensions)), Type::Array(element_type, dimensions) => Ok((element_type, dimensions)),
// Throw mismatched type error // Throw mismatched type error
@ -275,15 +274,7 @@ impl fmt::Display for Type {
Type::Group => write!(f, "group"), Type::Group => write!(f, "group"),
Type::IntegerType(integer_type) => write!(f, "{}", integer_type), Type::IntegerType(integer_type) => write!(f, "{}", integer_type),
Type::Array(type_, dimensions) => { Type::Array(type_, dimensions) => write!(f, "[{}; {}]", *type_, dimensions),
let dimensions_string = dimensions
.iter()
.map(|dimension| dimension.to_string())
.collect::<Vec<_>>()
.join(", ");
write!(f, "[{}; ({})]", *type_, dimensions_string)
}
Type::Tuple(tuple) => { Type::Tuple(tuple) => {
let tuple_string = tuple.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(", "); let tuple_string = tuple.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(", ");
@ -306,13 +297,13 @@ impl PartialEq for Type {
(Type::Group, Type::Group) => true, (Type::Group, Type::Group) => true,
(Type::IntegerType(integer_type1), Type::IntegerType(integer_type2)) => integer_type1.eq(integer_type2), (Type::IntegerType(integer_type1), Type::IntegerType(integer_type2)) => integer_type1.eq(integer_type2),
(Type::Array(type1, dimensions1), Type::Array(type2, dimensions2)) => { (Type::Array(array1, _), Type::Array(array2, _)) => {
// Flatten both array types before comparison. // Get both array element types before comparison.
let (type1_flat, dimensions1_flat) = flatten_array_type(type1, Cow::from(dimensions1)); let array1_element = get_array_element_type(array1);
let (type2_flat, dimensions2_flat) = flatten_array_type(type2, Cow::from(dimensions2)); let array2_element = get_array_element_type(array2);
// Element types and dimensions must match // Check that both arrays have the same element type.
type1_flat.eq(type2_flat) && dimensions1_flat.eq(&dimensions2_flat) array1_element.eq(array2_element)
} }
(Type::Tuple(types1), Type::Tuple(types2)) => types1.eq(types2), (Type::Tuple(types1), Type::Tuple(types2)) => types1.eq(types2),
@ -327,15 +318,15 @@ impl PartialEq for Type {
impl Eq for Type {} impl Eq for Type {}
/// ///
/// Returns the data type of the array element and vector of dimensions. /// Returns the data type of the array element.
/// ///
/// Will flatten an array type `[[[u8; 1]; 2]; 3]` into `[u8; (3, 2, 1)]`. /// If the given `type_` is an array, call `get_array_element_type()` on the array element type.
/// If the given `type_` is any other type, return the `type_`.
/// ///
pub fn flatten_array_type<'a>(type_: &'a Type, mut dimensions: Cow<'a, [usize]>) -> (&'a Type, Cow<'a, [usize]>) { pub fn get_array_element_type(type_: &Type) -> &Type {
if let Type::Array(element_type, element_dimensions) = type_ { if let Type::Array(element_type, _) = type_ {
dimensions.to_mut().extend(element_dimensions); get_array_element_type(element_type)
flatten_array_type(element_type, dimensions)
} else { } else {
(type_, dimensions) type_
} }
} }

View File

@ -16,7 +16,7 @@
use crate::TypeAssertionError; use crate::TypeAssertionError;
use leo_ast::Span; use leo_ast::Span;
use leo_symbol_table::{flatten_array_type, Type, TypeVariable}; use leo_symbol_table::{get_array_element_type, Type, TypeVariable};
use std::borrow::Cow; use std::borrow::Cow;
/// A type variable -> type pair. /// A type variable -> type pair.
@ -88,8 +88,8 @@ impl TypeVariablePairs {
match (left, right) { match (left, right) {
(Type::TypeVariable(variable), type_) => Ok(self.push(variable, type_)), (Type::TypeVariable(variable), type_) => Ok(self.push(variable, type_)),
(type_, Type::TypeVariable(variable)) => Ok(self.push(variable, type_)), (type_, Type::TypeVariable(variable)) => Ok(self.push(variable, type_)),
(Type::Array(left_type, left_dimensions), Type::Array(right_type, right_dimensions)) => { (Type::Array(left_type, _), Type::Array(right_type, _)) => {
self.push_pairs_array(*left_type, left_dimensions, *right_type, right_dimensions, span) self.push_pairs_array(*left_type, *right_type, span)
} }
(Type::Tuple(left_types), Type::Tuple(right_types)) => { (Type::Tuple(left_types), Type::Tuple(right_types)) => {
self.push_pairs_tuple(left_types.into_iter(), right_types.into_iter(), span) self.push_pairs_tuple(left_types.into_iter(), right_types.into_iter(), span)
@ -103,29 +103,13 @@ impl TypeVariablePairs {
/// If a `TypeVariable` is found, create a new `TypeVariablePair` between the given left /// If a `TypeVariable` is found, create a new `TypeVariablePair` between the given left
/// and right type. /// and right type.
/// ///
fn push_pairs_array( fn push_pairs_array(&mut self, left_type: Type, right_type: Type, span: &Span) -> Result<(), TypeAssertionError> {
&mut self, // Get both array element types before comparison.
left_type: Type, let array1_element = get_array_element_type(&left_type);
left_dimensions: Vec<usize>, let array2_element = get_array_element_type(&right_type);
right_type: Type,
right_dimensions: Vec<usize>,
span: &Span,
) -> Result<(), TypeAssertionError> {
// Flatten the array types to get the element types.
let (left_type_flat, left_dimensions_flat) = flatten_array_type(&left_type, Cow::from(&left_dimensions));
let (right_type_flat, right_dimensions_flat) = flatten_array_type(&right_type, Cow::from(&right_dimensions));
// If the dimensions do not match, then throw an error.
if left_dimensions_flat.ne(&right_dimensions_flat) {
return Err(TypeAssertionError::array_dimensions(
&left_dimensions_flat,
&right_dimensions_flat,
span,
));
}
// Compare the array element types. // Compare the array element types.
self.push_pairs(left_type_flat.to_owned(), right_type_flat.to_owned(), span) self.push_pairs(array1_element.to_owned(), array2_element.to_owned(), span)
} }
/// ///

View File

@ -548,7 +548,7 @@ impl Frame {
} }
// Arrays // Arrays
Expression::Array(expressions, span) => self.parse_array(expressions, span), Expression::ArrayInline(expressions, span) => self.parse_array(expressions, span),
Expression::ArrayAccess(array_w_index, span) => { Expression::ArrayAccess(array_w_index, span) => {
self.parse_expression_array_access(&array_w_index.0, &array_w_index.1, span) self.parse_expression_array_access(&array_w_index.0, &array_w_index.1, span)
} }