impl compare nested array types method for function return type

This commit is contained in:
collin 2020-09-17 12:46:55 -07:00
parent 23bdecf6c4
commit 6250e95277
8 changed files with 110 additions and 3 deletions

View File

@ -30,6 +30,8 @@ fn check_return_type(expected: Option<Type>, actual: Type, span: Span) -> Result
if expected.ne(&actual) {
if expected.is_self() && actual.is_circuit() {
return Ok(());
} else if expected.match_array_types(&actual) {
return Ok(());
} else {
return Err(StatementError::arguments_type(&expected, &actual, span));
}

View File

@ -108,9 +108,9 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedValue<F, G> {
ConstrainedValue::Integer(integer) => Type::IntegerType(integer.get_type()),
// Data type wrappers
ConstrainedValue::Array(types) => {
let array_type = types[0].to_type(span.clone())?;
let mut dimensions = vec![types.len()];
ConstrainedValue::Array(array) => {
let array_type = array[0].to_type(span.clone())?;
let mut dimensions = vec![array.len()];
// Nested array type
if let Type::Array(inner_type, inner_dimensions) = &array_type {

View File

@ -141,3 +141,37 @@ fn test_value_unchanged() {
assert_satisfied(program);
}
// Test return multidimensional arrays
#[test]
fn test_return_array_nested_fail() {
let bytes = include_bytes!("return_array_nested_fail.leo");
let program = parse_program(bytes).unwrap();
expect_compiler_error(program);
}
#[test]
fn test_return_array_nested_pass() {
let bytes = include_bytes!("return_array_nested_pass.leo");
let program = parse_program(bytes).unwrap();
assert_satisfied(program);
}
#[test]
fn test_return_array_tuple_fail() {
let bytes = include_bytes!("return_array_tuple_fail.leo");
let program = parse_program(bytes).unwrap();
expect_compiler_error(program);
}
#[test]
fn test_return_array_tuple_pass() {
let bytes = include_bytes!("return_array_tuple_pass.leo");
let program = parse_program(bytes).unwrap();
assert_satisfied(program);
}

View File

@ -0,0 +1,7 @@
function array_3x2_tuple() -> [[u8; 2]; 3] {
return [0u8; (2, 3)] // The correct 3x2 array tuple is `[0u8; (3, 2)]`
}
function main() {
let b = array_3x2_tuple();
}

View File

@ -0,0 +1,12 @@
function array_3x2_nested() -> [[u8; 2]; 3] {
return [[0u8; 2]; 3]
}
function array_3x2_tuple() -> [[u8; 2]; 3] {
return [0u8; (3, 2)]
}
function main() {
let a = array_3x2_nested();
let b = array_3x2_tuple();
}

View File

@ -0,0 +1,7 @@
function array_3x2_nested() -> [u8; (3, 2)] {
return [[0u8; 3]; 2] // The correct 3x2 nested array is `[0u8; 2]; 3]`
}
function main() {
let a = array_3x2_nested();
}

View File

@ -0,0 +1,12 @@
function array_3x2_nested() -> [u8; (3, 2)] {
return [[0u8; 2]; 3]
}
function array_3x2_tuple() -> [u8; (3, 2)] {
return [0u8; (3, 2)]
}
function main() {
let a = array_3x2_nested();
let b = array_3x2_tuple();
}

View File

@ -58,6 +58,26 @@ impl Type {
false
}
pub fn match_array_types(&self, other: &Type) -> bool {
// Check that both `self` and `other` are of type array
let (type_1, dimensions_1) = match self {
Type::Array(type_, dimensions) => (type_, dimensions),
_ => return false,
};
let (type_2, dimensions_2) = match other {
Type::Array(type_, dimensions) => (type_, dimensions),
_ => return false,
};
// Expand multidimensional array syntax
let (type_1_expanded, dimensions_1_expanded) = expand_array_type(type_1, dimensions_1);
let (type_2_expanded, dimensions_2_expanded) = expand_array_type(type_2, dimensions_2);
// Return true if expanded array types and dimensions match
type_1_expanded.eq(&type_2_expanded) && dimensions_1_expanded.eq(&dimensions_2_expanded)
}
pub fn outer_dimension(&self, dimensions: &Vec<usize>) -> Self {
let type_ = self.clone();
@ -85,6 +105,19 @@ impl Type {
}
}
fn expand_array_type(type_: &Type, dimensions: &Vec<usize>) -> (Type, Vec<usize>) {
if let Type::Array(nested_type, nested_dimensions) = type_ {
// Expand nested array type
let mut expanded_dimensions = dimensions.clone();
expanded_dimensions.append(&mut nested_dimensions.clone());
return expand_array_type(nested_type, &expanded_dimensions);
} else {
// Array type is fully expanded
(type_.clone(), dimensions.clone())
}
}
/// pest ast -> Explicit Type for defining circuit members and function params
impl From<DataType> for Type {