mirror of
https://github.com/ProvableHQ/leo.git
synced 2024-11-11 01:45:48 +03:00
Refactor symbol table array type comparison
This commit is contained in:
parent
0868ef52d3
commit
6f3a235c76
@ -100,20 +100,6 @@ impl Type {
|
||||
}
|
||||
}
|
||||
|
||||
fn expand_array_type(_type: &Type, _dimensions: &[usize]) -> (Type, Vec<usize>) {
|
||||
unimplemented!("deprecated")
|
||||
// if let Type::Array(nested_type, nested_dimensions) = type_ {
|
||||
// // Expand nested array type
|
||||
// let mut expanded_dimensions = dimensions.to_vec();
|
||||
// expanded_dimensions.append(&mut nested_dimensions.clone());
|
||||
//
|
||||
// expand_array_type(nested_type, &expanded_dimensions)
|
||||
// } else {
|
||||
// // Array type is fully expanded
|
||||
// (type_.clone(), dimensions.to_vec())
|
||||
// }
|
||||
}
|
||||
|
||||
/// pest ast -> Explicit Type for defining circuit members and function params
|
||||
|
||||
impl From<DataType> for Type {
|
||||
|
@ -243,7 +243,7 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
|
||||
),
|
||||
|
||||
// Arrays
|
||||
Expression::Array(array, span) => {
|
||||
Expression::ArrayInline(array, span) => {
|
||||
self.enforce_array(cs, file_scope, function_scope, expected_type, array, span)
|
||||
}
|
||||
Expression::ArrayAccess(array_w_index, span) => self.enforce_array_access(
|
||||
|
@ -77,6 +77,12 @@ impl InputParserError {
|
||||
InputParserError::SyntaxError(InputSyntaxError::from(error))
|
||||
}
|
||||
|
||||
pub fn array_index(actual: String, span: Span) -> Self {
|
||||
let message = format!("Expected constant number for array index, found `{}`", actual);
|
||||
|
||||
Self::new_from_span(message, span)
|
||||
}
|
||||
|
||||
pub fn implicit_type(data_type: DataType, implicit: NumberValue) -> Self {
|
||||
let message = format!("expected `{}`, found `{}`", data_type, implicit);
|
||||
|
||||
|
@ -14,11 +14,10 @@
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
|
||||
use crate::{SymbolTable, TypeError, TypeVariable};
|
||||
use leo_ast::{Identifier, IntegerType, Span, Type as UnresolvedType};
|
||||
use leo_ast::{ArrayDimensions, Identifier, IntegerType, Span, Type as UnresolvedType};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
cmp::{Eq, PartialEq},
|
||||
fmt,
|
||||
};
|
||||
@ -34,7 +33,7 @@ pub enum Type {
|
||||
IntegerType(IntegerType),
|
||||
|
||||
// Data type wrappers
|
||||
Array(Box<Type>, Vec<usize>),
|
||||
Array(Box<Type>, ArrayDimensions),
|
||||
Tuple(Vec<Type>),
|
||||
|
||||
// User defined types
|
||||
@ -145,7 +144,7 @@ impl Type {
|
||||
///
|
||||
/// Returns array element type and dimensions if self is an expected array type `Type::Array`.
|
||||
///
|
||||
pub fn get_type_array(&self, span: Span) -> Result<(&Type, &Vec<usize>), TypeError> {
|
||||
pub fn get_type_array(&self, span: Span) -> Result<(&Type, &ArrayDimensions), TypeError> {
|
||||
match self {
|
||||
Type::Array(element_type, dimensions) => Ok((element_type, dimensions)),
|
||||
// Throw mismatched type error
|
||||
@ -275,15 +274,7 @@ impl fmt::Display for Type {
|
||||
Type::Group => write!(f, "group"),
|
||||
Type::IntegerType(integer_type) => write!(f, "{}", integer_type),
|
||||
|
||||
Type::Array(type_, dimensions) => {
|
||||
let dimensions_string = dimensions
|
||||
.iter()
|
||||
.map(|dimension| dimension.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
write!(f, "[{}; ({})]", *type_, dimensions_string)
|
||||
}
|
||||
Type::Array(type_, dimensions) => write!(f, "[{}; {}]", *type_, dimensions),
|
||||
Type::Tuple(tuple) => {
|
||||
let tuple_string = tuple.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(", ");
|
||||
|
||||
@ -306,13 +297,13 @@ impl PartialEq for Type {
|
||||
(Type::Group, Type::Group) => true,
|
||||
(Type::IntegerType(integer_type1), Type::IntegerType(integer_type2)) => integer_type1.eq(integer_type2),
|
||||
|
||||
(Type::Array(type1, dimensions1), Type::Array(type2, dimensions2)) => {
|
||||
// Flatten both array types before comparison.
|
||||
let (type1_flat, dimensions1_flat) = flatten_array_type(type1, Cow::from(dimensions1));
|
||||
let (type2_flat, dimensions2_flat) = flatten_array_type(type2, Cow::from(dimensions2));
|
||||
(Type::Array(array1, _), Type::Array(array2, _)) => {
|
||||
// Get both array element types before comparison.
|
||||
let array1_element = get_array_element_type(array1);
|
||||
let array2_element = get_array_element_type(array2);
|
||||
|
||||
// Element types and dimensions must match
|
||||
type1_flat.eq(type2_flat) && dimensions1_flat.eq(&dimensions2_flat)
|
||||
// Check that both arrays have the same element type.
|
||||
array1_element.eq(array2_element)
|
||||
}
|
||||
|
||||
(Type::Tuple(types1), Type::Tuple(types2)) => types1.eq(types2),
|
||||
@ -327,15 +318,15 @@ impl PartialEq for Type {
|
||||
impl Eq for Type {}
|
||||
|
||||
///
|
||||
/// Returns the data type of the array element and vector of dimensions.
|
||||
/// Returns the data type of the array element.
|
||||
///
|
||||
/// Will flatten an array type `[[[u8; 1]; 2]; 3]` into `[u8; (3, 2, 1)]`.
|
||||
/// If the given `type_` is an array, call `get_array_element_type()` on the array element type.
|
||||
/// If the given `type_` is any other type, return the `type_`.
|
||||
///
|
||||
pub fn flatten_array_type<'a>(type_: &'a Type, mut dimensions: Cow<'a, [usize]>) -> (&'a Type, Cow<'a, [usize]>) {
|
||||
if let Type::Array(element_type, element_dimensions) = type_ {
|
||||
dimensions.to_mut().extend(element_dimensions);
|
||||
flatten_array_type(element_type, dimensions)
|
||||
pub fn get_array_element_type(type_: &Type) -> &Type {
|
||||
if let Type::Array(element_type, _) = type_ {
|
||||
get_array_element_type(element_type)
|
||||
} else {
|
||||
(type_, dimensions)
|
||||
type_
|
||||
}
|
||||
}
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
use crate::TypeAssertionError;
|
||||
use leo_ast::Span;
|
||||
use leo_symbol_table::{flatten_array_type, Type, TypeVariable};
|
||||
use leo_symbol_table::{get_array_element_type, Type, TypeVariable};
|
||||
use std::borrow::Cow;
|
||||
|
||||
/// A type variable -> type pair.
|
||||
@ -88,8 +88,8 @@ impl TypeVariablePairs {
|
||||
match (left, right) {
|
||||
(Type::TypeVariable(variable), type_) => Ok(self.push(variable, type_)),
|
||||
(type_, Type::TypeVariable(variable)) => Ok(self.push(variable, type_)),
|
||||
(Type::Array(left_type, left_dimensions), Type::Array(right_type, right_dimensions)) => {
|
||||
self.push_pairs_array(*left_type, left_dimensions, *right_type, right_dimensions, span)
|
||||
(Type::Array(left_type, _), Type::Array(right_type, _)) => {
|
||||
self.push_pairs_array(*left_type, *right_type, span)
|
||||
}
|
||||
(Type::Tuple(left_types), Type::Tuple(right_types)) => {
|
||||
self.push_pairs_tuple(left_types.into_iter(), right_types.into_iter(), span)
|
||||
@ -103,29 +103,13 @@ impl TypeVariablePairs {
|
||||
/// If a `TypeVariable` is found, create a new `TypeVariablePair` between the given left
|
||||
/// and right type.
|
||||
///
|
||||
fn push_pairs_array(
|
||||
&mut self,
|
||||
left_type: Type,
|
||||
left_dimensions: Vec<usize>,
|
||||
right_type: Type,
|
||||
right_dimensions: Vec<usize>,
|
||||
span: &Span,
|
||||
) -> Result<(), TypeAssertionError> {
|
||||
// Flatten the array types to get the element types.
|
||||
let (left_type_flat, left_dimensions_flat) = flatten_array_type(&left_type, Cow::from(&left_dimensions));
|
||||
let (right_type_flat, right_dimensions_flat) = flatten_array_type(&right_type, Cow::from(&right_dimensions));
|
||||
|
||||
// If the dimensions do not match, then throw an error.
|
||||
if left_dimensions_flat.ne(&right_dimensions_flat) {
|
||||
return Err(TypeAssertionError::array_dimensions(
|
||||
&left_dimensions_flat,
|
||||
&right_dimensions_flat,
|
||||
span,
|
||||
));
|
||||
}
|
||||
fn push_pairs_array(&mut self, left_type: Type, right_type: Type, span: &Span) -> Result<(), TypeAssertionError> {
|
||||
// Get both array element types before comparison.
|
||||
let array1_element = get_array_element_type(&left_type);
|
||||
let array2_element = get_array_element_type(&right_type);
|
||||
|
||||
// Compare the array element types.
|
||||
self.push_pairs(left_type_flat.to_owned(), right_type_flat.to_owned(), span)
|
||||
self.push_pairs(array1_element.to_owned(), array2_element.to_owned(), span)
|
||||
}
|
||||
|
||||
///
|
||||
|
@ -548,7 +548,7 @@ impl Frame {
|
||||
}
|
||||
|
||||
// Arrays
|
||||
Expression::Array(expressions, span) => self.parse_array(expressions, span),
|
||||
Expression::ArrayInline(expressions, span) => self.parse_array(expressions, span),
|
||||
Expression::ArrayAccess(array_w_index, span) => {
|
||||
self.parse_expression_array_access(&array_w_index.0, &array_w_index.1, span)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user