diff --git a/simple.program b/simple.program index f7d97643bd..8f3857339d 100644 --- a/simple.program +++ b/simple.program @@ -1,2 +1 @@ -x = 5 + a -return x ** 2 * 2 \ No newline at end of file +return if a then 1 else 0 fi \ No newline at end of file diff --git a/src/aleo_program/constraints.rs b/src/aleo_program/constraints.rs index 12570134de..0a24a301b0 100644 --- a/src/aleo_program/constraints.rs +++ b/src/aleo_program/constraints.rs @@ -178,6 +178,7 @@ impl ResolvedProgram { ) -> Boolean { match expression { BooleanExpression::Variable(variable) => self.bool_from_variable(cs, variable), + BooleanExpression::Value(value) => Boolean::Constant(value), BooleanExpression::Not(expression) => self.enforce_not(cs, *expression), BooleanExpression::Or(left, right) => self.enforce_or(cs, *left, *right), BooleanExpression::And(left, right) => self.enforce_and(cs, *left, *right), @@ -185,6 +186,16 @@ impl ResolvedProgram { BooleanExpression::FieldEq(left, right) => { self.enforce_field_equality(cs, *left, *right) } + BooleanExpression::IfElse(first, second, third) => { + if self + .enforce_boolean_expression(cs, *first) + .eq(&Boolean::Constant(true)) + { + self.enforce_boolean_expression(cs, *second) + } else { + self.enforce_boolean_expression(cs, *third) + } + } _ => unimplemented!(), } } @@ -285,12 +296,22 @@ impl ResolvedProgram { ) -> UInt32 { match expression { FieldExpression::Variable(variable) => self.u32_from_variable(cs, variable), + FieldExpression::Number(number) => UInt32::constant(number), FieldExpression::Add(left, right) => self.enforce_add(cs, *left, *right), FieldExpression::Sub(left, right) => self.enforce_sub(cs, *left, *right), FieldExpression::Mul(left, right) => self.enforce_mul(cs, *left, *right), FieldExpression::Div(left, right) => self.enforce_div(cs, *left, *right), FieldExpression::Pow(left, right) => self.enforce_pow(cs, *left, *right), - _ => unimplemented!(), + FieldExpression::IfElse(first, second, third) => { + if self + .enforce_boolean_expression(cs, *first) + .eq(&Boolean::Constant(true)) + { + self.enforce_field_expression(cs, *second) + } else { + self.enforce_field_expression(cs, *third) + } + } } } @@ -309,7 +330,7 @@ impl ResolvedProgram { let res = resolved_program.enforce_boolean_expression(cs, boolean_expression); println!( - "variable boolean result: {} = {}", + " variable boolean result: {} = {}", variable.0, res.get_value().unwrap() ); @@ -376,12 +397,12 @@ impl ResolvedProgram { Expression::Boolean(boolean_expression) => { let res = resolved_program .enforce_boolean_expression(cs, boolean_expression); - println!("boolean result: {}\n", res.get_value().unwrap()); + println!("\n Boolean result = {}", res.get_value().unwrap()); } Expression::FieldElement(field_expression) => { let res = resolved_program.enforce_field_expression(cs, field_expression); - println!("field result: {}\n", res.value.unwrap()); + println!("\n Field result = {}", res.value.unwrap()); } Expression::Variable(variable) => { match resolved_program @@ -391,11 +412,11 @@ impl ResolvedProgram { .clone() { ResolvedValue::Boolean(boolean) => println!( - "variable result: {}\n", + "\n Variable result = {}", boolean.get_value().unwrap() ), ResolvedValue::FieldElement(field_element) => println!( - "variable field result: {}\n", + "\n Variable field result = {}", field_element.value.unwrap() ), } diff --git a/src/aleo_program/gadgets.rs b/src/aleo_program/gadgets.rs deleted file mode 100644 index 0c14a17cd1..0000000000 --- a/src/aleo_program/gadgets.rs +++ /dev/null @@ -1,61 +0,0 @@ -// use snarkos_errors::gadgets::SynthesisError; -// use snarkos_models::{ -// curves::Field, -// gadgets::{ -// r1cs::ConstraintSystem, -// utilities::uint32::UInt32 -// } -// }; -// use snarkos_models::gadgets::utilities::boolean::Boolean; -// -// impl UInt32 { -// pub fn and>(&self, mut cs: CS, other: &Self) -> Result { -// let value= match (self.value, other.valoue) { -// (Some(a), Some(b)) => Some(a & b), -// _=> None, -// }; -// -// let bits = self -// .bits -// .iter() -// .zip(other.bits.iter()) -// .enumerate() -// .map(|(i, (a, b)) | Boolean::and(cs.ns(|| format!("and of bit gadget {}", i)), a, b)) -// .collect(); -// -// Ok(UInt32 { bits, value }) -// } -// -// fn recursive_add>(mut cs: CS, a: &Self, b: &Self) -> Result { -// let uncommon_bits = a.xor(cs.ns(|| format!("{} ^ {}", a.value.unwrap(), b.value.unwrap())),&b)?; -// let common_bits = a.and(cs.ns(|| format!("{} & {}", a.value.unwrap(), b.value.unwrap())), &b)?; -// -// if common_bits.value == 0 { -// return Ok(uncommon_bits) -// } -// let shifted_common_bits = common_bits.rotr(common_bits.bits.len() - 1); -// return Self::recursive_add(cs.ns(|| format!("recursive add {} + {}", uncommon_bits.value, shifted_common_bits.value)), &uncommon_bits, &shifted_common_bits) -// } -// -// pub fn add>(&self, mut cs: CS, other: &Self) -> Result { -// let new_value = match (self.value, other.value) { -// (Some(a), Some(b)) => Some(a + b), -// _ => None, -// }; -// -// return Self::recursive_add(cs.ns( || format!("recursive add {} + {}", self.value, other.value)), &self, &other) -// -// // let bits = self -// // .bits -// // .iter() -// // .zip(other.bits.iter()) -// // .enumerate() -// // .map(|(i, (a, b))| Boo) -// } -// -// pub fn sub>(&self, mut cs: CS, other: &Self) -> Result {} -// -// pub fn mul>(&self, mut cs: CS, other: &Self) -> Result {} -// -// pub fn div>(&self, mut cs: CS, other: &Self) -> Result {} -// } diff --git a/src/aleo_program/mod.rs b/src/aleo_program/mod.rs index 8de9c79985..fa5dcbb1f4 100644 --- a/src/aleo_program/mod.rs +++ b/src/aleo_program/mod.rs @@ -10,9 +10,6 @@ pub use self::types::*; pub mod constraints; pub use self::constraints::*; -pub mod gadgets; -pub use self::gadgets::*; - pub mod types_display; pub use self::types_display::*; diff --git a/src/aleo_program/types.rs b/src/aleo_program/types.rs index 60bf0bedc7..9b72dd7d8e 100644 --- a/src/aleo_program/types.rs +++ b/src/aleo_program/types.rs @@ -33,11 +33,13 @@ pub struct Variable(pub String); pub enum FieldExpression { Variable(Variable), Number(u32), + // Operators Add(Box, Box), Sub(Box, Box), Mul(Box, Box), Div(Box, Box), Pow(Box, Box), + // Conditionals IfElse( Box, Box, @@ -61,6 +63,12 @@ pub enum BooleanExpression { Gt(Box, Box), Leq(Box, Box), Lt(Box, Box), + // Conditionals + IfElse( + Box, + Box, + Box, + ), } /// Expression that evaluates to a value diff --git a/src/aleo_program/types_display.rs b/src/aleo_program/types_display.rs index 398b044b70..7aedc4127c 100644 --- a/src/aleo_program/types_display.rs +++ b/src/aleo_program/types_display.rs @@ -24,7 +24,9 @@ impl<'ast> fmt::Display for FieldExpression { FieldExpression::Mul(ref lhs, ref rhs) => write!(f, "{} * {}", lhs, rhs), FieldExpression::Div(ref lhs, ref rhs) => write!(f, "{} / {}", lhs, rhs), FieldExpression::Pow(ref lhs, ref rhs) => write!(f, "{} ** {}", lhs, rhs), - FieldExpression::IfElse(ref _a, ref _b, ref _c) => unimplemented!(), + FieldExpression::IfElse(ref a, ref b, ref c) => { + write!(f, "if {} then {} else {} fi", a, b, c) + } } } } @@ -44,6 +46,9 @@ impl<'ast> fmt::Display for BooleanExpression { BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs), BooleanExpression::Leq(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs), BooleanExpression::Lt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs), + BooleanExpression::IfElse(ref a, ref b, ref c) => { + write!(f, "if {} then {} else {} fi", a, b, c) + } } } } diff --git a/src/aleo_program/types_from.rs b/src/aleo_program/types_from.rs index b6605f81d3..ede1f0e8c1 100644 --- a/src/aleo_program/types_from.rs +++ b/src/aleo_program/types_from.rs @@ -5,7 +5,6 @@ //! @author Collin Chin //! @date 2020 -use crate::aleo_program::{BooleanExpression, Statement}; use crate::{aleo_program::types, ast}; impl<'ast> From> for types::FieldExpression { @@ -64,7 +63,7 @@ impl<'ast> From> for types::Expression { impl<'ast> From> for types::Expression { fn from(expression: ast::NotExpression<'ast>) -> Self { - types::Expression::Boolean(BooleanExpression::Not(Box::new( + types::Expression::Boolean(types::BooleanExpression::Not(Box::new( types::BooleanExpression::from(*expression.expression), ))) } @@ -119,7 +118,7 @@ impl<'ast> types::BooleanExpression { Box::new(types::BooleanExpression::Variable(lhs)), Box::new(rhs), ) - } + } //TODO: check case for two variables? // Field equality (types::Expression::FieldElement(lhs), types::Expression::FieldElement(rhs)) => { types::BooleanExpression::FieldEq(Box::new(lhs), Box::new(rhs)) @@ -215,6 +214,69 @@ impl<'ast> From> for types::Expression { } } +impl<'ast> From> for types::Expression { + fn from(expression: ast::TernaryExpression<'ast>) -> Self { + // Evaluate expressions to find out result type + let first = types::BooleanExpression::from(*expression.first); + let second = types::Expression::from(*expression.second); + let third = types::Expression::from(*expression.third); + + match (second, third) { + // Boolean Result + (types::Expression::Boolean(second), types::Expression::Boolean(third)) => { + types::Expression::Boolean(types::BooleanExpression::IfElse( + Box::new(first), + Box::new(second), + Box::new(third), + )) + } + (types::Expression::Boolean(second), types::Expression::Variable(third)) => { + types::Expression::Boolean(types::BooleanExpression::IfElse( + Box::new(first), + Box::new(second), + Box::new(types::BooleanExpression::Variable(third)), + )) + } + (types::Expression::Variable(second), types::Expression::Boolean(third)) => { + types::Expression::Boolean(types::BooleanExpression::IfElse( + Box::new(first), + Box::new(types::BooleanExpression::Variable(second)), + Box::new(third), + )) + } + // Field Result + (types::Expression::FieldElement(second), types::Expression::FieldElement(third)) => { + types::Expression::FieldElement(types::FieldExpression::IfElse( + Box::new(first), + Box::new(second), + Box::new(third), + )) + } + (types::Expression::FieldElement(second), types::Expression::Variable(third)) => { + types::Expression::FieldElement(types::FieldExpression::IfElse( + Box::new(first), + Box::new(second), + Box::new(types::FieldExpression::Variable(third)), + )) + } + (types::Expression::Variable(second), types::Expression::FieldElement(third)) => { + types::Expression::FieldElement(types::FieldExpression::IfElse( + Box::new(first), + Box::new(types::FieldExpression::Variable(second)), + Box::new(third), + )) + } + + (second, third) => unimplemented!( + "pattern if {} then {} else {} unimplemented", + first, + second, + third + ), + } + } +} + impl<'ast> From> for types::Expression { fn from(expression: ast::Expression<'ast>) -> Self { match expression { @@ -222,6 +284,7 @@ impl<'ast> From> for types::Expression { ast::Expression::Variable(variable) => types::Expression::from(variable), ast::Expression::Not(expression) => types::Expression::from(expression), ast::Expression::Binary(expression) => types::Expression::from(expression), + ast::Expression::Ternary(expression) => types::Expression::from(expression), } } } @@ -259,7 +322,7 @@ impl<'ast> From> for types::Statement { impl<'ast> From> for types::Program { fn from(file: ast::File<'ast>) -> Self { // 1. compile ast -> aleo program representation - let statements: Vec = file + let statements: Vec = file .statements .into_iter() .map(|statement| types::Statement::from(statement)) diff --git a/src/ast.rs b/src/ast.rs index fbf5257f47..8041576618 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -76,6 +76,12 @@ fn parse_term(pair: Pair) -> Box { let expression = parse_term(inner.next().unwrap()); Expression::Not(NotExpression { operation, expression, span }) }, + Rule::expression_conditional => { + println!("conditional expression"); + Expression::Ternary( + TernaryExpression::from_pest(&mut pair.into_inner()).unwrap(), + ) + } Rule::expression => Expression::from_pest(&mut pair.into_inner()).unwrap(), // Parenthesis case // Rule::expression_increment => { @@ -315,6 +321,16 @@ pub struct BinaryExpression<'ast> { pub span: Span<'ast>, } +#[derive(Clone, Debug, FromPest, PartialEq)] +#[pest_ast(rule(Rule::expression_conditional))] +pub struct TernaryExpression<'ast> { + pub first: Box>, + pub second: Box>, + pub third: Box>, + #[pest_ast(outer())] + pub span: Span<'ast>, +} + // #[derive(Clong, Debug, PartialEq)] // pub struct IdentifierExpression<'ast> { // pub value: String, @@ -327,6 +343,7 @@ pub enum Expression<'ast> { Variable(Variable<'ast>), Not(NotExpression<'ast>), Binary(BinaryExpression<'ast>), + Ternary(TernaryExpression<'ast>), // Increment(IncrementExpression<'ast>), // Decrement(DecrementExpression<'ast>), } @@ -346,12 +363,27 @@ impl<'ast> Expression<'ast> { }) } + pub fn ternary( + first: Box>, + second: Box>, + third: Box>, + span: Span<'ast>, + ) -> Self { + Expression::Ternary(TernaryExpression { + first, + second, + third, + span, + }) + } + pub fn span(&self) -> &Span<'ast> { match self { Expression::Value(expression) => &expression.span(), Expression::Variable(expression) => &expression.span, Expression::Not(expression) => &expression.span, Expression::Binary(expression) => &expression.span, + Expression::Ternary(expression) => &expression.span, // Expression::Increment(expression) => &expression.span, // Expression::Decrement(expression) => &expression.span, } @@ -367,6 +399,11 @@ impl<'ast> fmt::Display for Expression<'ast> { Expression::Binary(ref expression) => { write!(f, "{} == {}", expression.left, expression.right) } + Expression::Ternary(ref expression) => write!( + f, + "if {} then {} else {} fi", + expression.first, expression.second, expression.third + ), } } } diff --git a/src/language.pest b/src/language.pest index 533162d13b..0a6a5cbb8e 100644 --- a/src/language.pest +++ b/src/language.pest @@ -72,13 +72,24 @@ variable = @{ ((!protected_name ~ ASCII_ALPHA) | (protected_name ~ (ASCII_ALPHAN // Consider structs, conditionals, postfix, primary, inline array, array initializer, and unary expression_primitive = { value | variable } expression_not = { operation_pre_not ~ expression_term } -expression_term = { expression_primitive | expression_not | ("(" ~ expression ~ ")") } +expression_term = { ("(" ~ expression ~ ")") | expression_conditional | expression_primitive | expression_not} expression = { expression_term ~ (operation_binary ~ expression_term)* } - // expression_increment = { expression ~ operation_post_increment } // expression_decrement = { expression ~ operation_post_decrement } +// Conditionals + +expression_conditional = { "if" ~ expression ~ "then" ~ expression ~ "else" ~ expression ~ "fi"} +// conditional_if = { "if" } +// conditional_else = { "else" } +// +// conditional_for = { "for" } +// +// conditional = { conditional_if | conditional_else | conditional_for } + + + expression_tuple = _{ (expression ~ ("," ~ expression)*)? } /// Statements @@ -93,16 +104,6 @@ statement = { (statement_return | (statement_assign) ~ NEWLINE) ~ NEWLINE* } COMMENT = _{ ("/*" ~ (!"*/" ~ ANY)* ~ "*/") | ("//" ~ (!NEWLINE ~ ANY)*) } WHITESPACE = _{ " " | "\t" ~ NEWLINE } - -// /// Conditionals -// -// conditional_if = { "if" } -// conditional_else = { "else" } -// -// conditional_for = { "for" } -// -// conditional = { conditional_if | conditional_else | conditional_for } - // /// Helpers // // helper_range = { expression+ ~ ".." ~ expression+ } // Confirm that '+' is the correct repetition diff --git a/src/lib.rs b/src/lib.rs index 93017c28c2..b63467e415 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,4 +13,4 @@ pub mod ast; pub mod aleo_program; -pub mod zokrates_program; +// pub mod zokrates_program;