array type refactor finished. all tests pass

This commit is contained in:
collin 2020-11-10 13:22:18 -08:00
parent d6686005a3
commit 04f4e685ed
5 changed files with 67 additions and 61 deletions

View File

@ -81,7 +81,7 @@ impl ArrayDimensions {
///
pub fn remove_first(&mut self) -> Option<PositiveNumber> {
// If there are no dimensions in the array, then return None.
if self.0.get(0).is_none() {
if self.0.first().is_none() {
return None;
}
@ -91,6 +91,16 @@ impl ArrayDimensions {
// Return the first dimension.
Some(removed)
}
///
/// Attempts to remove the last dimension from the array.
///
/// If the last dimension exists, then remove and return `Some(PositiveNumber)`.
/// If the last dimension does not exist, then return `None`.
///
pub fn remove_last(&mut self) -> Option<PositiveNumber> {
self.0.pop()
}
}
/// Create a new [`ArrayDimensions`] from a [`GrammarArrayDimensions`] in a Leo program file.

View File

@ -24,13 +24,10 @@ use leo_input::types::{
};
use serde::{Deserialize, Serialize};
use std::{
fmt,
hash::{Hash, Hasher},
};
use std::fmt;
/// Explicit type used for defining a variable or expression type
#[derive(Clone, Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Type {
// Data types
Address,
@ -47,13 +44,62 @@ pub enum Type {
}
impl Type {
///
/// Returns `true` if the self `Type` is the `SelfType`.
///
pub fn is_self(&self) -> bool {
matches!(self, Type::SelfType)
}
///
/// Returns `true` if the self `Type` is a `Circuit`.
///
pub fn is_circuit(&self) -> bool {
matches!(self, Type::Circuit(_))
}
///
/// Returns `true` if the self `Type` is equal to the other `Type`.
///
/// Flattens array syntax: `[[u8; 1]; 2] == [u8; (2, 1)] == true`
///
pub fn eq_flat(&self, other: &Self) -> bool {
match (self, other) {
(Type::Address, Type::Address) => true,
(Type::Boolean, Type::Boolean) => true,
(Type::Field, Type::Field) => true,
(Type::Group, Type::Group) => true,
(Type::IntegerType(left), Type::IntegerType(right)) => left.eq(&right),
(Type::Circuit(left), Type::Circuit(right)) => left.eq(&right),
(Type::SelfType, Type::SelfType) => true,
(Type::Array(left_type, left_dim), Type::Array(right_type, right_dim)) => {
// Convert array dimensions to owned.
let mut left_dim_owned = left_dim.to_owned();
let mut right_dim_owned = right_dim.to_owned();
// Remove the first element from both dimensions.
let left_first = left_dim_owned.remove_first();
let right_first = right_dim_owned.remove_first();
// Compare the first dimensions.
if left_first.ne(&right_first) {
return false;
}
// Create a new array type from the remaining array dimensions.
let left_new_type = inner_array_type(*left_type.to_owned(), left_dim_owned);
let right_new_type = inner_array_type(*right_type.to_owned(), right_dim_owned);
// Call eq_flat() on the new left and right types.
return left_new_type.eq_flat(&right_new_type);
}
(Type::Tuple(left), Type::Tuple(right)) => left
.iter()
.zip(right)
.all(|(left_type, right_type)| left_type.eq_flat(right_type)),
_ => false,
}
}
}
/// pest ast -> Explicit Type for defining circuit members and function params
@ -166,55 +212,6 @@ impl fmt::Display for Type {
}
}
/// Compares two types while flattening array types.
impl PartialEq for Type {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Type::Address, Type::Address) => true,
(Type::Boolean, Type::Boolean) => true,
(Type::Field, Type::Field) => true,
(Type::Group, Type::Group) => true,
(Type::IntegerType(left), Type::IntegerType(right)) => left.eq(&right),
(Type::Circuit(left), Type::Circuit(right)) => left.eq(&right),
(Type::SelfType, Type::SelfType) => true,
(Type::Array(left_type, left_dim), Type::Array(right_type, right_dim)) => {
let mut left_dim_owned = left_dim.to_owned();
let mut right_dim_owned = right_dim.to_owned();
println!("left_owned {}", left_dim_owned);
println!("right_owned {}", right_dim_owned);
let left_first = left_dim_owned.remove_first();
let right_first = right_dim_owned.remove_first();
if left_first.ne(&right_first) {
return false;
}
let left_new_type = inner_array_type(*left_type.to_owned(), left_dim_owned);
let right_new_type = inner_array_type(*right_type.to_owned(), right_dim_owned);
println!("left_new {}", left_new_type);
println!("right_new {}", right_new_type);
return left_new_type.eq(&right_new_type);
}
(Type::Tuple(left), Type::Tuple(right)) => left
.iter()
.zip(right)
.all(|(left_type, right_type)| left_type.eq(right_type)),
_ => false,
}
}
}
impl Eq for Type {}
impl Hash for Type {
fn hash<H: Hasher>(&self, state: &mut H) {
self.hash(state)
}
}
///
/// Returns the type of the inner array given an array element and array dimensions.
///

View File

@ -130,7 +130,7 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
self.enforce_expression(cs, file_scope, function_scope, Some(*type_), element_expression)?;
// Allocate the array.
while let Some(dimension) = actual_dimensions.remove_first() {
while let Some(dimension) = actual_dimensions.remove_last() {
// Parse the dimension into a `usize`.
let dimension_usize = parse_index(&dimension, &span)?;
@ -208,7 +208,7 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
self.enforce_expression(cs, file_scope, function_scope, expected_type, element_expression)?;
// Allocate the array.
while let Some(dimension) = actual_dimensions.remove_first() {
while let Some(dimension) = actual_dimensions.remove_last() {
// Parse the dimension into a `usize`.
let dimension_usize = parse_index(&dimension, &span)?;

View File

@ -28,7 +28,7 @@ fn check_return_type(expected: Option<Type>, actual: Type, span: &Span) -> Resul
match expected {
Some(expected) => {
if expected.ne(&actual) {
if (expected.is_self() && actual.is_circuit()) || expected.eq(&actual) {
if (expected.is_self() && actual.is_circuit()) || expected.eq_flat(&actual) {
return Ok(());
} else {
return Err(StatementError::arguments_type(&expected, &actual, span.to_owned()));
@ -50,10 +50,9 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
return_type: Option<Type>,
span: &Span,
) -> Result<ConstrainedValue<F, G>, StatementError> {
// Make sure we return the correct number of values
let result = self.enforce_operand(cs, file_scope, function_scope, return_type.clone(), expression, span)?;
// Make sure we return the correct type.
check_return_type(return_type, result.to_type(&span)?, span)?;
Ok(result)

View File

@ -1,2 +1,2 @@
[registers]
r0: [u8; (32)] = [174, 9, 219, 124, 213, 79, 66, 180, 144, 239, 9, 182, 188, 84, 26, 246, 136, 228, 149, 155, 184, 197, 63, 53, 154, 111, 86, 227, 138, 180, 84, 163];
r0: [u8; 32] = [174, 9, 219, 124, 213, 79, 66, 180, 144, 239, 9, 182, 188, 84, 26, 246, 136, 228, 149, 155, 184, 197, 63, 53, 154, 111, 86, 227, 138, 180, 84, 163];