diff --git a/simple.program b/simple.program index 0007bcbca0..8e0dc94810 100644 --- a/simple.program +++ b/simple.program @@ -1 +1 @@ -x = 5 + 3 \ No newline at end of file +return 5 + a \ No newline at end of file diff --git a/src/aleo_program/constraints.rs b/src/aleo_program/constraints.rs index f4a482a2b0..4435d36b51 100644 --- a/src/aleo_program/constraints.rs +++ b/src/aleo_program/constraints.rs @@ -8,252 +8,338 @@ use snarkos_models::gadgets::{ r1cs::ConstraintSystem, utilities::{alloc::AllocGadget, boolean::Boolean, eq::ConditionalEqGadget, uint32::UInt32}, }; +use std::collections::HashMap; -fn bool_from_variable>( - cs: &mut CS, - variable: Variable, -) -> Boolean { - let argument = std::env::args() - .nth(1) - .unwrap_or("true".into()) - .parse::() +pub enum ResolvedValue { + Boolean(Boolean), + FieldElement(UInt32), +} + +pub struct ResolvedProgram { + pub resolved_variables: HashMap, +} + +impl ResolvedProgram { + fn new() -> Self { + Self { + resolved_variables: HashMap::new(), + } + } + + fn insert(&mut self, variable: Variable, value: ResolvedValue) { + self.resolved_variables.insert(variable, value); + } + + fn bool_from_variable>( + &mut self, + cs: &mut CS, + variable: Variable, + ) -> Boolean { + if self.resolved_variables.contains_key(&variable) { + match self.resolved_variables.get(&variable).unwrap() { + ResolvedValue::Boolean(boolean) => boolean.clone(), + _ => panic!("expected a boolean, got field"), + }; + Boolean::Constant(true) + } else { + let argument = std::env::args() + .nth(1) + .unwrap_or("true".into()) + .parse::() + .unwrap(); + println!(" argument passed to command line a = {:?}", argument); + // let a = true; + Boolean::alloc_input(cs.ns(|| variable.0), || Ok(argument)).unwrap() + } + } + + fn u32_from_variable>( + &mut self, + cs: &mut CS, + variable: Variable, + ) -> UInt32 { + if self.resolved_variables.contains_key(&variable) { + match self.resolved_variables.get(&variable).unwrap() { + ResolvedValue::FieldElement(field) => field.clone(), + _ => panic!("expected a field, got boolean"), + } + } else { + let argument = std::env::args() + .nth(1) + .unwrap_or("1".into()) + .parse::() + .unwrap(); + + println!(" argument passed to command line a = {:?}", argument); + + // let a = 1; + UInt32::alloc(cs.ns(|| variable.0), Some(argument)).unwrap() + } + } + + fn get_bool_value>( + &mut self, + cs: &mut CS, + expression: BooleanExpression, + ) -> Boolean { + match expression { + BooleanExpression::Variable(variable) => self.bool_from_variable(cs, variable), + BooleanExpression::Value(value) => Boolean::Constant(value), + expression => self.enforce_boolean_expression(cs, expression), + } + } + + fn get_u32_value>( + &mut self, + cs: &mut CS, + expression: FieldExpression, + ) -> UInt32 { + match expression { + FieldExpression::Variable(variable) => self.u32_from_variable(cs, variable), + FieldExpression::Number(number) => UInt32::constant(number), + field => self.enforce_field_expression(cs, field), + } + } + + fn enforce_or>( + &mut self, + cs: &mut CS, + left: BooleanExpression, + right: BooleanExpression, + ) -> Boolean { + let left = self.get_bool_value(cs, left); + let right = self.get_bool_value(cs, right); + + Boolean::or(cs, &left, &right).unwrap() + } + + fn enforce_and>( + &mut self, + cs: &mut CS, + left: BooleanExpression, + right: BooleanExpression, + ) -> Boolean { + let left = self.get_bool_value(cs, left); + let right = self.get_bool_value(cs, right); + + Boolean::and(cs, &left, &right).unwrap() + } + + fn enforce_bool_equality>( + &mut self, + cs: &mut CS, + left: BooleanExpression, + right: BooleanExpression, + ) -> Boolean { + let left = self.get_bool_value(cs, left); + let right = self.get_bool_value(cs, right); + + left.enforce_equal(cs.ns(|| format!("enforce bool equal")), &right) + .unwrap(); + + Boolean::Constant(true) + } + + fn enforce_field_equality>( + &mut self, + cs: &mut CS, + left: FieldExpression, + right: FieldExpression, + ) -> Boolean { + let left = self.get_u32_value(cs, left); + let right = self.get_u32_value(cs, right); + + left.conditional_enforce_equal( + cs.ns(|| format!("enforce field equal")), + &right, + &Boolean::Constant(true), + ) .unwrap(); - println!(" argument passed to command line a = {:?}", argument); - // let a = true; - Boolean::alloc_input(cs.ns(|| variable.0), || Ok(argument)).unwrap() -} - -fn u32_from_variable>(cs: &mut CS, variable: Variable) -> UInt32 { - let argument = std::env::args() - .nth(1) - .unwrap_or("1".into()) - .parse::() - .unwrap(); - - println!(" argument passed to command line a = {:?}", argument); - - // let a = 1; - UInt32::alloc(cs.ns(|| variable.0), Some(argument)).unwrap() -} - -fn get_bool_value>( - cs: &mut CS, - expression: BooleanExpression, -) -> Boolean { - match expression { - BooleanExpression::Variable(variable) => bool_from_variable(cs, variable), - BooleanExpression::Value(value) => Boolean::Constant(value), - expression => enforce_boolean_expression(cs, expression), + Boolean::Constant(true) } -} -fn get_u32_value>( - cs: &mut CS, - expression: FieldExpression, -) -> UInt32 { - match expression { - FieldExpression::Variable(variable) => u32_from_variable(cs, variable), - FieldExpression::Number(number) => UInt32::constant(number), - field => enforce_field_expression(cs, field), - } -} - -fn enforce_or>( - cs: &mut CS, - left: BooleanExpression, - right: BooleanExpression, -) -> Boolean { - let left = get_bool_value(cs, left); - let right = get_bool_value(cs, right); - - Boolean::or(cs, &left, &right).unwrap() -} - -fn enforce_and>( - cs: &mut CS, - left: BooleanExpression, - right: BooleanExpression, -) -> Boolean { - let left = get_bool_value(cs, left); - let right = get_bool_value(cs, right); - - Boolean::and(cs, &left, &right).unwrap() -} - -fn enforce_bool_equality>( - cs: &mut CS, - left: BooleanExpression, - right: BooleanExpression, -) -> Boolean { - let left = get_bool_value(cs, left); - let right = get_bool_value(cs, right); - - left.enforce_equal(cs.ns(|| format!("enforce bool equal")), &right) - .unwrap(); - - Boolean::Constant(true) -} - -fn enforce_field_equality>( - cs: &mut CS, - left: FieldExpression, - right: FieldExpression, -) -> Boolean { - let left = get_u32_value(cs, left); - let right = get_u32_value(cs, right); - - left.conditional_enforce_equal( - cs.ns(|| format!("enforce field equal")), - &right, - &Boolean::Constant(true), - ) - .unwrap(); - - Boolean::Constant(true) -} - -fn enforce_boolean_expression>( - cs: &mut CS, - expression: BooleanExpression, -) -> Boolean { - match expression { - BooleanExpression::Or(left, right) => enforce_or(cs, *left, *right), - BooleanExpression::And(left, right) => enforce_and(cs, *left, *right), - BooleanExpression::BoolEq(left, right) => enforce_bool_equality(cs, *left, *right), - BooleanExpression::FieldEq(left, right) => enforce_field_equality(cs, *left, *right), - _ => unimplemented!(), - } -} - -fn enforce_add>( - cs: &mut CS, - left: FieldExpression, - right: FieldExpression, -) -> UInt32 { - let left = get_u32_value(cs, left); - let right = get_u32_value(cs, right); - - left.add( - cs.ns(|| format!("enforce {} + {}", left.value.unwrap(), right.value.unwrap())), - &right, - ) - .unwrap() -} - -fn enforce_sub>( - cs: &mut CS, - left: FieldExpression, - right: FieldExpression, -) -> UInt32 { - let left = get_u32_value(cs, left); - let right = get_u32_value(cs, right); - - left.sub( - cs.ns(|| format!("enforce {} - {}", left.value.unwrap(), right.value.unwrap())), - &right, - ) - .unwrap() -} - -fn enforce_mul>( - cs: &mut CS, - left: FieldExpression, - right: FieldExpression, -) -> UInt32 { - let left = get_u32_value(cs, left); - let right = get_u32_value(cs, right); - - left.mul( - cs.ns(|| format!("enforce {} * {}", left.value.unwrap(), right.value.unwrap())), - &right, - ) - .unwrap() -} - -fn enforce_div>( - cs: &mut CS, - left: FieldExpression, - right: FieldExpression, -) -> UInt32 { - let left = get_u32_value(cs, left); - let right = get_u32_value(cs, right); - - left.div( - cs.ns(|| format!("enforce {} / {}", left.value.unwrap(), right.value.unwrap())), - &right, - ) - .unwrap() -} - -fn enforce_pow>( - cs: &mut CS, - left: FieldExpression, - right: FieldExpression, -) -> UInt32 { - let left = get_u32_value(cs, left); - let right = get_u32_value(cs, right); - - left.pow( - cs.ns(|| { - format!( - "enforce {} ** {}", - left.value.unwrap(), - right.value.unwrap() - ) - }), - &right, - ) - .unwrap() -} - -fn enforce_field_expression>( - cs: &mut CS, - expression: FieldExpression, -) -> UInt32 { - match expression { - FieldExpression::Add(left, right) => enforce_add(cs, *left, *right), - FieldExpression::Sub(left, right) => enforce_sub(cs, *left, *right), - FieldExpression::Mul(left, right) => enforce_mul(cs, *left, *right), - FieldExpression::Div(left, right) => enforce_div(cs, *left, *right), - FieldExpression::Pow(left, right) => enforce_pow(cs, *left, *right), - _ => unimplemented!(), - } -} - -pub fn generate_constraints>(cs: &mut CS, program: Program) { - program - .statements - .into_iter() - .for_each(|statement| match statement { - Statement::Definition(variable, expression) => match expression { - Expression::Boolean(boolean_expression) => { - let res = enforce_boolean_expression(cs, boolean_expression); - println!("boolean result: {}", res.get_value().unwrap()); - } - Expression::FieldElement(field_expression) => { - let res = enforce_field_expression(cs, field_expression); - println!("field result: {}", res.value.unwrap()); - } - _ => unimplemented!(), - }, - Statement::Return(statements) => { - statements - .into_iter() - .for_each(|expression| match expression { - Expression::Boolean(boolean_expression) => { - let res = enforce_boolean_expression(cs, boolean_expression); - println!("boolean result: {}", res.get_value().unwrap()); - } - Expression::FieldElement(field_expression) => { - let res = enforce_field_expression(cs, field_expression); - println!("field result: {}", res.value.unwrap()); - } - _ => unimplemented!(), - }); + fn enforce_boolean_expression>( + &mut self, + cs: &mut CS, + expression: BooleanExpression, + ) -> Boolean { + match expression { + BooleanExpression::Or(left, right) => self.enforce_or(cs, *left, *right), + BooleanExpression::And(left, right) => self.enforce_and(cs, *left, *right), + BooleanExpression::BoolEq(left, right) => self.enforce_bool_equality(cs, *left, *right), + BooleanExpression::FieldEq(left, right) => { + self.enforce_field_equality(cs, *left, *right) } _ => unimplemented!(), - }); + } + } + + fn enforce_add>( + &mut self, + cs: &mut CS, + left: FieldExpression, + right: FieldExpression, + ) -> UInt32 { + let left = self.get_u32_value(cs, left); + let right = self.get_u32_value(cs, right); + + println!("left: {:#?}", left.value.unwrap()); + println!("right: {:#?}", right.value.unwrap()); + // println!("expected: {:#?}", UInt32::alloc(cs.ns(|| format!("expected")), Some(3))); + + let res = left + .add( + cs.ns(|| format!("enforce {} + {}", left.value.unwrap(), right.value.unwrap())), + &right, + ) + .unwrap(); + + println!("result: {:#?}", res.bits.to_vec()); + + res + } + + fn enforce_sub>( + &mut self, + cs: &mut CS, + left: FieldExpression, + right: FieldExpression, + ) -> UInt32 { + let left = self.get_u32_value(cs, left); + let right = self.get_u32_value(cs, right); + + left.sub( + cs.ns(|| format!("enforce {} - {}", left.value.unwrap(), right.value.unwrap())), + &right, + ) + .unwrap() + } + + fn enforce_mul>( + &mut self, + cs: &mut CS, + left: FieldExpression, + right: FieldExpression, + ) -> UInt32 { + let left = self.get_u32_value(cs, left); + let right = self.get_u32_value(cs, right); + + println!("left: {}", left.value.unwrap()); + println!("right: {}", right.value.unwrap()); + + let res = left + .mul( + cs.ns(|| format!("enforce {} * {}", left.value.unwrap(), right.value.unwrap())), + &right, + ) + .unwrap(); + + println!("result: {}", res.value.unwrap()); + + res + } + + fn enforce_div>( + &mut self, + cs: &mut CS, + left: FieldExpression, + right: FieldExpression, + ) -> UInt32 { + let left = self.get_u32_value(cs, left); + let right = self.get_u32_value(cs, right); + + left.div( + cs.ns(|| format!("enforce {} / {}", left.value.unwrap(), right.value.unwrap())), + &right, + ) + .unwrap() + } + + fn enforce_pow>( + &mut self, + cs: &mut CS, + left: FieldExpression, + right: FieldExpression, + ) -> UInt32 { + let left = self.get_u32_value(cs, left); + let right = self.get_u32_value(cs, right); + + left.pow( + cs.ns(|| { + format!( + "enforce {} ** {}", + left.value.unwrap(), + right.value.unwrap() + ) + }), + &right, + ) + .unwrap() + } + + fn enforce_field_expression>( + &mut self, + cs: &mut CS, + expression: FieldExpression, + ) -> UInt32 { + println!("enforcing: {}", expression); + match expression { + FieldExpression::Add(left, right) => self.enforce_add(cs, *left, *right), + FieldExpression::Sub(left, right) => self.enforce_sub(cs, *left, *right), + FieldExpression::Mul(left, right) => self.enforce_mul(cs, *left, *right), + FieldExpression::Div(left, right) => self.enforce_div(cs, *left, *right), + FieldExpression::Pow(left, right) => self.enforce_pow(cs, *left, *right), + _ => unimplemented!(), + } + } + + pub fn generate_constraints>(cs: &mut CS, program: Program) { + let mut resolved_program = ResolvedProgram::new(); + + program + .statements + .into_iter() + .for_each(|statement| match statement { + Statement::Definition(variable, expression) => match expression { + Expression::Boolean(boolean_expression) => { + let res = + resolved_program.enforce_boolean_expression(cs, boolean_expression); + // println!("variable boolean result: {}", res.get_value().unwrap()); + resolved_program.insert(variable, ResolvedValue::Boolean(res)); + } + Expression::FieldElement(field_expression) => { + let res = resolved_program.enforce_field_expression(cs, field_expression); + println!( + " variable field result: {} = {}", + variable.0, + res.value.unwrap() + ); + resolved_program.insert(variable, ResolvedValue::FieldElement(res)); + } + _ => unimplemented!(), + }, + Statement::Return(statements) => { + statements + .into_iter() + .for_each(|expression| match expression { + Expression::Boolean(boolean_expression) => { + let res = resolved_program + .enforce_boolean_expression(cs, boolean_expression); + println!("boolean result: {}", res.get_value().unwrap()); + } + Expression::FieldElement(field_expression) => { + println!("expression {:?}", field_expression); + let res = + resolved_program.enforce_field_expression(cs, field_expression); + println!("field result: {}", res.value.unwrap()); + } + _ => unimplemented!(), + }); + } + statement => unimplemented!("statement unimplemented: {}", statement), + }); + } } // impl Program { diff --git a/src/aleo_program/types.rs b/src/aleo_program/types.rs index 0037788a05..53d71e70fe 100644 --- a/src/aleo_program/types.rs +++ b/src/aleo_program/types.rs @@ -7,7 +7,7 @@ // id == 0 for field values // id < 0 for boolean values /// A variable in a constraint system. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Variable(pub String); // // /// Linear combination of variables in a program. (a + b + c) diff --git a/src/aleo_program/types_display.rs b/src/aleo_program/types_display.rs index 292eade165..4af0998c44 100644 --- a/src/aleo_program/types_display.rs +++ b/src/aleo_program/types_display.rs @@ -80,6 +80,9 @@ impl fmt::Debug for Statement { }); write!(f, "") } + Statement::Definition(ref variable, ref statement) => { + write!(f, "{} = {}", variable, statement) + } _ => unimplemented!(), } } diff --git a/src/aleo_program/types_from.rs b/src/aleo_program/types_from.rs index 7b04a16c61..b6605f81d3 100644 --- a/src/aleo_program/types_from.rs +++ b/src/aleo_program/types_from.rs @@ -5,7 +5,7 @@ //! @author Collin Chin //! @date 2020 -use crate::aleo_program::BooleanExpression; +use crate::aleo_program::{BooleanExpression, Statement}; use crate::{aleo_program::types, ast}; impl<'ast> From> for types::FieldExpression { @@ -38,6 +38,12 @@ impl<'ast> From> for types::Expression { } } +impl<'ast> From> for types::Variable { + fn from(variable: ast::Variable<'ast>) -> Self { + types::Variable(variable.value) + } +} + impl<'ast> From> for types::FieldExpression { fn from(variable: ast::Variable<'ast>) -> Self { types::FieldExpression::Variable(types::Variable(variable.value)) @@ -220,12 +226,6 @@ impl<'ast> From> for types::Expression { } } -impl<'ast> From> for types::Variable { - fn from(variable: ast::Variable<'ast>) -> Self { - types::Variable(variable.value) - } -} - impl<'ast> From> for types::Statement { fn from(statement: ast::AssignStatement<'ast>) -> Self { types::Statement::Definition( @@ -250,7 +250,7 @@ impl<'ast> From> for types::Statement { impl<'ast> From> for types::Statement { fn from(statement: ast::Statement<'ast>) -> Self { match statement { - ast::Statement::Assign(_statement) => unimplemented!(), + ast::Statement::Assign(statement) => types::Statement::from(statement), ast::Statement::Return(statement) => types::Statement::from(statement), } } @@ -258,8 +258,9 @@ impl<'ast> From> for types::Statement { impl<'ast> From> for types::Program { fn from(file: ast::File<'ast>) -> Self { - let statements = file - .statement + // 1. compile ast -> aleo program representation + let statements: Vec = file + .statements .into_iter() .map(|statement| types::Statement::from(statement)) .collect(); diff --git a/src/ast.rs b/src/ast.rs index 6a6801b703..fbf5257f47 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -139,7 +139,7 @@ fn binary_expression<'ast>( #[derive(Clone, Debug, FromPest, PartialEq)] #[pest_ast(rule(Rule::file))] pub struct File<'ast> { - pub statement: Vec>, + pub statements: Vec>, pub eoi: EOI, #[pest_ast(outer())] pub span: Span<'ast>, diff --git a/src/main.rs b/src/main.rs index 5b6b8a74bb..e624bf55c9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -50,7 +50,7 @@ impl ConstraintSynthesizer for Benchmark { let program = aleo_program::Program::from(syntax_tree); println!(" compiled: {:#?}", program); - aleo_program::generate_constraints(cs, program); + aleo_program::ResolvedProgram::generate_constraints(cs, program); Ok(()) } diff --git a/src/zokrates_program/types_from.rs b/src/zokrates_program/types_from.rs index f9ad4dddec..f475e0c481 100644 --- a/src/zokrates_program/types_from.rs +++ b/src/zokrates_program/types_from.rs @@ -170,7 +170,7 @@ impl<'ast> From> for program::Program<'ast> { fn from(file: ast::File<'ast>) -> Self { program::Program { nodes: file - .statement + .statements .iter() .map(|statement| types::StatementNode::from(statement.clone())) .collect(),