diff --git a/benchmark/simple.leo b/benchmark/simple.leo index 7dbe93d5c5..5b35252ebd 100644 --- a/benchmark/simple.leo +++ b/benchmark/simple.leo @@ -1,15 +1,6 @@ -struct Foo { - x: u32 - y: u32 -} - -function main(a: private fe) { - let b = a + 1fe; - assert_eq(b, 2fe); - - let c = Foo { - x: 4, - y: 5, - }; +function main() { + let a = 1u8 + 2u8; + let b = 2u32 + 4; + assert_eq(b, 6); } \ No newline at end of file diff --git a/benchmark/src/main.rs b/benchmark/src/main.rs index fab2618639..7e0f8461f1 100644 --- a/benchmark/src/main.rs +++ b/benchmark/src/main.rs @@ -60,7 +60,7 @@ impl ConstraintSynthesizer for Benchmark { cs: &mut CS, ) -> Result<(), SynthesisError> { let _res = - leo_compiler::ResolvedProgram::generate_constraints(cs, self.program, self.parameters); + leo_compiler::ConstrainedProgram::generate_constraints(cs, self.program, self.parameters); println!(" Result: {}", _res); // Write results to file or something @@ -92,8 +92,8 @@ fn main() { let start = Instant::now(); // Set main function arguments in compiled program - let argument = Some(ParameterValue::Field(Fr::one())); - program.parameters = vec![argument]; + // let argument = Some(ParameterValue::Field(Fr::one())); + // program.parameters = vec![argument]; // Generate proof let proof = create_random_proof(program, ¶ms, rng).unwrap(); diff --git a/compiler/src/ast.rs b/compiler/src/ast.rs index 750a7c5f35..f453528a42 100644 --- a/compiler/src/ast.rs +++ b/compiler/src/ast.rs @@ -122,11 +122,19 @@ pub enum OperationAssign { // Types +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::type_u8))] +pub struct U8Type {} + #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::type_u32))] -pub struct U32Type<'ast> { - #[pest_ast(outer())] - pub span: Span<'ast>, +pub struct U32Type {} + +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::type_integer))] +pub enum IntegerType { + U8Type(U8Type), + U32Type(U32Type), } #[derive(Clone, Debug, FromPest, PartialEq)] @@ -154,7 +162,7 @@ pub struct StructType<'ast> { #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::type_basic))] pub enum BasicType<'ast> { - U32(U32Type<'ast>), + Integer(IntegerType), Field(FieldType<'ast>), Boolean(BooleanType<'ast>), } @@ -210,15 +218,15 @@ impl<'ast> fmt::Display for Number<'ast> { } #[derive(Clone, Debug, FromPest, PartialEq)] -#[pest_ast(rule(Rule::value_u32))] -pub struct U32<'ast> { +#[pest_ast(rule(Rule::value_integer))] +pub struct Integer<'ast> { pub number: Number<'ast>, - pub _type: Option>, + pub _type: Option, #[pest_ast(outer())] pub span: Span<'ast>, } -impl<'ast> fmt::Display for U32<'ast> { +impl<'ast> fmt::Display for Integer<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self.number) } @@ -257,15 +265,15 @@ impl<'ast> fmt::Display for Boolean<'ast> { #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::value))] pub enum Value<'ast> { + Integer(Integer<'ast>), Field(Field<'ast>), Boolean(Boolean<'ast>), - U32(U32<'ast>), } impl<'ast> Value<'ast> { pub fn span(&self) -> &Span<'ast> { match self { - Value::U32(value) => &value.span, + Value::Integer(value) => &value.span, Value::Field(value) => &value.span, Value::Boolean(value) => &value.span, } @@ -275,7 +283,7 @@ impl<'ast> Value<'ast> { impl<'ast> fmt::Display for Value<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Value::U32(ref value) => write!(f, "{}", value), + Value::Integer(ref value) => write!(f, "{}", value), Value::Field(ref value) => write!(f, "{}", value), Value::Boolean(ref value) => write!(f, "{}", value), } diff --git a/compiler/src/compiler.rs b/compiler/src/compiler.rs index b97a66b17c..134a73916a 100644 --- a/compiler/src/compiler.rs +++ b/compiler/src/compiler.rs @@ -1,6 +1,8 @@ //! Compiles a Leo program from a file path. -use crate::{ast, errors::CompilerError, ParameterValue, Program, ResolvedProgram, ResolvedValue}; +use crate::{ + ast, errors::CompilerError, ConstrainedProgram, ConstrainedValue, ParameterValue, Program, +}; use snarkos_errors::gadgets::SynthesisError; use snarkos_models::{ @@ -18,7 +20,7 @@ pub struct Compiler { main_file_path: PathBuf, program: Program, parameters: Vec>>, - output: Option>, + output: Option>, _engine: PhantomData, } @@ -91,7 +93,7 @@ impl ConstraintSynthesizer for Compiler { self, cs: &mut CS, ) -> Result<(), SynthesisError> { - let _res = ResolvedProgram::generate_constraints(cs, self.program, self.parameters); + let _res = ConstrainedProgram::generate_constraints(cs, self.program, self.parameters); // Write results to file or something diff --git a/compiler/src/constraints/boolean.rs b/compiler/src/constraints/boolean.rs index 49839a504b..18566ae14e 100644 --- a/compiler/src/constraints/boolean.rs +++ b/compiler/src/constraints/boolean.rs @@ -1,7 +1,7 @@ //! Methods to enforce constraints on booleans in a resolved Leo program. use crate::{ - constraints::{new_variable_from_variable, ResolvedProgram, ResolvedValue}, + constraints::{new_variable_from_variable, ConstrainedProgram, ConstrainedValue}, types::{ParameterModel, ParameterValue, Variable}, }; @@ -14,7 +14,7 @@ use snarkos_models::{ }, }; -impl> ResolvedProgram { +impl> ConstrainedProgram { pub(crate) fn bool_from_parameter( &mut self, cs: &mut CS, @@ -45,7 +45,10 @@ impl> ResolvedProgram { let parameter_variable = new_variable_from_variable(scope, ¶meter_model.variable); // store each argument as variable in resolved program - self.store_variable(parameter_variable.clone(), ResolvedValue::Boolean(number)); + self.store_variable( + parameter_variable.clone(), + ConstrainedValue::Boolean(number), + ); parameter_variable } @@ -80,13 +83,13 @@ impl> ResolvedProgram { // parameter_variable } - pub(crate) fn get_boolean_constant(bool: bool) -> ResolvedValue { - ResolvedValue::Boolean(Boolean::Constant(bool)) + pub(crate) fn get_boolean_constant(bool: bool) -> ConstrainedValue { + ConstrainedValue::Boolean(Boolean::Constant(bool)) } - pub(crate) fn evaluate_not(value: ResolvedValue) -> ResolvedValue { + pub(crate) fn evaluate_not(value: ConstrainedValue) -> ConstrainedValue { match value { - ResolvedValue::Boolean(boolean) => ResolvedValue::Boolean(boolean.not()), + ConstrainedValue::Boolean(boolean) => ConstrainedValue::Boolean(boolean.not()), value => unimplemented!("cannot enforce not on non-boolean value {}", value), } } @@ -94,12 +97,12 @@ impl> ResolvedProgram { pub(crate) fn enforce_or( &mut self, cs: &mut CS, - left: ResolvedValue, - right: ResolvedValue, - ) -> ResolvedValue { + left: ConstrainedValue, + right: ConstrainedValue, + ) -> ConstrainedValue { match (left, right) { - (ResolvedValue::Boolean(left_bool), ResolvedValue::Boolean(right_bool)) => { - ResolvedValue::Boolean(Boolean::or(cs, &left_bool, &right_bool).unwrap()) + (ConstrainedValue::Boolean(left_bool), ConstrainedValue::Boolean(right_bool)) => { + ConstrainedValue::Boolean(Boolean::or(cs, &left_bool, &right_bool).unwrap()) } (left_value, right_value) => unimplemented!( "cannot enforce or on non-boolean values {} || {}", @@ -112,12 +115,12 @@ impl> ResolvedProgram { pub(crate) fn enforce_and( &mut self, cs: &mut CS, - left: ResolvedValue, - right: ResolvedValue, - ) -> ResolvedValue { + left: ConstrainedValue, + right: ConstrainedValue, + ) -> ConstrainedValue { match (left, right) { - (ResolvedValue::Boolean(left_bool), ResolvedValue::Boolean(right_bool)) => { - ResolvedValue::Boolean(Boolean::and(cs, &left_bool, &right_bool).unwrap()) + (ConstrainedValue::Boolean(left_bool), ConstrainedValue::Boolean(right_bool)) => { + ConstrainedValue::Boolean(Boolean::and(cs, &left_bool, &right_bool).unwrap()) } (left_value, right_value) => unimplemented!( "cannot enforce and on non-boolean values {} && {}", @@ -127,8 +130,8 @@ impl> ResolvedProgram { } } - pub(crate) fn boolean_eq(left: Boolean, right: Boolean) -> ResolvedValue { - ResolvedValue::Boolean(Boolean::Constant(left.eq(&right))) + pub(crate) fn boolean_eq(left: Boolean, right: Boolean) -> ConstrainedValue { + ConstrainedValue::Boolean(Boolean::Constant(left.eq(&right))) } pub(crate) fn enforce_boolean_eq(&mut self, cs: &mut CS, left: Boolean, right: Boolean) { diff --git a/compiler/src/constraints/resolved_program.rs b/compiler/src/constraints/constrained_program.rs similarity index 77% rename from compiler/src/constraints/resolved_program.rs rename to compiler/src/constraints/constrained_program.rs index 7e19c12120..3bdd543340 100644 --- a/compiler/src/constraints/resolved_program.rs +++ b/compiler/src/constraints/constrained_program.rs @@ -1,6 +1,6 @@ //! An in memory store to keep track of defined names when constraining a Leo program. -use crate::{constraints::ResolvedValue, types::Variable}; +use crate::{constraints::ConstrainedValue, types::Variable}; use snarkos_models::{ curves::{Field, PrimeField}, @@ -8,8 +8,8 @@ use snarkos_models::{ }; use std::{collections::HashMap, marker::PhantomData}; -pub struct ResolvedProgram> { - pub resolved_names: HashMap>, +pub struct ConstrainedProgram> { + pub resolved_names: HashMap>, pub _cs: PhantomData, } @@ -44,7 +44,7 @@ pub fn new_variable_from_variables( } } -impl> ResolvedProgram { +impl> ConstrainedProgram { pub fn new() -> Self { Self { resolved_names: HashMap::new(), @@ -52,11 +52,11 @@ impl> ResolvedProgram { } } - pub(crate) fn store(&mut self, name: String, value: ResolvedValue) { + pub(crate) fn store(&mut self, name: String, value: ConstrainedValue) { self.resolved_names.insert(name, value); } - pub(crate) fn store_variable(&mut self, variable: Variable, value: ResolvedValue) { + pub(crate) fn store_variable(&mut self, variable: Variable, value: ConstrainedValue) { self.store(variable.name, value); } @@ -68,18 +68,18 @@ impl> ResolvedProgram { self.contains_name(&variable.name) } - pub(crate) fn get(&self, name: &String) -> Option<&ResolvedValue> { + pub(crate) fn get(&self, name: &String) -> Option<&ConstrainedValue> { self.resolved_names.get(name) } - pub(crate) fn get_mut(&mut self, name: &String) -> Option<&mut ResolvedValue> { + pub(crate) fn get_mut(&mut self, name: &String) -> Option<&mut ConstrainedValue> { self.resolved_names.get_mut(name) } pub(crate) fn get_mut_variable( &mut self, variable: &Variable, - ) -> Option<&mut ResolvedValue> { + ) -> Option<&mut ConstrainedValue> { self.get_mut(&variable.name) } } diff --git a/compiler/src/constraints/resolved_value.rs b/compiler/src/constraints/constrained_value.rs similarity index 51% rename from compiler/src/constraints/resolved_value.rs rename to compiler/src/constraints/constrained_value.rs index 7ee115ec3b..a35b350aa2 100644 --- a/compiler/src/constraints/resolved_value.rs +++ b/compiler/src/constraints/constrained_value.rs @@ -1,17 +1,18 @@ //! The in memory stored value for a defined name in a resolved Leo program. -use crate::types::{Function, Struct, Type, Variable}; +use crate::{ + constraints::ConstrainedInteger, + types::{Function, Struct, Type, Variable}, +}; use snarkos_models::{ curves::{Field, PrimeField}, - gadgets::{ - r1cs::Variable as R1CSVariable, utilities::boolean::Boolean, utilities::uint32::UInt32, - }, + gadgets::{r1cs::Variable as R1CSVariable, utilities::boolean::Boolean}, }; use std::fmt; #[derive(Clone, PartialEq, Eq)] -pub struct ResolvedStructMember(pub Variable, pub ResolvedValue); +pub struct ConstrainedStructMember(pub Variable, pub ConstrainedValue); #[derive(Clone, PartialEq, Eq)] pub enum FieldElement { @@ -35,55 +36,60 @@ impl fmt::Display for FieldElement { } #[derive(Clone, PartialEq, Eq)] -pub enum ResolvedValue { - U32(UInt32), +pub enum ConstrainedValue { + Integer(ConstrainedInteger), FieldElement(FieldElement), Boolean(Boolean), - Array(Vec>), + Array(Vec>), StructDefinition(Struct), - StructExpression(Variable, Vec>), + StructExpression(Variable, Vec>), Function(Function), - Return(Vec>), // add Null for function returns + Return(Vec>), // add Null for function returns } -impl ResolvedValue { - pub(crate) fn match_type(&self, ty: &Type) -> bool { - match (self, ty) { - (ResolvedValue::U32(ref _i), Type::U32) => true, - (ResolvedValue::FieldElement(ref _f), Type::FieldElement) => true, - (ResolvedValue::Boolean(ref _b), Type::Boolean) => true, - (ResolvedValue::Array(ref arr), Type::Array(ref ty, ref len)) => { +impl ConstrainedValue { + pub(crate) fn expect_type(&self, _type: &Type) { + match (self, _type) { + (ConstrainedValue::Integer(ref integer), Type::IntegerType(ref _type)) => { + integer.expect_type(_type) + } + (ConstrainedValue::FieldElement(ref _f), Type::FieldElement) => {} + (ConstrainedValue::Boolean(ref _b), Type::Boolean) => {} + (ConstrainedValue::Array(ref arr), Type::Array(ref ty, ref len)) => { // check array lengths are equal - let mut res = arr.len() == *len; + if arr.len() != *len { + unimplemented!("array length {} != {}", arr.len(), *len) + } // check each value in array matches for value in arr { - res &= value.match_type(ty) + value.expect_type(ty) } - res } ( - ResolvedValue::StructExpression(ref actual_name, ref _members), + ConstrainedValue::StructExpression(ref actual_name, ref _members), Type::Struct(ref expected_name), - ) => actual_name == expected_name, - (ResolvedValue::Return(ref values), ty) => { - let mut res = true; - for value in values { - res &= value.match_type(ty) + ) => { + if expected_name != actual_name { + unimplemented!("expected struct name {} got {}", expected_name, actual_name) } - res } - (_, _) => false, + (ConstrainedValue::Return(ref values), ty) => { + for value in values { + value.expect_type(ty) + } + } + (value, _type) => unimplemented!("expected type {}, got {}", _type, value), } } } -impl fmt::Display for ResolvedValue { +impl fmt::Display for ConstrainedValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - ResolvedValue::U32(ref value) => write!(f, "{}", value.value.unwrap()), - ResolvedValue::FieldElement(ref value) => write!(f, "{}", value), - ResolvedValue::Boolean(ref value) => write!(f, "{}", value.get_value().unwrap()), - ResolvedValue::Array(ref array) => { + ConstrainedValue::Integer(ref value) => write!(f, "{}", value), + ConstrainedValue::FieldElement(ref value) => write!(f, "{}", value), + ConstrainedValue::Boolean(ref value) => write!(f, "{}", value.get_value().unwrap()), + ConstrainedValue::Array(ref array) => { write!(f, "[")?; for (i, e) in array.iter().enumerate() { write!(f, "{}", e)?; @@ -93,7 +99,7 @@ impl fmt::Display for ResolvedValue { } write!(f, "]") } - ResolvedValue::StructExpression(ref variable, ref members) => { + ConstrainedValue::StructExpression(ref variable, ref members) => { write!(f, "{} {{", variable)?; for (i, member) in members.iter().enumerate() { write!(f, "{}: {}", member.0, member.1)?; @@ -103,7 +109,7 @@ impl fmt::Display for ResolvedValue { } write!(f, "}}") } - ResolvedValue::Return(ref values) => { + ConstrainedValue::Return(ref values) => { write!(f, "Program output: [")?; for (i, value) in values.iter().enumerate() { write!(f, "{}", value)?; @@ -113,15 +119,15 @@ impl fmt::Display for ResolvedValue { } write!(f, "]") } - ResolvedValue::StructDefinition(ref _definition) => { + ConstrainedValue::StructDefinition(ref _definition) => { unimplemented!("cannot return struct definition in program") } - ResolvedValue::Function(ref function) => write!(f, "{}();", function.function_name), + ConstrainedValue::Function(ref function) => write!(f, "{}();", function.function_name), } } } -impl fmt::Debug for ResolvedValue { +impl fmt::Debug for ConstrainedValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}", self) } diff --git a/compiler/src/constraints/expression.rs b/compiler/src/constraints/expression.rs index a2365db7c5..e44291eb98 100644 --- a/compiler/src/constraints/expression.rs +++ b/compiler/src/constraints/expression.rs @@ -2,8 +2,8 @@ use crate::{ constraints::{ - new_scope_from_variable, new_variable_from_variable, ResolvedProgram, ResolvedStructMember, - ResolvedValue, + new_scope_from_variable, new_variable_from_variable, ConstrainedProgram, + ConstrainedStructMember, ConstrainedValue, }, types::{Expression, RangeOrExpression, SpreadOrExpression, StructMember, Variable}, }; @@ -13,13 +13,13 @@ use snarkos_models::{ gadgets::{r1cs::ConstraintSystem, utilities::boolean::Boolean}, }; -impl> ResolvedProgram { +impl> ConstrainedProgram { /// Enforce a variable expression by getting the resolved value pub(crate) fn enforce_variable( &mut self, scope: String, unresolved_variable: Variable, - ) -> ResolvedValue { + ) -> ConstrainedValue { // Evaluate the variable name in the current function scope let variable_name = new_scope_from_variable(scope, &unresolved_variable); @@ -39,14 +39,14 @@ impl> ResolvedProgram { fn enforce_add_expression( &mut self, cs: &mut CS, - left: ResolvedValue, - right: ResolvedValue, - ) -> ResolvedValue { + left: ConstrainedValue, + right: ConstrainedValue, + ) -> ConstrainedValue { match (left, right) { - (ResolvedValue::U32(num_1), ResolvedValue::U32(num_2)) => { - Self::enforce_u32_add(cs, num_1, num_2) + (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { + Self::enforce_integer_add(cs, num_1, num_2) } - (ResolvedValue::FieldElement(fe_1), ResolvedValue::FieldElement(fe_2)) => { + (ConstrainedValue::FieldElement(fe_1), ConstrainedValue::FieldElement(fe_2)) => { self.enforce_field_add(cs, fe_1, fe_2) } (val_1, val_2) => unimplemented!("cannot add {} + {}", val_1, val_2), @@ -56,14 +56,14 @@ impl> ResolvedProgram { fn enforce_sub_expression( &mut self, cs: &mut CS, - left: ResolvedValue, - right: ResolvedValue, - ) -> ResolvedValue { + left: ConstrainedValue, + right: ConstrainedValue, + ) -> ConstrainedValue { match (left, right) { - (ResolvedValue::U32(num_1), ResolvedValue::U32(num_2)) => { - Self::enforce_u32_sub(cs, num_1, num_2) + (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { + Self::enforce_integer_sub(cs, num_1, num_2) } - (ResolvedValue::FieldElement(fe_1), ResolvedValue::FieldElement(fe_2)) => { + (ConstrainedValue::FieldElement(fe_1), ConstrainedValue::FieldElement(fe_2)) => { self.enforce_field_sub(cs, fe_1, fe_2) } (val_1, val_2) => unimplemented!("cannot subtract {} - {}", val_1, val_2), @@ -73,14 +73,14 @@ impl> ResolvedProgram { fn enforce_mul_expression( &mut self, cs: &mut CS, - left: ResolvedValue, - right: ResolvedValue, - ) -> ResolvedValue { + left: ConstrainedValue, + right: ConstrainedValue, + ) -> ConstrainedValue { match (left, right) { - (ResolvedValue::U32(num_1), ResolvedValue::U32(num_2)) => { - Self::enforce_u32_mul(cs, num_1, num_2) + (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { + Self::enforce_integer_mul(cs, num_1, num_2) } - (ResolvedValue::FieldElement(fe_1), ResolvedValue::FieldElement(fe_2)) => { + (ConstrainedValue::FieldElement(fe_1), ConstrainedValue::FieldElement(fe_2)) => { self.enforce_field_mul(cs, fe_1, fe_2) } (val_1, val_2) => unimplemented!("cannot multiply {} * {}", val_1, val_2), @@ -90,14 +90,14 @@ impl> ResolvedProgram { fn enforce_div_expression( &mut self, cs: &mut CS, - left: ResolvedValue, - right: ResolvedValue, - ) -> ResolvedValue { + left: ConstrainedValue, + right: ConstrainedValue, + ) -> ConstrainedValue { match (left, right) { - (ResolvedValue::U32(num_1), ResolvedValue::U32(num_2)) => { - Self::enforce_u32_div(cs, num_1, num_2) + (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { + Self::enforce_integer_div(cs, num_1, num_2) } - (ResolvedValue::FieldElement(fe_1), ResolvedValue::FieldElement(fe_2)) => { + (ConstrainedValue::FieldElement(fe_1), ConstrainedValue::FieldElement(fe_2)) => { self.enforce_field_div(cs, fe_1, fe_2) } (val_1, val_2) => unimplemented!("cannot divide {} / {}", val_1, val_2), @@ -106,17 +106,17 @@ impl> ResolvedProgram { fn enforce_pow_expression( &mut self, cs: &mut CS, - left: ResolvedValue, - right: ResolvedValue, - ) -> ResolvedValue { + left: ConstrainedValue, + right: ConstrainedValue, + ) -> ConstrainedValue { match (left, right) { - (ResolvedValue::U32(num_1), ResolvedValue::U32(num_2)) => { - Self::enforce_u32_pow(cs, num_1, num_2) + (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { + Self::enforce_integer_pow(cs, num_1, num_2) } - (ResolvedValue::FieldElement(fe_1), ResolvedValue::U32(num_2)) => { + (ConstrainedValue::FieldElement(fe_1), ConstrainedValue::Integer(num_2)) => { self.enforce_field_pow(cs, fe_1, num_2) } - (_, ResolvedValue::FieldElement(num_2)) => { + (_, ConstrainedValue::FieldElement(num_2)) => { unimplemented!("exponent power must be an integer, got field {}", num_2) } (val_1, val_2) => unimplemented!("cannot enforce exponentiation {} * {}", val_1, val_2), @@ -126,14 +126,16 @@ impl> ResolvedProgram { /// Evaluate Boolean operations fn evaluate_eq_expression( &mut self, - left: ResolvedValue, - right: ResolvedValue, - ) -> ResolvedValue { + left: ConstrainedValue, + right: ConstrainedValue, + ) -> ConstrainedValue { match (left, right) { - (ResolvedValue::Boolean(bool_1), ResolvedValue::Boolean(bool_2)) => { + (ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => { Self::boolean_eq(bool_1, bool_2) } - (ResolvedValue::U32(num_1), ResolvedValue::U32(num_2)) => Self::u32_eq(num_1, num_2), + (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { + Self::evaluate_integer_eq(num_1, num_2) + } // (ResolvedValue::FieldElement(fe_1), ResolvedValue::FieldElement(fe_2)) => { // Self::field_eq(fe_1, fe_2) // } @@ -143,9 +145,9 @@ impl> ResolvedProgram { fn evaluate_geq_expression( &mut self, - left: ResolvedValue, - right: ResolvedValue, - ) -> ResolvedValue { + left: ConstrainedValue, + right: ConstrainedValue, + ) -> ConstrainedValue { match (left, right) { // (ResolvedValue::FieldElement(fe_1), ResolvedValue::FieldElement(fe_2)) => { // Self::field_geq(fe_1, fe_2) @@ -160,9 +162,9 @@ impl> ResolvedProgram { fn evaluate_gt_expression( &mut self, - left: ResolvedValue, - right: ResolvedValue, - ) -> ResolvedValue { + left: ConstrainedValue, + right: ConstrainedValue, + ) -> ConstrainedValue { match (left, right) { // (ResolvedValue::FieldElement(fe_1), ResolvedValue::FieldElement(fe_2)) => { // Self::field_gt(fe_1, fe_2) @@ -177,9 +179,9 @@ impl> ResolvedProgram { fn evaluate_leq_expression( &mut self, - left: ResolvedValue, - right: ResolvedValue, - ) -> ResolvedValue { + left: ConstrainedValue, + right: ConstrainedValue, + ) -> ConstrainedValue { match (left, right) { // (ResolvedValue::FieldElement(fe_1), ResolvedValue::FieldElement(fe_2)) => { // Self::field_leq(fe_1, fe_2) @@ -194,9 +196,9 @@ impl> ResolvedProgram { fn evaluate_lt_expression( &mut self, - left: ResolvedValue, - right: ResolvedValue, - ) -> ResolvedValue { + left: ConstrainedValue, + right: ConstrainedValue, + ) -> ConstrainedValue { match (left, right) { // (ResolvedValue::FieldElement(fe_1), ResolvedValue::FieldElement(fe_2)) => { // Self::field_lt(fe_1, fe_2) @@ -216,7 +218,7 @@ impl> ResolvedProgram { file_scope: String, function_scope: String, array: Vec>>, - ) -> ResolvedValue { + ) -> ConstrainedValue { let mut result = vec![]; array.into_iter().for_each(|element| match *element { SpreadOrExpression::Spread(spread) => match spread { @@ -224,7 +226,7 @@ impl> ResolvedProgram { let array_name = new_scope_from_variable(function_scope.clone(), &variable); match self.get(&array_name) { Some(value) => match value { - ResolvedValue::Array(array) => result.extend(array.clone()), + ConstrainedValue::Array(array) => result.extend(array.clone()), value => { unimplemented!("spreads only implemented for arrays, got {}", value) } @@ -246,7 +248,7 @@ impl> ResolvedProgram { )); } }); - ResolvedValue::Array(result) + ConstrainedValue::Array(result) } pub(crate) fn enforce_index( @@ -257,7 +259,7 @@ impl> ResolvedProgram { index: Expression, ) -> usize { match self.enforce_expression(cs, file_scope, function_scope, index) { - ResolvedValue::U32(number) => number.value.unwrap() as usize, + ConstrainedValue::Integer(number) => number.get_value() as usize, value => unimplemented!("From index must resolve to an integer, got {}", value), } } @@ -269,9 +271,9 @@ impl> ResolvedProgram { function_scope: String, array: Box>, index: RangeOrExpression, - ) -> ResolvedValue { + ) -> ConstrainedValue { match self.enforce_expression(cs, file_scope.clone(), function_scope.clone(), *array) { - ResolvedValue::Array(array) => { + ConstrainedValue::Array(array) => { match index { RangeOrExpression::Range(from, to) => { let from_resolved = match from { @@ -282,7 +284,7 @@ impl> ResolvedProgram { Some(to_index) => to_index.to_usize(), None => array.len(), // Array slice ends at array length }; - ResolvedValue::Array(array[from_resolved..to_resolved].to_owned()) + ConstrainedValue::Array(array[from_resolved..to_resolved].to_owned()) } RangeOrExpression::Expression(index) => { let index_resolved = @@ -302,12 +304,12 @@ impl> ResolvedProgram { function_scope: String, variable: Variable, members: Vec>, - ) -> ResolvedValue { + ) -> ConstrainedValue { let struct_name = new_variable_from_variable(file_scope.clone(), &variable); if let Some(resolved_value) = self.get_mut_variable(&struct_name) { match resolved_value { - ResolvedValue::StructDefinition(struct_definition) => { + ConstrainedValue::StructDefinition(struct_definition) => { let resolved_members = struct_definition .fields .clone() @@ -325,11 +327,11 @@ impl> ResolvedProgram { member.expression, ); - ResolvedStructMember(member.variable, member_value) + ConstrainedStructMember(member.variable, member_value) }) .collect(); - ResolvedValue::StructExpression(variable, resolved_members) + ConstrainedValue::StructExpression(variable, resolved_members) } _ => unimplemented!("Inline struct type is not defined as a struct"), } @@ -348,9 +350,9 @@ impl> ResolvedProgram { function_scope: String, struct_variable: Box>, struct_member: Variable, - ) -> ResolvedValue { + ) -> ConstrainedValue { match self.enforce_expression(cs, file_scope, function_scope, *struct_variable) { - ResolvedValue::StructExpression(_name, members) => { + ConstrainedValue::StructExpression(_name, members) => { let matched_member = members.into_iter().find(|member| member.0 == struct_member); match matched_member { Some(member) => member.1, @@ -368,19 +370,19 @@ impl> ResolvedProgram { function_scope: String, function: Variable, arguments: Vec>, - ) -> ResolvedValue { + ) -> ConstrainedValue { let function_name = new_variable_from_variable(file_scope.clone(), &function); match self.get_mut_variable(&function_name) { Some(value) => match value.clone() { - ResolvedValue::Function(function) => { + ConstrainedValue::Function(function) => { // this function call is inline so we unwrap the return value match self.enforce_function(cs, file_scope, function_scope, function, arguments) { - ResolvedValue::Return(return_values) => { + ConstrainedValue::Return(return_values) => { if return_values.len() == 1 { return_values[0].clone() } else { - ResolvedValue::Return(return_values) + ConstrainedValue::Return(return_values) } } value => unimplemented!( @@ -402,7 +404,7 @@ impl> ResolvedProgram { file_scope: String, function_scope: String, expression: Expression, - ) -> ResolvedValue { + ) -> ConstrainedValue { match expression { // Variables Expression::Variable(unresolved_variable) => { @@ -528,7 +530,7 @@ impl> ResolvedProgram { function_scope.clone(), *first, ) { - ResolvedValue::Boolean(resolved) => resolved, + ConstrainedValue::Boolean(resolved) => resolved, _ => unimplemented!("if else conditional must resolve to boolean"), }; diff --git a/compiler/src/constraints/field_element.rs b/compiler/src/constraints/field_element.rs index 37fc6047e5..8dbd1ce813 100644 --- a/compiler/src/constraints/field_element.rs +++ b/compiler/src/constraints/field_element.rs @@ -1,20 +1,18 @@ //! Methods to enforce constraints on field elements in a resolved Leo program. use crate::{ - constraints::{new_variable_from_variable, FieldElement, ResolvedProgram, ResolvedValue}, + constraints::{new_variable_from_variable, ConstrainedProgram, ConstrainedValue, FieldElement}, types::{ParameterModel, ParameterValue, Variable}, + ConstrainedInteger, }; use snarkos_errors::gadgets::SynthesisError; use snarkos_models::{ curves::{Field, PrimeField}, - gadgets::{ - r1cs::{ConstraintSystem, LinearCombination, Variable as R1CSVariable}, - utilities::uint32::UInt32, - }, + gadgets::r1cs::{ConstraintSystem, LinearCombination, Variable as R1CSVariable}, }; -impl> ResolvedProgram { +impl> ConstrainedProgram { pub(crate) fn field_element_from_parameter( &mut self, cs: &mut CS, @@ -49,7 +47,7 @@ impl> ResolvedProgram { // Store parameter as variable in resolved program self.store_variable( parameter_variable.clone(), - ResolvedValue::FieldElement(FieldElement::Allocated(field_option, field_value)), + ConstrainedValue::FieldElement(FieldElement::Allocated(field_option, field_value)), ); parameter_variable @@ -83,8 +81,8 @@ impl> ResolvedProgram { // parameter_variable } - pub(crate) fn get_field_element_constant(fe: F) -> ResolvedValue { - ResolvedValue::FieldElement(FieldElement::Constant(fe)) + pub(crate) fn get_field_element_constant(fe: F) -> ConstrainedValue { + ConstrainedValue::FieldElement(FieldElement::Constant(fe)) } // pub(crate) fn field_eq(fe1: F, fe2: F) -> ResolvedValue { @@ -156,11 +154,11 @@ impl> ResolvedProgram { cs: &mut CS, fe_1: FieldElement, fe_2: FieldElement, - ) -> ResolvedValue { + ) -> ConstrainedValue { match (fe_1, fe_2) { // if both constants, then return a constant result (FieldElement::Constant(fe_1_constant), FieldElement::Constant(fe_2_constant)) => { - ResolvedValue::FieldElement(FieldElement::Constant( + ConstrainedValue::FieldElement(FieldElement::Constant( fe_1_constant.add(&fe_2_constant), )) } @@ -184,7 +182,7 @@ impl> ResolvedProgram { |lc| lc + sum_variable.clone(), ); - ResolvedValue::FieldElement(FieldElement::Allocated(sum_value, sum_variable)) + ConstrainedValue::FieldElement(FieldElement::Allocated(sum_value, sum_variable)) } ( FieldElement::Constant(fe_1_constant), @@ -205,7 +203,7 @@ impl> ResolvedProgram { |lc| lc + sum_variable.clone(), ); - ResolvedValue::FieldElement(FieldElement::Allocated(sum_value, sum_variable)) + ConstrainedValue::FieldElement(FieldElement::Allocated(sum_value, sum_variable)) } ( FieldElement::Allocated(fe_1_value, fe_1_variable), @@ -229,7 +227,7 @@ impl> ResolvedProgram { |lc| lc + sum_variable.clone(), ); - ResolvedValue::FieldElement(FieldElement::Allocated(sum_value, sum_variable)) + ConstrainedValue::FieldElement(FieldElement::Allocated(sum_value, sum_variable)) } } } @@ -239,11 +237,11 @@ impl> ResolvedProgram { cs: &mut CS, fe_1: FieldElement, fe_2: FieldElement, - ) -> ResolvedValue { + ) -> ConstrainedValue { match (fe_1, fe_2) { // if both constants, then return a constant result (FieldElement::Constant(fe_1_constant), FieldElement::Constant(fe_2_constant)) => { - ResolvedValue::FieldElement(FieldElement::Constant( + ConstrainedValue::FieldElement(FieldElement::Constant( fe_1_constant.sub(&fe_2_constant), )) } @@ -267,7 +265,7 @@ impl> ResolvedProgram { |lc| lc + sub_variable.clone(), ); - ResolvedValue::FieldElement(FieldElement::Allocated(sub_value, sub_variable)) + ConstrainedValue::FieldElement(FieldElement::Allocated(sub_value, sub_variable)) } ( FieldElement::Constant(fe_1_constant), @@ -288,7 +286,7 @@ impl> ResolvedProgram { |lc| lc + sub_variable.clone(), ); - ResolvedValue::FieldElement(FieldElement::Allocated(sub_value, sub_variable)) + ConstrainedValue::FieldElement(FieldElement::Allocated(sub_value, sub_variable)) } ( FieldElement::Allocated(fe_1_value, fe_1_variable), @@ -312,7 +310,7 @@ impl> ResolvedProgram { |lc| lc + sub_variable.clone(), ); - ResolvedValue::FieldElement(FieldElement::Allocated(sub_value, sub_variable)) + ConstrainedValue::FieldElement(FieldElement::Allocated(sub_value, sub_variable)) } } } @@ -322,11 +320,11 @@ impl> ResolvedProgram { cs: &mut CS, fe_1: FieldElement, fe_2: FieldElement, - ) -> ResolvedValue { + ) -> ConstrainedValue { match (fe_1, fe_2) { // if both constants, then return a constant result (FieldElement::Constant(fe_1_constant), FieldElement::Constant(fe_2_constant)) => { - ResolvedValue::FieldElement(FieldElement::Constant( + ConstrainedValue::FieldElement(FieldElement::Constant( fe_1_constant.mul(&fe_2_constant), )) } @@ -350,7 +348,7 @@ impl> ResolvedProgram { |lc| lc + mul_variable.clone(), ); - ResolvedValue::FieldElement(FieldElement::Allocated(mul_value, mul_variable)) + ConstrainedValue::FieldElement(FieldElement::Allocated(mul_value, mul_variable)) } ( FieldElement::Constant(fe_1_constant), @@ -371,7 +369,7 @@ impl> ResolvedProgram { |lc| lc + mul_variable.clone(), ); - ResolvedValue::FieldElement(FieldElement::Allocated(mul_value, mul_variable)) + ConstrainedValue::FieldElement(FieldElement::Allocated(mul_value, mul_variable)) } ( FieldElement::Allocated(fe_1_value, fe_1_variable), @@ -395,7 +393,7 @@ impl> ResolvedProgram { |lc| lc + mul_variable.clone(), ); - ResolvedValue::FieldElement(FieldElement::Allocated(mul_value, mul_variable)) + ConstrainedValue::FieldElement(FieldElement::Allocated(mul_value, mul_variable)) } } } @@ -405,11 +403,11 @@ impl> ResolvedProgram { cs: &mut CS, fe_1: FieldElement, fe_2: FieldElement, - ) -> ResolvedValue { + ) -> ConstrainedValue { match (fe_1, fe_2) { // if both constants, then return a constant result (FieldElement::Constant(fe_1_constant), FieldElement::Constant(fe_2_constant)) => { - ResolvedValue::FieldElement(FieldElement::Constant( + ConstrainedValue::FieldElement(FieldElement::Constant( fe_1_constant.div(&fe_2_constant), )) } @@ -434,7 +432,7 @@ impl> ResolvedProgram { |lc| lc + div_variable.clone(), ); - ResolvedValue::FieldElement(FieldElement::Allocated(div_value, div_variable)) + ConstrainedValue::FieldElement(FieldElement::Allocated(div_value, div_variable)) } ( FieldElement::Constant(fe_1_constant), @@ -462,7 +460,7 @@ impl> ResolvedProgram { |lc| lc + div_variable.clone(), ); - ResolvedValue::FieldElement(FieldElement::Allocated(div_value, div_variable)) + ConstrainedValue::FieldElement(FieldElement::Allocated(div_value, div_variable)) } ( FieldElement::Allocated(fe_1_value, fe_1_variable), @@ -493,7 +491,7 @@ impl> ResolvedProgram { |lc| lc + div_variable.clone(), ); - ResolvedValue::FieldElement(FieldElement::Allocated(div_value, div_variable)) + ConstrainedValue::FieldElement(FieldElement::Allocated(div_value, div_variable)) } } } @@ -502,16 +500,16 @@ impl> ResolvedProgram { &mut self, cs: &mut CS, fe_1: FieldElement, - num: UInt32, - ) -> ResolvedValue { + num: ConstrainedInteger, + ) -> ConstrainedValue { match fe_1 { // if both constants, then return a constant result - FieldElement::Constant(fe_1_constant) => ResolvedValue::FieldElement( - FieldElement::Constant(fe_1_constant.pow(&[num.value.unwrap() as u64])), + FieldElement::Constant(fe_1_constant) => ConstrainedValue::FieldElement( + FieldElement::Constant(fe_1_constant.pow(&[num.get_value() as u64])), ), // else, return an allocated result FieldElement::Allocated(fe_1_value, _fe_1_variable) => { - let pow_value: Option = fe_1_value.map(|v| v.pow(&[num.value.unwrap() as u64])); + let pow_value: Option = fe_1_value.map(|v| v.pow(&[num.get_value() as u64])); let pow_variable: R1CSVariable = cs .alloc( || "field exponentiation", @@ -525,7 +523,7 @@ impl> ResolvedProgram { // |lc| lc + (fe_2_inverse_value, CS::one()), // |lc| lc + pow_variable.clone()); - ResolvedValue::FieldElement(FieldElement::Allocated(pow_value, pow_variable)) + ConstrainedValue::FieldElement(FieldElement::Allocated(pow_value, pow_variable)) } } } diff --git a/compiler/src/constraints/integer/integer.rs b/compiler/src/constraints/integer/integer.rs new file mode 100644 index 0000000000..05dc045db2 --- /dev/null +++ b/compiler/src/constraints/integer/integer.rs @@ -0,0 +1,244 @@ +//! Methods to enforce constraints on integers in a resolved Leo program. + +use crate::{ + constraints::{ConstrainedProgram, ConstrainedValue}, + types::{Integer, ParameterModel, ParameterValue, Type, Variable}, + IntegerType, +}; + +use snarkos_models::{ + curves::{Field, PrimeField}, + gadgets::{ + r1cs::ConstraintSystem, + utilities::{boolean::Boolean, uint32::UInt32, uint8::UInt8}, + }, +}; +use std::fmt; + +#[derive(Clone, PartialEq, Eq)] +pub enum ConstrainedInteger { + U8(UInt8), + U32(UInt32), +} + +impl ConstrainedInteger { + pub(crate) fn get_value(&self) -> usize { + match self { + ConstrainedInteger::U8(u8) => u8.value.unwrap() as usize, + ConstrainedInteger::U32(u32) => u32.value.unwrap() as usize, + } + } + + pub(crate) fn expect_type(&self, integer_type: &IntegerType) { + match (self, integer_type) { + (ConstrainedInteger::U8(_u8), IntegerType::U8) => {} + (ConstrainedInteger::U32(_u32), IntegerType::U32) => {} + (actual, expected) => { + unimplemented!("expected integer type {}, got {}", expected, actual) + } + } + } +} + +impl fmt::Display for ConstrainedInteger { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ConstrainedInteger::U8(u8) => write!(f, "{}", u8.value.unwrap()), + ConstrainedInteger::U32(u32) => write!(f, "{}", u32.value.unwrap()), + } + } +} + +impl> ConstrainedProgram { + pub(crate) fn get_integer_constant(integer: Integer) -> ConstrainedValue { + ConstrainedValue::Integer(match integer { + Integer::U8(u8_value) => ConstrainedInteger::U8(UInt8::constant(u8_value)), + Integer::U32(u32_value) => ConstrainedInteger::U32(UInt32::constant(u32_value)), + }) + } + + pub(crate) fn evaluate_integer_eq( + left: ConstrainedInteger, + right: ConstrainedInteger, + ) -> ConstrainedValue { + ConstrainedValue::Boolean(Boolean::Constant(match (left, right) { + (ConstrainedInteger::U8(left_u8), ConstrainedInteger::U8(right_u8)) => { + left_u8.eq(&right_u8) + } + (ConstrainedInteger::U32(left_u32), ConstrainedInteger::U32(right_u32)) => { + left_u32.eq(&right_u32) + } + (left, right) => unimplemented!( + "cannot evaluate integer equality between {} == {}", + left, + right + ), + })) + } + + pub(crate) fn integer_from_parameter( + &mut self, + cs: &mut CS, + scope: String, + parameter_model: ParameterModel, + parameter_value: Option>, + ) -> Variable { + let integer_type = match ¶meter_model._type { + Type::IntegerType(integer_type) => integer_type, + _type => unimplemented!("expected integer parameter, got {}", _type), + }; + + match integer_type { + IntegerType::U8 => self.u8_from_parameter(cs, scope, parameter_model, parameter_value), + IntegerType::U32 => { + self.u32_from_parameter(cs, scope, parameter_model, parameter_value) + } + } + } + + pub(crate) fn integer_array_from_parameter( + &mut self, + _cs: &mut CS, + _scope: String, + _parameter_model: ParameterModel, + _parameter_value: Option>, + ) -> Variable { + unimplemented!("Cannot enforce integer array as parameter") + // // Check visibility of parameter + // let mut array_value = vec![]; + // let name = parameter.variable.name.clone(); + // for argument in argument_array { + // let number = if parameter.private { + // UInt32::alloc(cs.ns(|| name), Some(argument)).unwrap() + // } else { + // UInt32::alloc_input(cs.ns(|| name), Some(argument)).unwrap() + // }; + // + // array_value.push(number); + // } + // + // + // let parameter_variable = new_variable_from_variable(scope, ¶meter.variable); + // + // // store array as variable in resolved program + // self.store_variable(parameter_variable.clone(), ResolvedValue::U32Array(array_value)); + // + // parameter_variable + } + + pub(crate) fn enforce_integer_eq( + cs: &mut CS, + left: ConstrainedInteger, + right: ConstrainedInteger, + ) { + match (left, right) { + (ConstrainedInteger::U8(left_u8), ConstrainedInteger::U8(right_u8)) => { + Self::enforce_u8_eq(cs, left_u8, right_u8) + } + (ConstrainedInteger::U32(left_u32), ConstrainedInteger::U32(right_u32)) => { + Self::enforce_u32_eq(cs, left_u32, right_u32) + } + (left, right) => unimplemented!( + "cannot enforce integer equality between {} == {}", + left, + right + ), + } + } + + pub(crate) fn enforce_integer_add( + cs: &mut CS, + left: ConstrainedInteger, + right: ConstrainedInteger, + ) -> ConstrainedValue { + ConstrainedValue::Integer(match (left, right) { + (ConstrainedInteger::U8(left_u8), ConstrainedInteger::U8(right_u8)) => { + ConstrainedInteger::U8(Self::enforce_u8_add(cs, left_u8, right_u8)) + } + (ConstrainedInteger::U32(left_u32), ConstrainedInteger::U32(right_u32)) => { + ConstrainedInteger::U32(Self::enforce_u32_add(cs, left_u32, right_u32)) + } + (left, right) => unimplemented!( + "cannot enforce integer addition between {} + {}", + left, + right + ), + }) + } + pub(crate) fn enforce_integer_sub( + cs: &mut CS, + left: ConstrainedInteger, + right: ConstrainedInteger, + ) -> ConstrainedValue { + ConstrainedValue::Integer(match (left, right) { + (ConstrainedInteger::U8(left_u8), ConstrainedInteger::U8(right_u8)) => { + ConstrainedInteger::U8(Self::enforce_u8_sub(cs, left_u8, right_u8)) + } + (ConstrainedInteger::U32(left_u32), ConstrainedInteger::U32(right_u32)) => { + ConstrainedInteger::U32(Self::enforce_u32_sub(cs, left_u32, right_u32)) + } + (left, right) => unimplemented!( + "cannot enforce integer subtraction between {} - {}", + left, + right + ), + }) + } + pub(crate) fn enforce_integer_mul( + cs: &mut CS, + left: ConstrainedInteger, + right: ConstrainedInteger, + ) -> ConstrainedValue { + ConstrainedValue::Integer(match (left, right) { + (ConstrainedInteger::U8(left_u8), ConstrainedInteger::U8(right_u8)) => { + ConstrainedInteger::U8(Self::enforce_u8_mul(cs, left_u8, right_u8)) + } + (ConstrainedInteger::U32(left_u32), ConstrainedInteger::U32(right_u32)) => { + ConstrainedInteger::U32(Self::enforce_u32_mul(cs, left_u32, right_u32)) + } + (left, right) => unimplemented!( + "cannot enforce integer multiplication between {} * {}", + left, + right + ), + }) + } + pub(crate) fn enforce_integer_div( + cs: &mut CS, + left: ConstrainedInteger, + right: ConstrainedInteger, + ) -> ConstrainedValue { + ConstrainedValue::Integer(match (left, right) { + (ConstrainedInteger::U8(left_u8), ConstrainedInteger::U8(right_u8)) => { + ConstrainedInteger::U8(Self::enforce_u8_div(cs, left_u8, right_u8)) + } + (ConstrainedInteger::U32(left_u32), ConstrainedInteger::U32(right_u32)) => { + ConstrainedInteger::U32(Self::enforce_u32_div(cs, left_u32, right_u32)) + } + (left, right) => unimplemented!( + "cannot enforce integer division between {} / {}", + left, + right + ), + }) + } + pub(crate) fn enforce_integer_pow( + cs: &mut CS, + left: ConstrainedInteger, + right: ConstrainedInteger, + ) -> ConstrainedValue { + ConstrainedValue::Integer(match (left, right) { + (ConstrainedInteger::U8(left_u8), ConstrainedInteger::U8(right_u8)) => { + ConstrainedInteger::U8(Self::enforce_u8_pow(cs, left_u8, right_u8)) + } + (ConstrainedInteger::U32(left_u32), ConstrainedInteger::U32(right_u32)) => { + ConstrainedInteger::U32(Self::enforce_u32_pow(cs, left_u32, right_u32)) + } + (left, right) => unimplemented!( + "cannot enforce integer exponentiation between {} ** {}", + left, + right + ), + }) + } +} diff --git a/compiler/src/constraints/integer/mod.rs b/compiler/src/constraints/integer/mod.rs new file mode 100644 index 0000000000..d07e2512ac --- /dev/null +++ b/compiler/src/constraints/integer/mod.rs @@ -0,0 +1,10 @@ +//! Module containing methods to enforce constraints on integers in a Leo program + +pub mod integer; +pub use integer::*; + +pub mod uint8; +pub use uint8::*; + +pub mod uint32; +pub use uint32::*; diff --git a/compiler/src/constraints/integer.rs b/compiler/src/constraints/integer/uint32.rs similarity index 58% rename from compiler/src/constraints/integer.rs rename to compiler/src/constraints/integer/uint32.rs index 75ddf9464b..2463e93306 100644 --- a/compiler/src/constraints/integer.rs +++ b/compiler/src/constraints/integer/uint32.rs @@ -1,8 +1,9 @@ -//! Methods to enforce constraints on integers in a resolved Leo program. +//! Methods to enforce constraints on uint32s in a resolved Leo program. use crate::{ - constraints::{new_variable_from_variable, ResolvedProgram, ResolvedValue}, - types::{Integer, ParameterModel, ParameterValue, Variable}, + constraints::{new_variable_from_variable, ConstrainedProgram, ConstrainedValue}, + types::{ParameterModel, ParameterValue, Variable}, + ConstrainedInteger, }; use snarkos_errors::gadgets::SynthesisError; @@ -10,11 +11,11 @@ use snarkos_models::{ curves::{Field, PrimeField}, gadgets::{ r1cs::ConstraintSystem, - utilities::{alloc::AllocGadget, boolean::Boolean, eq::EqGadget, uint32::UInt32}, + utilities::{alloc::AllocGadget, eq::EqGadget, uint32::UInt32}, }, }; -impl> ResolvedProgram { +impl> ConstrainedProgram { pub(crate) fn u32_from_parameter( &mut self, cs: &mut CS, @@ -45,7 +46,10 @@ impl> ResolvedProgram { let parameter_variable = new_variable_from_variable(scope, ¶meter_model.variable); // store each argument as variable in resolved program - self.store_variable(parameter_variable.clone(), ResolvedValue::U32(integer)); + self.store_variable( + parameter_variable.clone(), + ConstrainedValue::Integer(ConstrainedInteger::U32(integer)), + ); parameter_variable } @@ -80,72 +84,52 @@ impl> ResolvedProgram { // parameter_variable } - pub(crate) fn get_integer_constant(integer: Integer) -> ResolvedValue { - match integer { - Integer::U32(u32_value) => ResolvedValue::U32(UInt32::constant(u32_value)), - } - } - - pub(crate) fn u32_eq(left: UInt32, right: UInt32) -> ResolvedValue { - ResolvedValue::Boolean(Boolean::Constant(left.eq(&right))) - } - pub(crate) fn enforce_u32_eq(cs: &mut CS, left: UInt32, right: UInt32) { left.enforce_equal(cs.ns(|| format!("enforce u32 equal")), &right) .unwrap(); } - pub(crate) fn enforce_u32_add(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { - ResolvedValue::U32( - UInt32::addmany( - cs.ns(|| format!("enforce {} + {}", left.value.unwrap(), right.value.unwrap())), - &[left, right], - ) - .unwrap(), + pub(crate) fn enforce_u32_add(cs: &mut CS, left: UInt32, right: UInt32) -> UInt32 { + UInt32::addmany( + cs.ns(|| format!("enforce {} + {}", left.value.unwrap(), right.value.unwrap())), + &[left, right], ) + .unwrap() } - pub(crate) fn enforce_u32_sub(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { - ResolvedValue::U32( - left.sub( - cs.ns(|| format!("enforce {} - {}", left.value.unwrap(), right.value.unwrap())), - &right, - ) - .unwrap(), + pub(crate) fn enforce_u32_sub(cs: &mut CS, left: UInt32, right: UInt32) -> UInt32 { + left.sub( + cs.ns(|| format!("enforce {} - {}", left.value.unwrap(), right.value.unwrap())), + &right, ) + .unwrap() } - pub(crate) fn enforce_u32_mul(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { - ResolvedValue::U32( - left.mul( - cs.ns(|| format!("enforce {} * {}", left.value.unwrap(), right.value.unwrap())), - &right, - ) - .unwrap(), + pub(crate) fn enforce_u32_mul(cs: &mut CS, left: UInt32, right: UInt32) -> UInt32 { + left.mul( + cs.ns(|| format!("enforce {} * {}", left.value.unwrap(), right.value.unwrap())), + &right, ) + .unwrap() } - pub(crate) fn enforce_u32_div(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { - ResolvedValue::U32( - left.div( - cs.ns(|| format!("enforce {} / {}", left.value.unwrap(), right.value.unwrap())), - &right, - ) - .unwrap(), + pub(crate) fn enforce_u32_div(cs: &mut CS, left: UInt32, right: UInt32) -> UInt32 { + left.div( + cs.ns(|| format!("enforce {} / {}", left.value.unwrap(), right.value.unwrap())), + &right, ) + .unwrap() } - pub(crate) fn enforce_u32_pow(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { - ResolvedValue::U32( - left.pow( - cs.ns(|| { - format!( - "enforce {} ** {}", - left.value.unwrap(), - right.value.unwrap() - ) - }), - &right, - ) - .unwrap(), + pub(crate) fn enforce_u32_pow(cs: &mut CS, left: UInt32, right: UInt32) -> UInt32 { + left.pow( + cs.ns(|| { + format!( + "enforce {} ** {}", + left.value.unwrap(), + right.value.unwrap() + ) + }), + &right, ) + .unwrap() } } diff --git a/compiler/src/constraints/integer/uint8.rs b/compiler/src/constraints/integer/uint8.rs new file mode 100644 index 0000000000..3fc5925be6 --- /dev/null +++ b/compiler/src/constraints/integer/uint8.rs @@ -0,0 +1,135 @@ +//! Methods to enforce constraints on uint8s in a resolved Leo program. + +use crate::{ + constraints::{new_variable_from_variable, ConstrainedProgram, ConstrainedValue}, + types::{ParameterModel, ParameterValue, Variable}, + ConstrainedInteger, +}; + +use snarkos_errors::gadgets::SynthesisError; +use snarkos_models::{ + curves::{Field, PrimeField}, + gadgets::{ + r1cs::ConstraintSystem, + utilities::{alloc::AllocGadget, eq::EqGadget, uint8::UInt8}, + }, +}; + +impl> ConstrainedProgram { + pub(crate) fn u8_from_parameter( + &mut self, + cs: &mut CS, + scope: String, + parameter_model: ParameterModel, + parameter_value: Option>, + ) -> Variable { + // Check that the parameter value is the correct type + let integer_option = parameter_value.map(|parameter| match parameter { + ParameterValue::Integer(i) => i as u8, + value => unimplemented!("expected integer parameter, got {}", value), + }); + + // Check visibility of parameter + let name = parameter_model.variable.name.clone(); + let integer = if parameter_model.private { + UInt8::alloc(cs.ns(|| name), || { + integer_option.ok_or(SynthesisError::AssignmentMissing) + }) + .unwrap() + } else { + UInt8::alloc_input(cs.ns(|| name), || { + integer_option.ok_or(SynthesisError::AssignmentMissing) + }) + .unwrap() + }; + + let parameter_variable = new_variable_from_variable(scope, ¶meter_model.variable); + + // store each argument as variable in resolved program + self.store_variable( + parameter_variable.clone(), + ConstrainedValue::Integer(ConstrainedInteger::U8(integer)), + ); + + parameter_variable + } + + pub(crate) fn u8_array_from_parameter( + &mut self, + _cs: &mut CS, + _scope: String, + _parameter_model: ParameterModel, + _parameter_value: Option>, + ) -> Variable { + unimplemented!("Cannot enforce integer array as parameter") + // // Check visibility of parameter + // let mut array_value = vec![]; + // let name = parameter.variable.name.clone(); + // for argument in argument_array { + // let number = if parameter.private { + // UInt32::alloc(cs.ns(|| name), Some(argument)).unwrap() + // } else { + // UInt32::alloc_input(cs.ns(|| name), Some(argument)).unwrap() + // }; + // + // array_value.push(number); + // } + // + // + // let parameter_variable = new_variable_from_variable(scope, ¶meter.variable); + // + // // store array as variable in resolved program + // self.store_variable(parameter_variable.clone(), ResolvedValue::U32Array(array_value)); + // + // parameter_variable + } + + pub(crate) fn enforce_u8_eq(cs: &mut CS, left: UInt8, right: UInt8) { + left.enforce_equal(cs.ns(|| format!("enforce u8 equal")), &right) + .unwrap(); + } + + pub(crate) fn enforce_u8_add(cs: &mut CS, left: UInt8, right: UInt8) -> UInt8 { + UInt8::addmany( + cs.ns(|| format!("enforce {} + {}", left.value.unwrap(), right.value.unwrap())), + &[left, right], + ) + .unwrap() + } + + pub(crate) fn enforce_u8_sub(cs: &mut CS, left: UInt8, right: UInt8) -> UInt8 { + left.sub( + cs.ns(|| format!("enforce {} - {}", left.value.unwrap(), right.value.unwrap())), + &right, + ) + .unwrap() + } + + pub(crate) fn enforce_u8_mul(cs: &mut CS, left: UInt8, right: UInt8) -> UInt8 { + left.mul( + cs.ns(|| format!("enforce {} * {}", left.value.unwrap(), right.value.unwrap())), + &right, + ) + .unwrap() + } + pub(crate) fn enforce_u8_div(cs: &mut CS, left: UInt8, right: UInt8) -> UInt8 { + left.div( + cs.ns(|| format!("enforce {} / {}", left.value.unwrap(), right.value.unwrap())), + &right, + ) + .unwrap() + } + pub(crate) fn enforce_u8_pow(cs: &mut CS, left: UInt8, right: UInt8) -> UInt8 { + left.pow( + cs.ns(|| { + format!( + "enforce {} ** {}", + left.value.unwrap(), + right.value.unwrap() + ) + }), + &right, + ) + .unwrap() + } +} diff --git a/compiler/src/constraints/main_function.rs b/compiler/src/constraints/main_function.rs index 3800815668..724788c18e 100644 --- a/compiler/src/constraints/main_function.rs +++ b/compiler/src/constraints/main_function.rs @@ -4,8 +4,8 @@ use crate::{ ast, constraints::{ - new_scope, new_scope_from_variable, new_variable_from_variables, ResolvedProgram, - ResolvedValue, + new_scope, new_scope_from_variable, new_variable_from_variables, ConstrainedProgram, + ConstrainedValue, }, types::{Expression, Function, ParameterValue, Program, Type}, Import, @@ -19,7 +19,7 @@ use snarkos_models::{ use std::fs; use std::path::Path; -impl> ResolvedProgram { +impl> ConstrainedProgram { fn enforce_argument( &mut self, cs: &mut CS, @@ -27,7 +27,7 @@ impl> ResolvedProgram { caller_scope: String, function_name: String, argument: Expression, - ) -> ResolvedValue { + ) -> ConstrainedValue { match argument { Expression::Variable(variable) => self.enforce_variable(caller_scope, variable), expression => self.enforce_expression(cs, scope, function_name, expression), @@ -41,7 +41,7 @@ impl> ResolvedProgram { caller_scope: String, function: Function, arguments: Vec>, - ) -> ResolvedValue { + ) -> ConstrainedValue { let function_name = new_scope(scope.clone(), function.get_name()); // Make sure we are given the correct number of arguments @@ -62,7 +62,7 @@ impl> ResolvedProgram { .for_each(|(parameter, argument)| { // Check that argument is correct type match parameter._type.clone() { - Type::U32 => { + Type::IntegerType(integer_type) => { match self.enforce_argument( cs, scope.clone(), @@ -70,13 +70,14 @@ impl> ResolvedProgram { function_name.clone(), argument, ) { - ResolvedValue::U32(number) => { + ConstrainedValue::Integer(number) => { + number.expect_type(&integer_type); // Store argument as variable with {function_name}_{parameter name} let variable_name = new_scope_from_variable( function_name.clone(), ¶meter.variable, ); - self.store(variable_name, ResolvedValue::U32(number)); + self.store(variable_name, ConstrainedValue::Integer(number)); } argument => { unimplemented!("expected integer argument got {}", argument) @@ -91,13 +92,13 @@ impl> ResolvedProgram { function_name.clone(), argument, ) { - ResolvedValue::FieldElement(fe) => { + ConstrainedValue::FieldElement(fe) => { // Store argument as variable with {function_name}_{parameter name} let variable_name = new_scope_from_variable( function_name.clone(), ¶meter.variable, ); - self.store(variable_name, ResolvedValue::FieldElement(fe)); + self.store(variable_name, ConstrainedValue::FieldElement(fe)); } argument => unimplemented!("expected field argument got {}", argument), } @@ -110,13 +111,13 @@ impl> ResolvedProgram { function_name.clone(), argument, ) { - ResolvedValue::Boolean(bool) => { + ConstrainedValue::Boolean(bool) => { // Store argument as variable with {function_name}_{parameter name} let variable_name = new_scope_from_variable( function_name.clone(), ¶meter.variable, ); - self.store(variable_name, ResolvedValue::Boolean(bool)); + self.store(variable_name, ConstrainedValue::Boolean(bool)); } argument => { unimplemented!("expected boolean argument got {}", argument) @@ -129,7 +130,7 @@ impl> ResolvedProgram { // Evaluate function statements - let mut return_values = ResolvedValue::Return(vec![]); + let mut return_values = ConstrainedValue::Return(vec![]); for statement in function.statements.iter() { if let Some(returned) = self.enforce_statement( @@ -153,7 +154,7 @@ impl> ResolvedProgram { scope: String, function: Function, parameters: Vec>>, - ) -> ResolvedValue { + ) -> ConstrainedValue { let function_name = new_scope(scope.clone(), function.get_name()); let mut arguments = vec![]; @@ -168,7 +169,7 @@ impl> ResolvedProgram { .for_each(|(parameter_model, parameter_value)| { // append each variable to arguments vector arguments.push(Expression::Variable(match parameter_model._type { - Type::U32 => self.u32_from_parameter( + Type::IntegerType(ref _integer_type) => self.integer_from_parameter( cs, function_name.clone(), parameter_model, @@ -187,7 +188,7 @@ impl> ResolvedProgram { parameter_value, ), Type::Array(ref ty, _length) => match *ty.clone() { - Type::U32 => self.u32_array_from_parameter( + Type::IntegerType(_type) => self.integer_array_from_parameter( cs, function_name.clone(), parameter_model, @@ -259,7 +260,7 @@ impl> ResolvedProgram { // store imported struct under resolved name self.store_variable( resolved_struct_name, - ResolvedValue::StructDefinition(struct_def), + ConstrainedValue::StructDefinition(struct_def), ); } None => { @@ -280,7 +281,7 @@ impl> ResolvedProgram { // store imported function under resolved name self.store_variable( resolved_function_name, - ResolvedValue::Function(function), + ConstrainedValue::Function(function), ) } None => unimplemented!( @@ -318,7 +319,7 @@ impl> ResolvedProgram { new_variable_from_variables(&program_name.clone(), &variable); self.store_variable( resolved_struct_name, - ResolvedValue::StructDefinition(struct_def), + ConstrainedValue::StructDefinition(struct_def), ); }); @@ -328,7 +329,7 @@ impl> ResolvedProgram { .into_iter() .for_each(|(function_name, function)| { let resolved_function_name = new_scope(program_name.name.clone(), function_name.0); - self.store(resolved_function_name, ResolvedValue::Function(function)); + self.store(resolved_function_name, ConstrainedValue::Function(function)); }); } @@ -336,8 +337,8 @@ impl> ResolvedProgram { cs: &mut CS, program: Program, parameters: Vec>>, - ) -> ResolvedValue { - let mut resolved_program = ResolvedProgram::new(); + ) -> ConstrainedValue { + let mut resolved_program = ConstrainedProgram::new(); let program_name = program.get_name(); let main_function_name = new_scope(program_name.clone(), "main".into()); @@ -348,7 +349,7 @@ impl> ResolvedProgram { .expect("main function not defined"); match main.clone() { - ResolvedValue::Function(function) => { + ConstrainedValue::Function(function) => { let result = resolved_program.enforce_main_function(cs, program_name, function, parameters); log::debug!("{}", result); diff --git a/compiler/src/constraints/mod.rs b/compiler/src/constraints/mod.rs index 907625ec94..d67b5309a1 100644 --- a/compiler/src/constraints/mod.rs +++ b/compiler/src/constraints/mod.rs @@ -15,11 +15,11 @@ pub use integer::*; pub mod field_element; pub use field_element::*; -pub mod resolved_program; -pub use resolved_program::*; +pub mod constrained_program; +pub use constrained_program::*; -pub mod resolved_value; -pub use resolved_value::*; +pub mod constrained_value; +pub use constrained_value::*; pub mod statement; pub use statement::*; diff --git a/compiler/src/constraints/statement.rs b/compiler/src/constraints/statement.rs index 54ce16bd1d..15b25ca046 100644 --- a/compiler/src/constraints/statement.rs +++ b/compiler/src/constraints/statement.rs @@ -1,11 +1,12 @@ //! Methods to enforce constraints on statements in a resolved Leo program. use crate::{ - constraints::{new_scope_from_variable, ResolvedProgram, ResolvedValue}, + constraints::{new_scope_from_variable, ConstrainedProgram, ConstrainedValue}, types::{ Assignee, ConditionalNestedOrEnd, ConditionalStatement, Expression, Integer, RangeOrExpression, Statement, Type, Variable, }, + ConstrainedInteger, }; use snarkos_models::{ @@ -13,7 +14,7 @@ use snarkos_models::{ gadgets::{r1cs::ConstraintSystem, utilities::boolean::Boolean, utilities::uint32::UInt32}, }; -impl> ResolvedProgram { +impl> ConstrainedProgram { fn resolve_assignee(&mut self, scope: String, assignee: Assignee) -> String { match assignee { Assignee::Variable(name) => new_scope_from_variable(scope, &name), @@ -30,7 +31,7 @@ impl> ResolvedProgram { file_scope: String, function_scope: String, assignee: Assignee, - return_value: &mut ResolvedValue, + return_value: &mut ConstrainedValue, ) { match assignee { Assignee::Variable(name) => { @@ -56,7 +57,7 @@ impl> ResolvedProgram { // Modify the single value of the array in place match self.get_mut(&expected_array_name) { Some(value) => match value { - ResolvedValue::Array(old) => { + ConstrainedValue::Array(old) => { old[index] = return_value.to_owned(); } _ => { @@ -82,7 +83,7 @@ impl> ResolvedProgram { // Modify the range of values of the array in place match self.get_mut(&expected_array_name) { Some(value) => match (value, return_value) { - (ResolvedValue::Array(old), ResolvedValue::Array(new)) => { + (ConstrainedValue::Array(old), ConstrainedValue::Array(new)) => { let to_index = to_index_option.unwrap_or(old.len()); old.splice(from_index..to_index, new.iter().cloned()); } @@ -105,7 +106,7 @@ impl> ResolvedProgram { match self.get_mut(&expected_struct_name) { Some(value) => match value { - ResolvedValue::StructExpression(_variable, members) => { + ConstrainedValue::StructExpression(_variable, members) => { // Modify the struct member in place let matched_member = members.into_iter().find(|member| member.0 == struct_member); @@ -176,11 +177,8 @@ impl> ResolvedProgram { match ty { // Explicit type Some(ty) => { - if result_value.match_type(&ty) { - self.store_assignment(cs, file_scope, function_scope, assignee, result_value); - } else { - unimplemented!("incompatible types {} = {}", assignee, result_value) - } + result_value.expect_type(&ty); + self.store_assignment(cs, file_scope, function_scope, assignee, result_value); } // Implicit type None => self.store_assignment(cs, file_scope, function_scope, assignee, result_value), @@ -199,7 +197,7 @@ impl> ResolvedProgram { let return_values = match self.enforce_expression(cs, file_scope.clone(), function_scope.clone(), function) { - ResolvedValue::Return(values) => values, + ConstrainedValue::Return(values) => values, value => unimplemented!( "multiple assignment only implemented for functions, got {}", value @@ -227,7 +225,7 @@ impl> ResolvedProgram { function_scope: String, expressions: Vec>, return_types: Vec>, - ) -> ResolvedValue { + ) -> ConstrainedValue { // Make sure we return the correct number of values if return_types.len() != expressions.len() { unimplemented!( @@ -237,7 +235,7 @@ impl> ResolvedProgram { ) } - ResolvedValue::Return( + ConstrainedValue::Return( expressions .into_iter() .zip(return_types.into_iter()) @@ -248,13 +246,10 @@ impl> ResolvedProgram { function_scope.clone(), expression, ); - if !result.match_type(&ty) { - unimplemented!("expected return type {}, got {}", ty, result) - } else { - result - } + result.expect_type(&ty); + result }) - .collect::>>(), + .collect::>>(), ) } @@ -265,7 +260,7 @@ impl> ResolvedProgram { function_scope: String, statements: Vec>, return_types: Vec>, - ) -> Option> { + ) -> Option> { let mut res = None; // Evaluate statements and possibly return early for statement in statements.iter() { @@ -291,14 +286,14 @@ impl> ResolvedProgram { function_scope: String, statement: ConditionalStatement, return_types: Vec>, - ) -> Option> { + ) -> Option> { let condition = match self.enforce_expression( cs, file_scope.clone(), function_scope.clone(), statement.condition.clone(), ) { - ResolvedValue::Boolean(resolved) => resolved, + ConstrainedValue::Boolean(resolved) => resolved, value => unimplemented!("if else conditional must resolve to boolean, got {}", value), }; @@ -344,14 +339,17 @@ impl> ResolvedProgram { stop: Integer, statements: Vec>, return_types: Vec>, - ) -> Option> { + ) -> Option> { let mut res = None; for i in start.to_usize()..stop.to_usize() { // Store index in current function scope. // For loop scope is not implemented. let index_name = new_scope_from_variable(function_scope.clone(), &index); - self.store(index_name, ResolvedValue::U32(UInt32::constant(i as u32))); + self.store( + index_name, + ConstrainedValue::Integer(ConstrainedInteger::U32(UInt32::constant(i as u32))), + ); // Evaluate statements and possibly return early if let Some(early_return) = self.iterate_or_early_return( @@ -372,17 +370,17 @@ impl> ResolvedProgram { fn enforce_assert_eq_statement( &mut self, cs: &mut CS, - left: ResolvedValue, - right: ResolvedValue, + left: ConstrainedValue, + right: ConstrainedValue, ) { match (left, right) { - (ResolvedValue::Boolean(bool_1), ResolvedValue::Boolean(bool_2)) => { + (ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => { self.enforce_boolean_eq(cs, bool_1, bool_2) } - (ResolvedValue::U32(num_1), ResolvedValue::U32(num_2)) => { - Self::enforce_u32_eq(cs, num_1, num_2) + (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { + Self::enforce_integer_eq(cs, num_1, num_2) } - (ResolvedValue::FieldElement(fe_1), ResolvedValue::FieldElement(fe_2)) => { + (ConstrainedValue::FieldElement(fe_1), ConstrainedValue::FieldElement(fe_2)) => { self.enforce_field_eq(cs, fe_1, fe_2) } (val_1, val_2) => { @@ -398,7 +396,7 @@ impl> ResolvedProgram { function_scope: String, statement: Statement, return_types: Vec>, - ) -> Option> { + ) -> Option> { let mut res = None; match statement { Statement::Return(expressions) => { @@ -467,7 +465,7 @@ impl> ResolvedProgram { } Statement::Expression(expression) => { match self.enforce_expression(cs, file_scope, function_scope, expression.clone()) { - ResolvedValue::Return(values) => { + ConstrainedValue::Return(values) => { if !values.is_empty() { unimplemented!("function return values not assigned {:#?}", values) } diff --git a/compiler/src/leo.pest b/compiler/src/leo.pest index 9f83663ca9..8e4ddfbd21 100644 --- a/compiler/src/leo.pest +++ b/compiler/src/leo.pest @@ -56,11 +56,16 @@ operation_assign = { } /// types - +type_u8 = {"u8"} type_u32 = {"u32"} +type_integer = { + type_u8 + | type_u32 +} + type_field = {"fe"} type_bool = {"bool"} -type_basic = { type_u32 | type_field | type_bool } +type_basic = { type_integer | type_field | type_bool } type_struct = { variable } type_basic_or_struct = {type_basic | type_struct } type_array = {type_basic ~ ("[" ~ value ~ "]")+ } @@ -70,10 +75,10 @@ type_list = _{(_type ~ ("," ~ _type)*)?} /// Values value_number = @{ "0" | ASCII_NONZERO_DIGIT ~ ASCII_DIGIT* } -value_u32 = { value_number ~ type_u32? } +value_integer = { value_number ~ type_integer? } value_field = { value_number ~ type_field } value_boolean = { "true" | "false" } -value = { value_field | value_boolean | value_u32 } +value = { value_field | value_boolean | value_integer } /// Variables diff --git a/compiler/src/types.rs b/compiler/src/types.rs index 36272da3c2..e1cea6f85f 100644 --- a/compiler/src/types.rs +++ b/compiler/src/types.rs @@ -17,7 +17,7 @@ pub struct Variable { /// An integer type enum wrapping the integer value #[derive(Debug, Clone, PartialEq, Eq)] pub enum Integer { - // U8(u8), + U8(u8), U32(u32), // U64(u64), } @@ -25,7 +25,7 @@ pub enum Integer { impl Integer { pub fn to_usize(&self) -> usize { match *self { - // U8(u8) + Integer::U8(num) => num as usize, Integer::U32(num) => num as usize, // U64(u64) } @@ -97,10 +97,17 @@ pub enum Assignee { StructMember(Box>, Variable), } +/// Explicit integer type +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum IntegerType { + U8, + U32, +} + /// Explicit type used for defining a variable or expression type #[derive(Clone, Debug, PartialEq, Eq)] pub enum Type { - U32, + IntegerType(IntegerType), FieldElement, Boolean, Array(Box>, usize), diff --git a/compiler/src/types_display.rs b/compiler/src/types_display.rs index 8b2f49e215..262f44901e 100644 --- a/compiler/src/types_display.rs +++ b/compiler/src/types_display.rs @@ -2,8 +2,8 @@ use crate::{ Assignee, ConditionalNestedOrEnd, ConditionalStatement, Expression, Function, FunctionName, - Integer, ParameterModel, ParameterValue, RangeOrExpression, SpreadOrExpression, Statement, - Struct, StructField, Type, Variable, + Integer, IntegerType, ParameterModel, ParameterValue, RangeOrExpression, SpreadOrExpression, + Statement, Struct, StructField, Type, Variable, }; use snarkos_models::curves::{Field, PrimeField}; @@ -23,6 +23,7 @@ impl fmt::Debug for Variable { impl fmt::Display for Integer { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { + Integer::U8(ref num) => write!(f, "{}", num), Integer::U32(ref num) => write!(f, "{}", num), } } @@ -217,10 +218,19 @@ impl fmt::Display for Statement { } } +impl fmt::Display for IntegerType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + IntegerType::U8 => write!(f, "u8"), + IntegerType::U32 => write!(f, "u32"), + } + } +} + impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Type::U32 => write!(f, "u32"), + Type::IntegerType(ref integer_type) => write!(f, "{}", integer_type), Type::FieldElement => write!(f, "fe"), Type::Boolean => write!(f, "bool"), Type::Struct(ref variable) => write!(f, "{}", variable), diff --git a/compiler/src/types_from.rs b/compiler/src/types_from.rs index 25fb3f69a8..cad12c57ae 100644 --- a/compiler/src/types_from.rs +++ b/compiler/src/types_from.rs @@ -23,15 +23,32 @@ impl<'ast, F: Field + PrimeField> From> for types::Expressio } /// pest ast - types::Integer -impl<'ast, F: Field + PrimeField> From> for types::Expression { - fn from(field: ast::U32<'ast>) -> Self { - types::Expression::Integer(types::Integer::U32( - field - .number - .value - .parse::() - .expect("unable to parse u32"), - )) +impl<'ast> types::Integer { + pub(crate) fn from(number: ast::Number<'ast>, _type: ast::IntegerType) -> Self { + match _type { + ast::IntegerType::U8Type(_u8) => { + types::Integer::U8(number.value.parse::().expect("unable to parse u8")) + } + ast::IntegerType::U32Type(_u32) => { + types::Integer::U32(number.value.parse::().expect("unable to parse u32")) + } + } + } +} + +impl<'ast, F: Field + PrimeField> From> for types::Expression { + fn from(field: ast::Integer<'ast>) -> Self { + types::Expression::Integer(match field._type { + Some(_type) => types::Integer::from(field.number, _type), + // default integer type is u32 + None => types::Integer::U32( + field + .number + .value + .parse::() + .expect("unable to parse u32"), + ), + }) } } @@ -91,7 +108,7 @@ impl<'ast, F: Field + PrimeField> From> for types::Expression impl<'ast, F: Field + PrimeField> From> for types::Expression { fn from(value: ast::Value<'ast>) -> Self { match value { - ast::Value::U32(num) => types::Expression::from(num), + ast::Value::Integer(num) => types::Expression::from(num), ast::Value::Field(fe) => types::Expression::from(fe), ast::Value::Boolean(bool) => types::Expression::from(bool), } @@ -290,7 +307,7 @@ impl<'ast, F: Field + PrimeField> From> for types::Express impl<'ast, F: Field + PrimeField> types::Expression { fn get_count(count: ast::Value<'ast>) -> usize { match count { - ast::Value::U32(f) => f + ast::Value::Integer(f) => f .number .value .parse::() @@ -554,10 +571,19 @@ impl<'ast, F: Field + PrimeField> From> for types::Statemen /// pest ast -> Explicit types::Type for defining struct members and function params +impl From for types::IntegerType { + fn from(integer_type: ast::IntegerType) -> Self { + match integer_type { + ast::IntegerType::U8Type(_type) => types::IntegerType::U8, + ast::IntegerType::U32Type(_type) => types::IntegerType::U32, + } + } +} + impl<'ast, F: Field + PrimeField> From> for types::Type { fn from(basic_type: ast::BasicType<'ast>) -> Self { match basic_type { - ast::BasicType::U32(_ty) => types::Type::U32, + ast::BasicType::Integer(ty) => types::Type::IntegerType(types::IntegerType::from(ty)), ast::BasicType::Field(_ty) => types::Type::FieldElement, ast::BasicType::Boolean(_ty) => types::Type::Boolean, } diff --git a/compiler/tests/u32/mod.rs b/compiler/tests/u32/mod.rs index 2f20f8b60c..78b0e719c8 100644 --- a/compiler/tests/u32/mod.rs +++ b/compiler/tests/u32/mod.rs @@ -1,4 +1,4 @@ -use leo_compiler::{compiler::Compiler, ResolvedValue}; +use leo_compiler::{compiler::Compiler, ConstrainedValue}; use snarkos_curves::bls12_377::Fr; @@ -39,7 +39,7 @@ fn test_zero() { let output = output.unwrap(); assert_eq!( - ResolvedValue::::Return(vec![ResolvedValue::::U32(UInt32::constant(0))]), + ConstrainedValue::::Return(vec![ConstrainedValue::::Integer(UInt32::constant(0))]), output ); println!("{}", output); @@ -54,7 +54,7 @@ fn test_one() { let output = output.unwrap(); assert_eq!( - ResolvedValue::::Return(vec![ResolvedValue::::U32(UInt32::constant(1))]), + ConstrainedValue::::Return(vec![ConstrainedValue::::Integer(UInt32::constant(1))]), output ); println!("{}", output); @@ -69,7 +69,7 @@ fn test_1_plus_1() { let output = output.unwrap(); assert_eq!( - ResolvedValue::::Return(vec![ResolvedValue::::U32(UInt32::constant(2))]), + ConstrainedValue::::Return(vec![ConstrainedValue::::Integer(UInt32::constant(2))]), output ); println!("{}", output); @@ -84,7 +84,7 @@ fn test_1_minus_1() { let output = output.unwrap(); assert_eq!( - ResolvedValue::::Return(vec![ResolvedValue::::U32(UInt32::constant(0))]), + ConstrainedValue::::Return(vec![ConstrainedValue::::Integer(UInt32::constant(0))]), output ); println!("{}", output);