diff --git a/benchmark/simple.leo b/benchmark/simple.leo index 5892cfd8a2..8f93db5c7e 100644 --- a/benchmark/simple.leo +++ b/benchmark/simple.leo @@ -1,3 +1,10 @@ -function main() { - let a: u32 = if true ? 1 : 5; +function test(mut a: u32) { + a = 0; +} + +function main() -> u32 { + let a = 1; + test(a); + + return a // <- returns 1 } \ No newline at end of file diff --git a/compiler/src/constraints/expression.rs b/compiler/src/constraints/expression.rs index 0739393c42..5fc60e12b4 100644 --- a/compiler/src/constraints/expression.rs +++ b/compiler/src/constraints/expression.rs @@ -142,6 +142,9 @@ impl> ConstrainedProgra (ConstrainedValue::FieldElement(fe_1), ConstrainedValue::FieldElement(fe_2)) => { Ok(self.enforce_field_mul(cs, fe_1, fe_2)?) } + // (ConstrainedValue::GroupElement(group), ConstrainedValue::FieldElement(scalar)) => { + // Ok(Self::evaluate_group_mul(group, scalar)) + // } (ConstrainedValue::Mutable(val_1), val_2) => { self.enforce_mul_expression(cs, *val_1, val_2) } @@ -439,9 +442,20 @@ impl> ConstrainedProgra cs: &mut CS, file_scope: String, function_scope: String, - expected_types: Vec>, + mut expected_types: Vec>, array: Vec>>, ) -> Result, ExpressionError> { + // Check explicit array type dimension if given + let expected_dimensions = vec![]; + if !expected_types.is_empty() { + match expected_types[0] { + Type::Array(ref _type, ref dimensions) => { + expected_types = vec![expected_types[0].inner_dimension(dimensions)]; + } + ref _type => return Err(ExpressionError::IncompatibleTypes(_type.to_string())), + } + } + let mut result = vec![]; for element in array.into_iter() { match *element { @@ -471,6 +485,17 @@ impl> ConstrainedProgra } } } + + // Check expected_dimensions if given + if !expected_dimensions.is_empty() { + if expected_dimensions[expected_dimensions.len() - 1] != result.len() { + return Err(ExpressionError::InvalidLength( + expected_dimensions[expected_dimensions.len() - 1], + result.len(), + )); + } + } + Ok(ConstrainedValue::Array(result)) } @@ -484,6 +509,7 @@ impl> ConstrainedProgra ) -> Result { match self.enforce_expression(cs, file_scope, function_scope, expected_types, index)? { ConstrainedValue::Integer(number) => Ok(number.to_usize()), + ConstrainedValue::Unresolved(string) => Ok(string.parse::()?), value => Err(ExpressionError::InvalidIndex(value.to_string())), } } diff --git a/compiler/src/constraints/function.rs b/compiler/src/constraints/function.rs index 57b6a4427f..000457c605 100644 --- a/compiler/src/constraints/function.rs +++ b/compiler/src/constraints/function.rs @@ -127,7 +127,7 @@ impl> ConstrainedProgra // Allocate each value in the current row for (i, value) in arr.into_iter().enumerate() { let value_name = new_scope(name.clone(), i.to_string()); - let value_type = array_type.next_dimension(&array_dimensions); + let value_type = array_type.outer_dimension(&array_dimensions); array_value.push(self.allocate_main_function_input( cs, @@ -142,7 +142,7 @@ impl> ConstrainedProgra // Allocate all row values as none for i in 0..expected_length { let value_name = new_scope(name.clone(), i.to_string()); - let value_type = array_type.next_dimension(&array_dimensions); + let value_type = array_type.outer_dimension(&array_dimensions); array_value.push( self.allocate_main_function_input( diff --git a/compiler/src/constraints/group_element.rs b/compiler/src/constraints/group_element.rs index 9ccc5e0284..a9b2fbbad0 100644 --- a/compiler/src/constraints/group_element.rs +++ b/compiler/src/constraints/group_element.rs @@ -17,4 +17,8 @@ impl> ConstrainedProgra pub fn evaluate_group_sub(group_element_1: G, group_element_2: G) -> ConstrainedValue { ConstrainedValue::GroupElement(group_element_1.sub(&group_element_2)) } + // + // pub fn evaluate_group_mul(group_element: G, scalar_field: G::ScalarField) -> ConstrainedValue { + // ConstrainedValue::GroupElement(group_element.mul(&scalar_field)) + // } } diff --git a/compiler/src/constraints/value.rs b/compiler/src/constraints/value.rs index c1ae556081..17039e56b3 100644 --- a/compiler/src/constraints/value.rs +++ b/compiler/src/constraints/value.rs @@ -51,27 +51,30 @@ impl ConstrainedValue { } pub(crate) fn from_type(value: String, _type: &Type) -> Result { - Ok(match _type { - Type::IntegerType(integer_type) => ConstrainedValue::Integer(match integer_type { + match _type { + Type::IntegerType(integer_type) => Ok(ConstrainedValue::Integer(match integer_type { IntegerType::U8 => Integer::U8(UInt8::constant(value.parse::()?)), IntegerType::U16 => Integer::U16(UInt16::constant(value.parse::()?)), IntegerType::U32 => Integer::U32(UInt32::constant(value.parse::()?)), IntegerType::U64 => Integer::U64(UInt64::constant(value.parse::()?)), IntegerType::U128 => Integer::U128(UInt128::constant(value.parse::()?)), - }), - Type::FieldElement => ConstrainedValue::FieldElement(FieldElement::Constant( + })), + Type::FieldElement => Ok(ConstrainedValue::FieldElement(FieldElement::Constant( F::from_str(&value).unwrap_or_default(), - )), - Type::GroupElement => ConstrainedValue::GroupElement({ + ))), + Type::GroupElement => Ok(ConstrainedValue::GroupElement({ use std::str::FromStr; let scalar = G::ScalarField::from_str(&value).unwrap_or_default(); let point = G::default().mul(&scalar); point - }), - Type::Boolean => ConstrainedValue::Boolean(Boolean::Constant(value.parse::()?)), - _ => ConstrainedValue::Unresolved(value), - }) + })), + Type::Boolean => Ok(ConstrainedValue::Boolean(Boolean::Constant( + value.parse::()?, + ))), + Type::Array(ref _type, _dimensions) => ConstrainedValue::from_type(value, _type), + _ => Ok(ConstrainedValue::Unresolved(value)), + } } pub(crate) fn to_type(&self) -> Type { diff --git a/compiler/src/errors/constraints/expression.rs b/compiler/src/errors/constraints/expression.rs index 0d011f482e..bbe3939635 100644 --- a/compiler/src/errors/constraints/expression.rs +++ b/compiler/src/errors/constraints/expression.rs @@ -1,6 +1,7 @@ use crate::errors::{BooleanError, FieldElementError, FunctionError, IntegerError, ValueError}; use snarkos_errors::gadgets::SynthesisError; +use std::num::ParseIntError; #[derive(Debug, Error)] pub enum ExpressionError { @@ -21,6 +22,9 @@ pub enum ExpressionError { #[error("{}", _0)] IntegerError(IntegerError), + #[error("{}", _0)] + ParseIntError(ParseIntError), + #[error("{}", _0)] FieldElementError(FieldElementError), @@ -46,6 +50,9 @@ pub enum ExpressionError { #[error("Index must resolve to an integer, got {}", _0)] InvalidIndex(String), + #[error("Expected array length {}, got {}", _0, _1)] + InvalidLength(usize, usize), + // Circuits #[error( "Circuit {} must be declared before it is used in an inline expression", @@ -110,6 +117,12 @@ impl From for ExpressionError { } } +impl From for ExpressionError { + fn from(error: ParseIntError) -> Self { + ExpressionError::ParseIntError(error) + } +} + impl From for ExpressionError { fn from(error: FieldElementError) -> Self { ExpressionError::FieldElementError(error) diff --git a/compiler/src/types.rs b/compiler/src/types.rs index 4b36aacabc..be79c3efa5 100644 --- a/compiler/src/types.rs +++ b/compiler/src/types.rs @@ -156,7 +156,7 @@ pub enum Type { } impl Type { - pub fn next_dimension(&self, dimensions: &Vec) -> Self { + pub fn outer_dimension(&self, dimensions: &Vec) -> Self { let _type = self.clone(); if dimensions.len() > 1 { @@ -168,6 +168,19 @@ impl Type { _type } + + pub fn inner_dimension(&self, dimensions: &Vec) -> Self { + let _type = self.clone(); + + if dimensions.len() > 1 { + let mut next = vec![]; + next.extend_from_slice(&dimensions[..dimensions.len() - 1]); + + return Type::Array(Box::new(_type), next); + } + + _type + } } #[derive(Clone, PartialEq, Eq)] diff --git a/compiler/src/types_from.rs b/compiler/src/types_from.rs index f8e6955dfd..5db3726175 100644 --- a/compiler/src/types_from.rs +++ b/compiler/src/types_from.rs @@ -61,6 +61,12 @@ impl<'ast> types::Integer { )), } } + + pub(crate) fn from_implicit(number: String) -> Self { + types::Integer::U128(UInt128::constant( + number.parse::().expect("unable to parse u128"), + )) + } } impl<'ast, F: Field + PrimeField, G: Group> From> for types::Expression { @@ -79,6 +85,9 @@ impl<'ast, F: Field + PrimeField, G: Group> From> .from .map(|from| match types::Expression::::from(from.0) { types::Expression::Integer(number) => number, + types::Expression::Implicit(string) => { + types::Integer::from_implicit(string) + } expression => { unimplemented!("Range bounds should be integers, found {}", expression) } @@ -87,6 +96,9 @@ impl<'ast, F: Field + PrimeField, G: Group> From> .to .map(|to| match types::Expression::::from(to.0) { types::Expression::Integer(number) => number, + types::Expression::Implicit(string) => { + types::Integer::from_implicit(string) + } expression => { unimplemented!("Range bounds should be integers, found {}", expression) } @@ -375,7 +387,12 @@ impl<'ast, F: Field + PrimeField, G: Group> From> impl<'ast, F: Field + PrimeField, G: Group> types::Expression { fn get_count(count: ast::Value<'ast>) -> usize { match count { - ast::Value::Integer(f) => f + ast::Value::Integer(integer) => integer + .number + .value + .parse::() + .expect("Unable to read array size"), + ast::Value::Implicit(number) => number .number .value .parse::() @@ -590,10 +607,12 @@ impl<'ast, F: Field + PrimeField, G: Group> From> fn from(statement: ast::ForStatement<'ast>) -> Self { let from = match types::Expression::::from(statement.start) { types::Expression::Integer(number) => number, + types::Expression::Implicit(string) => types::Integer::from_implicit(string), expression => unimplemented!("Range bounds should be integers, found {}", expression), }; let to = match types::Expression::::from(statement.stop) { types::Expression::Integer(number) => number, + types::Expression::Implicit(string) => types::Integer::from_implicit(string), expression => unimplemented!("Range bounds should be integers, found {}", expression), };