From 6250e952771037d9584246aa99329814b1b0d1e9 Mon Sep 17 00:00:00 2001 From: collin Date: Thu, 17 Sep 2020 12:46:55 -0700 Subject: [PATCH] impl compare nested array types method for function return type --- compiler/src/statement/return_/return_.rs | 2 ++ compiler/src/value/value.rs | 6 ++-- compiler/tests/function/mod.rs | 34 +++++++++++++++++++ .../function/return_array_nested_fail.leo | 7 ++++ .../function/return_array_nested_pass.leo | 12 +++++++ .../function/return_array_tuple_fail.leo | 7 ++++ .../function/return_array_tuple_pass.leo | 12 +++++++ typed/src/types/type_.rs | 33 ++++++++++++++++++ 8 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 compiler/tests/function/return_array_nested_fail.leo create mode 100644 compiler/tests/function/return_array_nested_pass.leo create mode 100644 compiler/tests/function/return_array_tuple_fail.leo create mode 100644 compiler/tests/function/return_array_tuple_pass.leo diff --git a/compiler/src/statement/return_/return_.rs b/compiler/src/statement/return_/return_.rs index 53c2817279..686c22af35 100644 --- a/compiler/src/statement/return_/return_.rs +++ b/compiler/src/statement/return_/return_.rs @@ -30,6 +30,8 @@ fn check_return_type(expected: Option, actual: Type, span: Span) -> Result if expected.ne(&actual) { if expected.is_self() && actual.is_circuit() { return Ok(()); + } else if expected.match_array_types(&actual) { + return Ok(()); } else { return Err(StatementError::arguments_type(&expected, &actual, span)); } diff --git a/compiler/src/value/value.rs b/compiler/src/value/value.rs index 9ce2a93961..596f292e76 100644 --- a/compiler/src/value/value.rs +++ b/compiler/src/value/value.rs @@ -108,9 +108,9 @@ impl> ConstrainedValue { ConstrainedValue::Integer(integer) => Type::IntegerType(integer.get_type()), // Data type wrappers - ConstrainedValue::Array(types) => { - let array_type = types[0].to_type(span.clone())?; - let mut dimensions = vec![types.len()]; + ConstrainedValue::Array(array) => { + let array_type = array[0].to_type(span.clone())?; + let mut dimensions = vec![array.len()]; // Nested array type if let Type::Array(inner_type, inner_dimensions) = &array_type { diff --git a/compiler/tests/function/mod.rs b/compiler/tests/function/mod.rs index 43f8a39f80..d5a3c3548b 100644 --- a/compiler/tests/function/mod.rs +++ b/compiler/tests/function/mod.rs @@ -141,3 +141,37 @@ fn test_value_unchanged() { assert_satisfied(program); } + +// Test return multidimensional arrays + +#[test] +fn test_return_array_nested_fail() { + let bytes = include_bytes!("return_array_nested_fail.leo"); + let program = parse_program(bytes).unwrap(); + + expect_compiler_error(program); +} + +#[test] +fn test_return_array_nested_pass() { + let bytes = include_bytes!("return_array_nested_pass.leo"); + let program = parse_program(bytes).unwrap(); + + assert_satisfied(program); +} + +#[test] +fn test_return_array_tuple_fail() { + let bytes = include_bytes!("return_array_tuple_fail.leo"); + let program = parse_program(bytes).unwrap(); + + expect_compiler_error(program); +} + +#[test] +fn test_return_array_tuple_pass() { + let bytes = include_bytes!("return_array_tuple_pass.leo"); + let program = parse_program(bytes).unwrap(); + + assert_satisfied(program); +} diff --git a/compiler/tests/function/return_array_nested_fail.leo b/compiler/tests/function/return_array_nested_fail.leo new file mode 100644 index 0000000000..dca001d9cc --- /dev/null +++ b/compiler/tests/function/return_array_nested_fail.leo @@ -0,0 +1,7 @@ +function array_3x2_tuple() -> [[u8; 2]; 3] { + return [0u8; (2, 3)] // The correct 3x2 array tuple is `[0u8; (3, 2)]` +} + +function main() { + let b = array_3x2_tuple(); +} \ No newline at end of file diff --git a/compiler/tests/function/return_array_nested_pass.leo b/compiler/tests/function/return_array_nested_pass.leo new file mode 100644 index 0000000000..dda5b4342b --- /dev/null +++ b/compiler/tests/function/return_array_nested_pass.leo @@ -0,0 +1,12 @@ +function array_3x2_nested() -> [[u8; 2]; 3] { + return [[0u8; 2]; 3] +} + +function array_3x2_tuple() -> [[u8; 2]; 3] { + return [0u8; (3, 2)] +} + +function main() { + let a = array_3x2_nested(); + let b = array_3x2_tuple(); +} \ No newline at end of file diff --git a/compiler/tests/function/return_array_tuple_fail.leo b/compiler/tests/function/return_array_tuple_fail.leo new file mode 100644 index 0000000000..4b7377e327 --- /dev/null +++ b/compiler/tests/function/return_array_tuple_fail.leo @@ -0,0 +1,7 @@ +function array_3x2_nested() -> [u8; (3, 2)] { + return [[0u8; 3]; 2] // The correct 3x2 nested array is `[0u8; 2]; 3]` +} + +function main() { + let a = array_3x2_nested(); +} \ No newline at end of file diff --git a/compiler/tests/function/return_array_tuple_pass.leo b/compiler/tests/function/return_array_tuple_pass.leo new file mode 100644 index 0000000000..a700bcabad --- /dev/null +++ b/compiler/tests/function/return_array_tuple_pass.leo @@ -0,0 +1,12 @@ +function array_3x2_nested() -> [u8; (3, 2)] { + return [[0u8; 2]; 3] +} + +function array_3x2_tuple() -> [u8; (3, 2)] { + return [0u8; (3, 2)] +} + +function main() { + let a = array_3x2_nested(); + let b = array_3x2_tuple(); +} \ No newline at end of file diff --git a/typed/src/types/type_.rs b/typed/src/types/type_.rs index 70a3d107b0..6f017f7ed4 100644 --- a/typed/src/types/type_.rs +++ b/typed/src/types/type_.rs @@ -58,6 +58,26 @@ impl Type { false } + pub fn match_array_types(&self, other: &Type) -> bool { + // Check that both `self` and `other` are of type array + let (type_1, dimensions_1) = match self { + Type::Array(type_, dimensions) => (type_, dimensions), + _ => return false, + }; + + let (type_2, dimensions_2) = match other { + Type::Array(type_, dimensions) => (type_, dimensions), + _ => return false, + }; + + // Expand multidimensional array syntax + let (type_1_expanded, dimensions_1_expanded) = expand_array_type(type_1, dimensions_1); + let (type_2_expanded, dimensions_2_expanded) = expand_array_type(type_2, dimensions_2); + + // Return true if expanded array types and dimensions match + type_1_expanded.eq(&type_2_expanded) && dimensions_1_expanded.eq(&dimensions_2_expanded) + } + pub fn outer_dimension(&self, dimensions: &Vec) -> Self { let type_ = self.clone(); @@ -85,6 +105,19 @@ impl Type { } } +fn expand_array_type(type_: &Type, dimensions: &Vec) -> (Type, Vec) { + if let Type::Array(nested_type, nested_dimensions) = type_ { + // Expand nested array type + let mut expanded_dimensions = dimensions.clone(); + expanded_dimensions.append(&mut nested_dimensions.clone()); + + return expand_array_type(nested_type, &expanded_dimensions); + } else { + // Array type is fully expanded + (type_.clone(), dimensions.clone()) + } +} + /// pest ast -> Explicit Type for defining circuit members and function params impl From for Type {