From 34d8a552e79aafb66e07f36d39812c85fb1264b4 Mon Sep 17 00:00:00 2001 From: collin Date: Thu, 23 Apr 2020 15:24:05 -0700 Subject: [PATCH] refactor type resolution --- simple.leo | 9 +- src/program/constraints/boolean.rs | 322 ++++---- src/program/constraints/expression.rs | 454 ++++++++--- src/program/constraints/field_element.rs | 337 ++++---- src/program/constraints/integer.rs | 499 ++++++------ src/program/constraints/resolved_value.rs | 81 +- src/program/constraints/statement.rs | 41 +- src/program/types.rs | 247 +++--- src/program/types_display.rs | 288 ++++--- src/program/types_from.rs | 905 +++++++++++----------- 10 files changed, 1766 insertions(+), 1417 deletions(-) diff --git a/simple.leo b/simple.leo index f1a2948819..f95c5f4a8b 100644 --- a/simple.leo +++ b/simple.leo @@ -1,5 +1,6 @@ -def foo() -> (fe): - return 3fe - def main() -> (fe): - return foo() \ No newline at end of file + a = 1fe + for i in 0..4 do + a = a + 1fe + endfor + return a \ No newline at end of file diff --git a/src/program/constraints/boolean.rs b/src/program/constraints/boolean.rs index 28ee30f41c..3ec10c8edd 100644 --- a/src/program/constraints/boolean.rs +++ b/src/program/constraints/boolean.rs @@ -4,10 +4,8 @@ //! @author Collin Chin //! @date 2020 -use crate::program::constraints::{new_scope_from_variable, ResolvedProgram, ResolvedValue}; -use crate::program::{ - new_variable_from_variable, BooleanExpression, BooleanSpreadOrExpression, Parameter, Variable, -}; +use crate::program::constraints::{ResolvedProgram, ResolvedValue}; +use crate::program::{new_variable_from_variable, Parameter, Variable}; use snarkos_models::curves::{Field, PrimeField}; use snarkos_models::gadgets::{ @@ -95,181 +93,177 @@ impl> ResolvedProgram { // parameter_variable } - pub(crate) fn bool_from_variable( + pub(crate) fn get_boolean_constant(bool: bool) -> ResolvedValue { + ResolvedValue::Boolean(Boolean::Constant(bool)) + } + + // pub(crate) fn bool_from_variable(&mut self, scope: String, variable: Variable) -> Boolean { + // // Evaluate variable name in current function scope + // let variable_name = new_scope_from_variable(scope, &variable); + // + // match self.get(&variable_name) { + // Some(value) => match value { + // ResolvedValue::Boolean(boolean) => boolean.clone(), + // value => unimplemented!( + // "expected boolean for variable {}, got {}", + // variable_name, + // value + // ), + // }, + // None => unimplemented!("cannot resolve variable {} in program", variable_name), + // } + // } + + // fn get_bool_value( + // &mut self, + // cs: &mut CS, + // scope: String, + // expression: BooleanExpression, + // ) -> Boolean { + // match expression { + // BooleanExpression::Variable(variable) => self.bool_from_variable(scope, variable), + // BooleanExpression::Value(value) => Boolean::Constant(value), + // expression => match self.enforce_boolean_expression(cs, scope, expression) { + // ResolvedValue::Boolean(value) => value, + // _ => unimplemented!("boolean expression did not resolve to boolean"), + // }, + // } + // } + + pub(crate) fn enforce_not(value: ResolvedValue) -> ResolvedValue { + match value { + ResolvedValue::Boolean(boolean) => ResolvedValue::Boolean(boolean.not()), + value => unimplemented!("cannot enforce not on non-boolean value {}", value), + } + } + + pub(crate) fn enforce_or( &mut self, cs: &mut CS, - scope: String, - variable: Variable, - ) -> Boolean { - // Evaluate variable name in current function scope - let variable_name = new_scope_from_variable(scope, &variable); - - if self.contains_name(&variable_name) { - // TODO: return synthesis error: "assignment missing" here - match self.get(&variable_name).unwrap() { - ResolvedValue::Boolean(boolean) => boolean.clone(), - _ => panic!("expected a boolean, got field"), + left: ResolvedValue, + right: ResolvedValue, + ) -> ResolvedValue { + match (left, right) { + (ResolvedValue::Boolean(left_bool), ResolvedValue::Boolean(right_bool)) => { + ResolvedValue::Boolean(Boolean::or(cs, &left_bool, &right_bool).unwrap()) } - } else { - let argument = std::env::args() - .nth(1) - .unwrap_or("true".into()) - .parse::() - .unwrap(); - println!(" argument passed to command line a = {:?}\n", argument); - // let a = true; - Boolean::alloc(cs.ns(|| variable.name), || Ok(argument)).unwrap() + (left_value, right_value) => unimplemented!( + "cannot enforce or on non-boolean values {} || {}", + left_value, + right_value + ), } } - fn get_bool_value( + pub(crate) fn enforce_and( &mut self, cs: &mut CS, - scope: String, - expression: BooleanExpression, - ) -> Boolean { - match expression { - BooleanExpression::Variable(variable) => self.bool_from_variable(cs, scope, variable), - BooleanExpression::Value(value) => Boolean::Constant(value), - expression => match self.enforce_boolean_expression(cs, scope, expression) { - ResolvedValue::Boolean(value) => value, - _ => unimplemented!("boolean expression did not resolve to boolean"), - }, + left: ResolvedValue, + right: ResolvedValue, + ) -> ResolvedValue { + match (left, right) { + (ResolvedValue::Boolean(left_bool), ResolvedValue::Boolean(right_bool)) => { + ResolvedValue::Boolean(Boolean::and(cs, &left_bool, &right_bool).unwrap()) + } + (left_value, right_value) => unimplemented!( + "cannot enforce and on non-boolean values {} && {}", + left_value, + right_value + ), } } - fn enforce_not( + pub(crate) fn enforce_boolean_eq( &mut self, cs: &mut CS, - scope: String, - expression: BooleanExpression, - ) -> Boolean { - let expression = self.get_bool_value(cs, scope, expression); - - expression.not() - } - - fn enforce_or( - &mut self, - cs: &mut CS, - scope: String, - left: BooleanExpression, - right: BooleanExpression, - ) -> Boolean { - let left = self.get_bool_value(cs, scope.clone(), left); - let right = self.get_bool_value(cs, scope.clone(), right); - - Boolean::or(cs, &left, &right).unwrap() - } - - fn enforce_and( - &mut self, - cs: &mut CS, - scope: String, - left: BooleanExpression, - right: BooleanExpression, - ) -> Boolean { - let left = self.get_bool_value(cs, scope.clone(), left); - let right = self.get_bool_value(cs, scope.clone(), right); - - Boolean::and(cs, &left, &right).unwrap() - } - - fn enforce_bool_equality( - &mut self, - cs: &mut CS, - scope: String, - left: BooleanExpression, - right: BooleanExpression, - ) -> Boolean { - let left = self.get_bool_value(cs, scope.clone(), left); - let right = self.get_bool_value(cs, scope.clone(), right); - + left: Boolean, + right: Boolean, + ) -> ResolvedValue { left.enforce_equal(cs.ns(|| format!("enforce bool equal")), &right) .unwrap(); - Boolean::Constant(true) - } - - pub(crate) fn enforce_boolean_expression( - &mut self, - cs: &mut CS, - scope: String, - expression: BooleanExpression, - ) -> ResolvedValue { - match expression { - BooleanExpression::Variable(variable) => { - ResolvedValue::Boolean(self.bool_from_variable(cs, scope, variable)) - } - BooleanExpression::Value(value) => ResolvedValue::Boolean(Boolean::Constant(value)), - BooleanExpression::Not(expression) => { - ResolvedValue::Boolean(self.enforce_not(cs, scope, *expression)) - } - BooleanExpression::Or(left, right) => { - ResolvedValue::Boolean(self.enforce_or(cs, scope, *left, *right)) - } - BooleanExpression::And(left, right) => { - ResolvedValue::Boolean(self.enforce_and(cs, scope, *left, *right)) - } - BooleanExpression::IntegerEq(left, right) => { - ResolvedValue::Boolean(self.enforce_integer_equality(cs, scope, *left, *right)) - } - BooleanExpression::FieldEq(left, right) => { - ResolvedValue::Boolean(self.enforce_field_equality(cs, scope, *left, *right)) - } - BooleanExpression::BoolEq(left, right) => { - ResolvedValue::Boolean(self.enforce_bool_equality(cs, scope, *left, *right)) - } - BooleanExpression::IfElse(first, second, third) => { - let resolved_first = - match self.enforce_boolean_expression(cs, scope.clone(), *first) { - ResolvedValue::Boolean(resolved) => resolved, - _ => unimplemented!("if else conditional must resolve to boolean"), - }; - if resolved_first.eq(&Boolean::Constant(true)) { - self.enforce_boolean_expression(cs, scope, *second) - } else { - self.enforce_boolean_expression(cs, scope, *third) - } - } - BooleanExpression::Array(array) => { - let mut result = vec![]; - array.into_iter().for_each(|element| match *element { - BooleanSpreadOrExpression::Spread(spread) => match spread { - BooleanExpression::Variable(variable) => { - let array_name = new_scope_from_variable(scope.clone(), &variable); - match self.get(&array_name) { - Some(value) => match value { - ResolvedValue::BooleanArray(array) => { - result.extend(array.clone()) - } - value => unimplemented!( - "spreads only implemented for arrays, got {}", - value - ), - }, - None => unimplemented!( - "cannot copy elements from array that does not exist {}", - variable.name - ), - } - } - value => { - unimplemented!("spreads only implemented for arrays, got {}", value) - } - }, - BooleanSpreadOrExpression::Expression(expression) => { - match self.enforce_boolean_expression(cs, scope.clone(), expression) { - ResolvedValue::Boolean(value) => result.push(value), - value => { - unimplemented!("expected boolean for boolean array, got {}", value) - } - } - } - }); - ResolvedValue::BooleanArray(result) - } - expression => unimplemented!("boolean expression {}", expression), - } + ResolvedValue::Boolean(Boolean::Constant(true)) } + // + // pub(crate) fn enforce_boolean_expression( + // &mut self, + // cs: &mut CS, + // scope: String, + // expression: BooleanExpression, + // ) -> ResolvedValue { + // match expression { + // BooleanExpression::Variable(variable) => { + // ResolvedValue::Boolean(self.bool_from_variable(cs, scope, variable)) + // } + // BooleanExpression::Value(value) => ResolvedValue::Boolean(Boolean::Constant(value)), + // BooleanExpression::Not(expression) => { + // ResolvedValue::Boolean(self.enforce_not(cs, scope, *expression)) + // } + // BooleanExpression::Or(left, right) => { + // ResolvedValue::Boolean(self.enforce_or(cs, scope, *left, *right)) + // } + // BooleanExpression::And(left, right) => { + // ResolvedValue::Boolean(self.enforce_and(cs, scope, *left, *right)) + // } + // BooleanExpression::IntegerEq(left, right) => { + // ResolvedValue::Boolean(self.enforce_integer_equality(cs, scope, *left, *right)) + // } + // BooleanExpression::FieldEq(left, right) => { + // ResolvedValue::Boolean(self.enforce_field_equality(cs, scope, *left, *right)) + // } + // BooleanExpression::BoolEq(left, right) => { + // ResolvedValue::Boolean(self.enforce_bool_equality(cs, scope, *left, *right)) + // } + // BooleanExpression::IfElse(first, second, third) => { + // let resolved_first = + // match self.enforce_boolean_expression(cs, scope.clone(), *first) { + // ResolvedValue::Boolean(resolved) => resolved, + // _ => unimplemented!("if else conditional must resolve to boolean"), + // }; + // if resolved_first.eq(&Boolean::Constant(true)) { + // self.enforce_boolean_expression(cs, scope, *second) + // } else { + // self.enforce_boolean_expression(cs, scope, *third) + // } + // } + // BooleanExpression::Array(array) => { + // let mut result = vec![]; + // array.into_iter().for_each(|element| match *element { + // BooleanSpreadOrExpression::Spread(spread) => match spread { + // BooleanExpression::Variable(variable) => { + // let array_name = new_scope_from_variable(scope.clone(), &variable); + // match self.get(&array_name) { + // Some(value) => match value { + // ResolvedValue::BooleanArray(array) => { + // result.extend(array.clone()) + // } + // value => unimplemented!( + // "spreads only implemented for arrays, got {}", + // value + // ), + // }, + // None => unimplemented!( + // "cannot copy elements from array that does not exist {}", + // variable.name + // ), + // } + // } + // value => { + // unimplemented!("spreads only implemented for arrays, got {}", value) + // } + // }, + // BooleanSpreadOrExpression::Expression(expression) => { + // match self.enforce_boolean_expression(cs, scope.clone(), expression) { + // ResolvedValue::Boolean(value) => result.push(value), + // value => { + // unimplemented!("expected boolean for boolean array, got {}", value) + // } + // } + // } + // }); + // ResolvedValue::BooleanArray(result) + // } + // expression => unimplemented!("boolean expression {}", expression), + // } + // } } diff --git a/src/program/constraints/expression.rs b/src/program/constraints/expression.rs index f36dfae8fa..f2561cff5b 100644 --- a/src/program/constraints/expression.rs +++ b/src/program/constraints/expression.rs @@ -5,14 +5,256 @@ //! @date 2020 use crate::program::constraints::{new_scope_from_variable, ResolvedProgram, ResolvedValue}; -use crate::program::{ - Expression, IntegerExpression, IntegerRangeOrExpression, StructMember, Variable, -}; +use crate::program::{Expression, RangeOrExpression, SpreadOrExpression, StructMember, Variable}; use snarkos_models::curves::{Field, PrimeField}; use snarkos_models::gadgets::r1cs::ConstraintSystem; +use snarkos_models::gadgets::utilities::boolean::Boolean; impl> ResolvedProgram { + /// Enforce a variable expression by getting the resolved value + fn enforce_variable( + &mut self, + scope: String, + unresolved_variable: Variable, + ) -> ResolvedValue { + // Evaluate the variable name in the current function scope + let variable_name = new_scope_from_variable(scope, &unresolved_variable); + + if self.contains_name(&variable_name) { + // Reassigning variable to another variable + self.get_mut(&variable_name).unwrap().clone() + } else if self.contains_variable(&unresolved_variable) { + // Check global scope (function and struct names) + self.get_mut_variable(&unresolved_variable).unwrap().clone() + } else { + unimplemented!("variable declaration {} not found", variable_name) + } + } + + /// Enforce numerical operations + fn enforce_add_expression( + &mut self, + cs: &mut CS, + left: ResolvedValue, + right: ResolvedValue, + ) -> ResolvedValue { + match (left, right) { + (ResolvedValue::U32(num1), ResolvedValue::U32(num2)) => { + Self::enforce_u32_add(cs, num1, num2) + } + (ResolvedValue::FieldElement(fe1), ResolvedValue::FieldElement(fe2)) => { + self.enforce_field_add(fe1, fe2) + } + (val1, val2) => unimplemented!("cannot add {} + {}", val1, val2), + } + } + + fn enforce_sub_expression( + &mut self, + cs: &mut CS, + left: ResolvedValue, + right: ResolvedValue, + ) -> ResolvedValue { + match (left, right) { + (ResolvedValue::U32(num1), ResolvedValue::U32(num2)) => { + Self::enforce_u32_sub(cs, num1, num2) + } + (ResolvedValue::FieldElement(fe1), ResolvedValue::FieldElement(fe2)) => { + self.enforce_field_sub(fe1, fe2) + } + (val1, val2) => unimplemented!("cannot subtract {} - {}", val1, val2), + } + } + + fn enforce_mul_expression( + &mut self, + cs: &mut CS, + left: ResolvedValue, + right: ResolvedValue, + ) -> ResolvedValue { + match (left, right) { + (ResolvedValue::U32(num1), ResolvedValue::U32(num2)) => { + Self::enforce_u32_mul(cs, num1, num2) + } + (ResolvedValue::FieldElement(fe1), ResolvedValue::FieldElement(fe2)) => { + self.enforce_field_mul(fe1, fe2) + } + (val1, val2) => unimplemented!("cannot multiply {} * {}", val1, val2), + } + } + + fn enforce_div_expression( + &mut self, + cs: &mut CS, + left: ResolvedValue, + right: ResolvedValue, + ) -> ResolvedValue { + match (left, right) { + (ResolvedValue::U32(num1), ResolvedValue::U32(num2)) => { + Self::enforce_u32_div(cs, num1, num2) + } + (ResolvedValue::FieldElement(fe1), ResolvedValue::FieldElement(fe2)) => { + self.enforce_field_div(fe1, fe2) + } + (val1, val2) => unimplemented!("cannot multiply {} * {}", val1, val2), + } + } + fn enforce_pow_expression( + &mut self, + cs: &mut CS, + left: ResolvedValue, + right: ResolvedValue, + ) -> ResolvedValue { + match (left, right) { + (ResolvedValue::U32(num1), ResolvedValue::U32(num2)) => { + Self::enforce_u32_pow(cs, num1, num2) + } + (ResolvedValue::FieldElement(fe1), ResolvedValue::FieldElement(fe2)) => { + self.enforce_field_pow(fe1, fe2) + } + (val1, val2) => unimplemented!("cannot multiply {} * {}", val1, val2), + } + } + + /// Enforce Boolean operations + fn enforce_eq_expression( + &mut self, + cs: &mut CS, + left: ResolvedValue, + right: ResolvedValue, + ) -> ResolvedValue { + match (left, right) { + (ResolvedValue::Boolean(bool1), ResolvedValue::Boolean(bool2)) => { + self.enforce_boolean_eq(cs, bool1, bool2) + } + (ResolvedValue::U32(num1), ResolvedValue::U32(num2)) => { + Self::enforce_u32_eq(cs, num1, num2) + } + (ResolvedValue::FieldElement(fe1), ResolvedValue::FieldElement(fe2)) => { + self.enforce_field_eq(fe1, fe2) + } + (val1, val2) => unimplemented!("cannot enforce equality between {} == {}", val1, val2), + } + } + + /// Enforce array expressions + fn enforce_array_expression( + &mut self, + cs: &mut CS, + scope: String, + array: Vec>>, + ) -> ResolvedValue { + let mut result = vec![]; + array.into_iter().for_each(|element| match *element { + SpreadOrExpression::Spread(spread) => match spread { + Expression::Variable(variable) => { + let array_name = new_scope_from_variable(scope.clone(), &variable); + match self.get(&array_name) { + Some(value) => match value { + ResolvedValue::Array(array) => result.extend(array.clone()), + value => { + unimplemented!("spreads only implemented for arrays, got {}", value) + } + }, + None => unimplemented!( + "cannot copy elements from array that does not exist {}", + variable.name + ), + } + } + value => unimplemented!("spreads only implemented for arrays, got {}", value), + }, + SpreadOrExpression::Expression(expression) => { + result.push(self.enforce_expression(cs, scope.clone(), expression)); + } + }); + ResolvedValue::Array(result) + } + + pub(crate) fn enforce_index( + &mut self, + cs: &mut CS, + scope: String, + index: Expression, + ) -> usize { + match self.enforce_expression(cs, scope.clone(), index) { + ResolvedValue::U32(number) => number.value.unwrap() as usize, + value => unimplemented!("From index must resolve to an integer, got {}", value), + } + } + + fn enforce_array_access_expression( + &mut self, + cs: &mut CS, + scope: String, + array: Box>, + index: RangeOrExpression, + ) -> ResolvedValue { + match self.enforce_expression(cs, scope.clone(), *array) { + ResolvedValue::Array(array) => { + match index { + RangeOrExpression::Range(from, to) => { + let from_resolved = match from { + Some(from_index) => from_index.to_usize(), + None => 0usize, // Array slice starts at index 0 + }; + let to_resolved = match to { + Some(to_index) => to_index.to_usize(), + None => array.len(), // Array slice ends at array length + }; + ResolvedValue::Array(array[from_resolved..to_resolved].to_owned()) + } + RangeOrExpression::Expression(index) => { + let index_resolved = self.enforce_index(cs, scope.clone(), index); + array[index_resolved].to_owned() + } + } + } + // ResolvedValue::U32Array(field_array) => { + // match index { + // RangeOrExpression::Range(from, to) => { + // let from_resolved = match from { + // Some(from_index) => self.enforce_index(cs, scope.clone(), from_index), + // None => 0usize, // Array slice starts at index 0 + // }; + // let to_resolved = match to { + // Some(to_index) => self.enforce_index(cs, scope.clone(), to_index), + // None => field_array.len(), // Array slice ends at array length + // }; + // ResolvedValue::U32Array(field_array[from_resolved..to_resolved].to_owned()) + // } + // RangeOrExpression::Expression(index) => { + // let index_resolved = self.enforce_index(cs, scope.clone(), index); + // ResolvedValue::U32(field_array[index_resolved].to_owned()) + // } + // } + // } + // ResolvedValue::BooleanArray(bool_array) => { + // match index { + // RangeOrExpression::Range(from, to) => { + // let from_resolved = match from { + // Some(from_index) => self.enforce_index(cs, scope.clone(), from_index), + // None => 0usize, // Array slice starts at index 0 + // }; + // let to_resolved = match to { + // Some(to_index) => self.enforce_index(cs, scope.clone(), to_index), + // None => bool_array.len(), // Array slice ends at array length + // }; + // ResolvedValue::BooleanArray( + // bool_array[from_resolved..to_resolved].to_owned(), + // ) + // } + // RangeOrExpression::Expression(index) => { + // let index_resolved = self.enforce_index(cs, scope.clone(), index); + // ResolvedValue::Boolean(bool_array[index_resolved].to_owned()) + // } + // } + // } + value => unimplemented!("Cannot access element of untyped array {}", value), + } + } + fn enforce_struct_expression( &mut self, cs: &mut CS, @@ -47,70 +289,6 @@ impl> ResolvedProgram { } } - pub(crate) fn enforce_index( - &mut self, - cs: &mut CS, - scope: String, - index: IntegerExpression, - ) -> usize { - match self.enforce_integer_expression(cs, scope.clone(), index) { - ResolvedValue::U32(number) => number.value.unwrap() as usize, - value => unimplemented!("From index must resolve to a uint32, got {}", value), - } - } - - fn enforce_array_access_expression( - &mut self, - cs: &mut CS, - scope: String, - array: Box>, - index: IntegerRangeOrExpression, - ) -> ResolvedValue { - match self.enforce_expression(cs, scope.clone(), *array) { - ResolvedValue::U32Array(field_array) => { - match index { - IntegerRangeOrExpression::Range(from, to) => { - let from_resolved = match from { - Some(from_index) => self.enforce_index(cs, scope.clone(), from_index), - None => 0usize, // Array slice starts at index 0 - }; - let to_resolved = match to { - Some(to_index) => self.enforce_index(cs, scope.clone(), to_index), - None => field_array.len(), // Array slice ends at array length - }; - ResolvedValue::U32Array(field_array[from_resolved..to_resolved].to_owned()) - } - IntegerRangeOrExpression::Expression(index) => { - let index_resolved = self.enforce_index(cs, scope.clone(), index); - ResolvedValue::U32(field_array[index_resolved].to_owned()) - } - } - } - ResolvedValue::BooleanArray(bool_array) => { - match index { - IntegerRangeOrExpression::Range(from, to) => { - let from_resolved = match from { - Some(from_index) => self.enforce_index(cs, scope.clone(), from_index), - None => 0usize, // Array slice starts at index 0 - }; - let to_resolved = match to { - Some(to_index) => self.enforce_index(cs, scope.clone(), to_index), - None => bool_array.len(), // Array slice ends at array length - }; - ResolvedValue::BooleanArray( - bool_array[from_resolved..to_resolved].to_owned(), - ) - } - IntegerRangeOrExpression::Expression(index) => { - let index_resolved = self.enforce_index(cs, scope.clone(), index); - ResolvedValue::Boolean(bool_array[index_resolved].to_owned()) - } - } - } - value => unimplemented!("Cannot access element of untyped array {}", value), - } - } - fn enforce_struct_access_expression( &mut self, cs: &mut CS, @@ -152,55 +330,125 @@ impl> ResolvedProgram { expression: Expression, ) -> ResolvedValue { match expression { - Expression::Boolean(boolean_expression) => { - self.enforce_boolean_expression(cs, scope, boolean_expression) - } - Expression::Integer(integer_expression) => { - self.enforce_integer_expression(cs, scope, integer_expression) - } - Expression::FieldElement(field_expression) => { - self.enforce_field_expression(cs, scope, field_expression) - } + // Variables Expression::Variable(unresolved_variable) => { - let variable_name = new_scope_from_variable(scope, &unresolved_variable); + self.enforce_variable(scope, unresolved_variable) + } - // Evaluate the variable name in the current function scope - if self.contains_name(&variable_name) { - // Reassigning variable to another variable - self.get_mut(&variable_name).unwrap().clone() - } else if self.contains_variable(&unresolved_variable) { - // Check global scope (function and struct names) - self.get_mut_variable(&unresolved_variable).unwrap().clone() + // Values + Expression::Integer(integer) => Self::get_integer_constant(integer), + Expression::FieldElement(fe) => ResolvedValue::FieldElement(fe), + Expression::Boolean(bool) => Self::get_boolean_constant(bool), + + // Binary operations + Expression::Add(left, right) => { + let resolved_left = self.enforce_expression(cs, scope.clone(), *left); + let resolved_right = self.enforce_expression(cs, scope.clone(), *right); + + self.enforce_add_expression(cs, resolved_left, resolved_right) + } + Expression::Sub(left, right) => { + let resolved_left = self.enforce_expression(cs, scope.clone(), *left); + let resolved_right = self.enforce_expression(cs, scope.clone(), *right); + + self.enforce_sub_expression(cs, resolved_left, resolved_right) + } + Expression::Mul(left, right) => { + let resolved_left = self.enforce_expression(cs, scope.clone(), *left); + let resolved_right = self.enforce_expression(cs, scope.clone(), *right); + + self.enforce_mul_expression(cs, resolved_left, resolved_right) + } + Expression::Div(left, right) => { + let resolved_left = self.enforce_expression(cs, scope.clone(), *left); + let resolved_right = self.enforce_expression(cs, scope.clone(), *right); + + self.enforce_div_expression(cs, resolved_left, resolved_right) + } + Expression::Pow(left, right) => { + let resolved_left = self.enforce_expression(cs, scope.clone(), *left); + let resolved_right = self.enforce_expression(cs, scope.clone(), *right); + + self.enforce_pow_expression(cs, resolved_left, resolved_right) + } + + // Boolean operations + Expression::Not(expression) => { + Self::enforce_not(self.enforce_expression(cs, scope, *expression)) + } + Expression::Or(left, right) => { + let resolved_left = self.enforce_expression(cs, scope.clone(), *left); + let resolved_right = self.enforce_expression(cs, scope.clone(), *right); + + self.enforce_or(cs, resolved_left, resolved_right) + } + Expression::And(left, right) => { + let resolved_left = self.enforce_expression(cs, scope.clone(), *left); + let resolved_right = self.enforce_expression(cs, scope.clone(), *right); + + self.enforce_and(cs, resolved_left, resolved_right) + } + Expression::Eq(left, right) => { + let resolved_left = self.enforce_expression(cs, scope.clone(), *left); + let resolved_right = self.enforce_expression(cs, scope.clone(), *right); + + self.enforce_eq_expression(cs, resolved_left, resolved_right) + } + Expression::Geq(left, right) => { + unimplemented!("expression {} >= {} unimplemented", left, right) + } + Expression::Gt(left, right) => { + unimplemented!("expression {} > {} unimplemented", left, right) + } + Expression::Leq(left, right) => { + unimplemented!("expression {} <= {} unimplemented", left, right) + } + Expression::Lt(left, right) => { + unimplemented!("expression {} < {} unimplemented", left, right) + } + + // Conditionals + Expression::IfElse(first, second, third) => { + let resolved_first = match self.enforce_expression(cs, scope.clone(), *first) { + ResolvedValue::Boolean(resolved) => resolved, + _ => unimplemented!("if else conditional must resolve to boolean"), + }; + + if resolved_first.eq(&Boolean::Constant(true)) { + self.enforce_expression(cs, scope, *second) } else { - // The type of the unassigned variable depends on what is passed in - if std::env::args() - .nth(1) - .expect("variable declaration not passed in") - .parse::() - .is_ok() - { - ResolvedValue::Boolean(self.bool_from_variable( - cs, - variable_name, - unresolved_variable, - )) - } else { - self.integer_from_variable(variable_name, unresolved_variable) - } + self.enforce_expression(cs, scope, *third) } } + + // Arrays + Expression::Array(array) => self.enforce_array_expression(cs, scope, array), + Expression::ArrayAccess(array, index) => { + self.enforce_array_access_expression(cs, scope, array, *index) + } + + // Structs Expression::Struct(struct_name, members) => { self.enforce_struct_expression(cs, scope, struct_name, members) } - Expression::ArrayAccess(array, index) => { - self.enforce_array_access_expression(cs, scope, array, index) - } Expression::StructMemberAccess(struct_variable, struct_member) => { self.enforce_struct_access_expression(cs, scope, struct_variable, struct_member) } + + // Functions Expression::FunctionCall(function, arguments) => { self.enforce_function_access_expression(cs, scope, function, arguments) - } // expression => unimplemented!("expression not impl {}", expression), + } + // Expression::BooleanExp(boolean_expression) => { + // self.enforce_boolean_expression(cs, scope, boolean_expression) + // } + // Expression::IntegerExp(integer_expression) => { + // self.enforce_integer_expression(cs, scope, integer_expression) + // } + // Expression::FieldElementExp(field_expression) => { + // self.enforce_field_expression(cs, scope, field_expression) + // } + _ => unimplemented!(), } } } diff --git a/src/program/constraints/field_element.rs b/src/program/constraints/field_element.rs index d3d0b97710..fad93380be 100644 --- a/src/program/constraints/field_element.rs +++ b/src/program/constraints/field_element.rs @@ -4,10 +4,8 @@ //! @author Collin Chin //! @date 2020 -use crate::program::constraints::{new_scope_from_variable, ResolvedProgram, ResolvedValue}; -use crate::program::{ - new_variable_from_variable, FieldExpression, FieldSpreadOrExpression, Parameter, Variable, -}; +use crate::program::constraints::{ResolvedProgram, ResolvedValue}; +use crate::program::{new_variable_from_variable, Parameter, Variable}; use snarkos_models::curves::{Field, PrimeField}; use snarkos_models::gadgets::{r1cs::ConstraintSystem, utilities::boolean::Boolean}; @@ -92,181 +90,188 @@ impl> ResolvedProgram { // parameter_variable } - fn field_element_from_variable(&mut self, scope: String, variable: Variable) -> F { - // Evaluate variable name in current function scope - let variable_name = new_scope_from_variable(scope, &variable); + // fn field_element_from_variable(&mut self, scope: String, variable: Variable) -> F { + // // Evaluate variable name in current function scope + // let variable_name = new_scope_from_variable(scope, &variable); + // + // match self.get(&variable_name) { + // Some(value) => match value { + // ResolvedValue::FieldElement(fe) => fe.clone(), + // value => unimplemented!( + // "expected field element for variable {}, got {}", + // variable_name, + // value + // ), + // }, + // None => unimplemented!("cannot resolve variable {} in program", variable_name), + // } + // } - if self.contains_name(&variable_name) { - // TODO: return synthesis error: "assignment missing" here - match self.get(&variable_name).unwrap().clone() { - ResolvedValue::FieldElement(fe) => fe, - value => unimplemented!( - "expected field element for variable {}, got {}", - variable_name, - value - ), - } - } else { - unimplemented!("cannot resolve variable {} in program", variable_name) - } + // fn get_field_value(&mut self, cs: &mut CS, scope: String, expression: FieldExpression) -> F { + // match expression { + // FieldExpression::Variable(variable) => { + // self.field_element_from_variable(scope, variable) + // } + // FieldExpression::Number(element) => element, + // } + // } + + pub(crate) fn enforce_field_eq(&mut self, fe1: F, fe2: F) -> ResolvedValue { + ResolvedValue::Boolean(Boolean::Constant(fe1.eq(&fe2))) } - fn get_field_value(&mut self, cs: &mut CS, scope: String, expression: FieldExpression) -> F { - match expression { - FieldExpression::Variable(variable) => { - self.field_element_from_variable(scope, variable) - } - FieldExpression::Number(element) => element, - expression => match self.enforce_field_expression(cs, scope, expression) { - ResolvedValue::FieldElement(element) => element, - value => unimplemented!("expected field element, got {}", value), - }, - } + pub(crate) fn enforce_field_add(&mut self, fe1: F, fe2: F) -> ResolvedValue { + ResolvedValue::FieldElement(fe1.add(&fe2)) } - pub(crate) fn enforce_field_equality( - &mut self, - cs: &mut CS, - scope: String, - left: FieldExpression, - right: FieldExpression, - ) -> Boolean { - let left = self.get_field_value(cs, scope.clone(), left); - let right = self.get_field_value(cs, scope.clone(), right); - - Boolean::Constant(left.eq(&right)) + pub(crate) fn enforce_field_sub(&mut self, fe1: F, fe2: F) -> ResolvedValue { + ResolvedValue::FieldElement(fe1.sub(&fe2)) } - fn enforce_field_add( - &mut self, - cs: &mut CS, - scope: String, - left: FieldExpression, - right: FieldExpression, - ) -> ResolvedValue { - let left = self.get_field_value(cs, scope.clone(), left); - let right = self.get_field_value(cs, scope.clone(), right); - - ResolvedValue::FieldElement(left.add(&right)) + pub(crate) fn enforce_field_mul(&mut self, fe1: F, fe2: F) -> ResolvedValue { + ResolvedValue::FieldElement(fe1.mul(&fe2)) } - fn enforce_field_sub( - &mut self, - cs: &mut CS, - scope: String, - left: FieldExpression, - right: FieldExpression, - ) -> ResolvedValue { - let left = self.get_field_value(cs, scope.clone(), left); - let right = self.get_field_value(cs, scope.clone(), right); - - ResolvedValue::FieldElement(left.sub(&right)) + pub(crate) fn enforce_field_div(&mut self, fe1: F, fe2: F) -> ResolvedValue { + ResolvedValue::FieldElement(fe1.div(&fe2)) } - fn enforce_field_mul( - &mut self, - cs: &mut CS, - scope: String, - left: FieldExpression, - right: FieldExpression, - ) -> ResolvedValue { - let left = self.get_field_value(cs, scope.clone(), left); - let right = self.get_field_value(cs, scope.clone(), right); - - ResolvedValue::FieldElement(left.mul(&right)) - } - - fn enforce_field_div( - &mut self, - cs: &mut CS, - scope: String, - left: FieldExpression, - right: FieldExpression, - ) -> ResolvedValue { - let left = self.get_field_value(cs, scope.clone(), left); - let right = self.get_field_value(cs, scope.clone(), right); - - ResolvedValue::FieldElement(left.div(&right)) - } - - fn enforce_field_pow( - &mut self, - _cs: &mut CS, - _scope: String, - _left: FieldExpression, - _right: FieldExpression, - ) -> ResolvedValue { + pub(crate) fn enforce_field_pow(&mut self, _fe1: F, _fe2: F) -> ResolvedValue { unimplemented!("field element exponentiation not supported") - // let left = self.get_field_value(cs, scope.clone(), left); - // let right = self.get_field_value(cs, scope.clone(), right); - // - // ResolvedValue::FieldElement(left.pow(&right)) + + // ResolvedValue::FieldElement(fe1.pow(&fe2)) } - pub(crate) fn enforce_field_expression( - &mut self, - cs: &mut CS, - scope: String, - expression: FieldExpression, - ) -> ResolvedValue { - match expression { - FieldExpression::Variable(variable) => { - ResolvedValue::FieldElement(self.field_element_from_variable(scope, variable)) - } - FieldExpression::Number(field) => ResolvedValue::FieldElement(field), - FieldExpression::Add(left, right) => self.enforce_field_add(cs, scope, *left, *right), - FieldExpression::Sub(left, right) => self.enforce_field_sub(cs, scope, *left, *right), - FieldExpression::Mul(left, right) => self.enforce_field_mul(cs, scope, *left, *right), - FieldExpression::Div(left, right) => self.enforce_field_div(cs, scope, *left, *right), - FieldExpression::Pow(left, right) => self.enforce_field_pow(cs, scope, *left, *right), - FieldExpression::IfElse(first, second, third) => { - let resolved_first = - match self.enforce_boolean_expression(cs, scope.clone(), *first) { - ResolvedValue::Boolean(resolved) => resolved, - _ => unimplemented!("if else conditional must resolve to boolean"), - }; + // fn enforce_field_add_old( + // &mut self, + // cs: &mut CS, + // scope: String, + // left: FieldExpression, + // right: FieldExpression, + // ) -> ResolvedValue { + // let left = self.get_field_value(cs, scope.clone(), left); + // let right = self.get_field_value(cs, scope.clone(), right); + // + // ResolvedValue::FieldElement(left.add(&right)) + // } + // + // fn enforce_field_sub_old( + // &mut self, + // cs: &mut CS, + // scope: String, + // left: FieldExpression, + // right: FieldExpression, + // ) -> ResolvedValue { + // let left = self.get_field_value(cs, scope.clone(), left); + // let right = self.get_field_value(cs, scope.clone(), right); + // + // ResolvedValue::FieldElement(left.sub(&right)) + // } + // + // fn enforce_field_mul_old( + // &mut self, + // cs: &mut CS, + // scope: String, + // left: FieldExpression, + // right: FieldExpression, + // ) -> ResolvedValue { + // let left = self.get_field_value(cs, scope.clone(), left); + // let right = self.get_field_value(cs, scope.clone(), right); + // + // ResolvedValue::FieldElement(left.mul(&right)) + // } + // + // fn enforce_field_div_old( + // &mut self, + // cs: &mut CS, + // scope: String, + // left: FieldExpression, + // right: FieldExpression, + // ) -> ResolvedValue { + // let left = self.get_field_value(cs, scope.clone(), left); + // let right = self.get_field_value(cs, scope.clone(), right); + // + // ResolvedValue::FieldElement(left.div(&right)) + // } + // + // fn enforce_field_pow_old( + // &mut self, + // _cs: &mut CS, + // _scope: String, + // _left: FieldExpression, + // _right: FieldExpression, + // ) -> ResolvedValue { + // unimplemented!("field element exponentiation not supported") + // // let left = self.get_field_value(cs, scope.clone(), left); + // // let right = self.get_field_value(cs, scope.clone(), right); + // // + // // ResolvedValue::FieldElement(left.pow(&right)) + // } - if resolved_first.eq(&Boolean::Constant(true)) { - self.enforce_field_expression(cs, scope, *second) - } else { - self.enforce_field_expression(cs, scope, *third) - } - } - FieldExpression::Array(array) => { - let mut result = vec![]; - array.into_iter().for_each(|element| match *element { - FieldSpreadOrExpression::Spread(spread) => match spread { - FieldExpression::Variable(variable) => { - let array_name = new_scope_from_variable(scope.clone(), &variable); - match self.get(&array_name) { - Some(value) => match value { - ResolvedValue::FieldElementArray(array) => { - result.extend(array.clone()) - } - value => unimplemented!( - "spreads only implemented for arrays, got {}", - value - ), - }, - None => unimplemented!( - "cannot copy elements from array that does not exist {}", - variable.name - ), - } - } - value => { - unimplemented!("spreads only implemented for arrays, got {}", value) - } - }, - FieldSpreadOrExpression::Expression(expression) => { - match self.enforce_field_expression(cs, scope.clone(), expression) { - ResolvedValue::FieldElement(value) => result.push(value), - _ => unimplemented!("cannot resolve field"), - } - } - }); - ResolvedValue::FieldElementArray(result) - } - } - } + // pub(crate) fn enforce_field_expression( + // &mut self, + // cs: &mut CS, + // scope: String, + // expression: FieldExpression, + // ) -> ResolvedValue { + // match expression { + // FieldExpression::Variable(variable) => { + // ResolvedValue::FieldElement(self.field_element_from_variable(scope, variable)) + // } + // FieldExpression::Number(field) => ResolvedValue::FieldElement(field), + // FieldExpression::Add(left, right) => self.enforce_field_add_old(cs, scope, *left, *right), + // FieldExpression::Sub(left, right) => self.enforce_field_sub_old(cs, scope, *left, *right), + // FieldExpression::Mul(left, right) => self.enforce_field_mul_old(cs, scope, *left, *right), + // FieldExpression::Div(left, right) => self.enforce_field_div_old(cs, scope, *left, *right), + // FieldExpression::Pow(left, right) => self.enforce_field_pow_old(cs, scope, *left, *right), + // FieldExpression::IfElse(first, second, third) => { + // let resolved_first = + // match self.enforce_boolean_expression(cs, scope.clone(), *first) { + // ResolvedValue::Boolean(resolved) => resolved, + // _ => unimplemented!("if else conditional must resolve to boolean"), + // }; + // + // if resolved_first.eq(&Boolean::Constant(true)) { + // self.enforce_field_expression(cs, scope, *second) + // } else { + // self.enforce_field_expression(cs, scope, *third) + // } + // } + // FieldExpression::Array(array) => { + // let mut result = vec![]; + // array.into_iter().for_each(|element| match *element { + // FieldSpreadOrExpression::Spread(spread) => match spread { + // FieldExpression::Variable(variable) => { + // let array_name = new_scope_from_variable(scope.clone(), &variable); + // match self.get(&array_name) { + // Some(value) => match value { + // ResolvedValue::FieldElementArray(array) => { + // result.extend(array.clone()) + // } + // value => unimplemented!( + // "spreads only implemented for arrays, got {}", + // value + // ), + // }, + // None => unimplemented!( + // "cannot copy elements from array that does not exist {}", + // variable.name + // ), + // } + // } + // value => { + // unimplemented!("spreads only implemented for arrays, got {}", value) + // } + // }, + // FieldSpreadOrExpression::Expression(expression) => { + // match self.enforce_field_expression(cs, scope.clone(), expression) { + // ResolvedValue::FieldElement(value) => result.push(value), + // _ => unimplemented!("cannot resolve field"), + // } + // } + // }); + // ResolvedValue::FieldElementArray(result) + // } + // } + // } } diff --git a/src/program/constraints/integer.rs b/src/program/constraints/integer.rs index 7e17484eb9..8b88748361 100644 --- a/src/program/constraints/integer.rs +++ b/src/program/constraints/integer.rs @@ -4,11 +4,8 @@ //! @author Collin Chin //! @date 2020 -use crate::program::constraints::{new_scope_from_variable, ResolvedProgram, ResolvedValue}; -use crate::program::{ - new_variable_from_variable, Integer, IntegerExpression, IntegerSpreadOrExpression, Parameter, - Variable, -}; +use crate::program::constraints::{ResolvedProgram, ResolvedValue}; +use crate::program::{new_variable_from_variable, Integer, Parameter, Variable}; use snarkos_models::curves::{Field, PrimeField}; use snarkos_models::gadgets::{ @@ -96,42 +93,36 @@ impl> ResolvedProgram { // parameter_variable } - pub(crate) fn integer_from_variable( - &mut self, - scope: String, - variable: Variable, - ) -> ResolvedValue { - // Evaluate variable name in current function scope - let variable_name = new_scope_from_variable(scope, &variable); + // pub(crate) fn integer_from_variable( + // &mut self, + // scope: String, + // variable: Variable, + // ) -> ResolvedValue { + // // Evaluate variable name in current function scope + // let variable_name = new_scope_from_variable(scope, &variable); + // + // match self.get(&variable_name) { + // Some(value) => value.clone(), + // None => unimplemented!("cannot resolve variable {} in program", variable_name), + // } + // } - if self.contains_name(&variable_name) { - // TODO: return synthesis error: "assignment missing" here - self.get(&variable_name).unwrap().clone() - } else { - unimplemented!("cannot resolve variable {} in program", variable_name) - } - } - - fn get_integer_constant(integer: Integer) -> ResolvedValue { + pub(crate) fn get_integer_constant(integer: Integer) -> ResolvedValue { match integer { Integer::U32(u32_value) => ResolvedValue::U32(UInt32::constant(u32_value)), } } + // + // pub(crate) fn get_integer_value( + // integer: Integer + // ) -> ResolvedValue { + // match expression { + // IntegerExpression::Variable(variable) => self.integer_from_variable(scope, variable), + // IntegerExpression::Number(number) => Self::get_integer_constant(number), + // } + // } - fn get_integer_value( - &mut self, - cs: &mut CS, - scope: String, - expression: IntegerExpression, - ) -> ResolvedValue { - match expression { - IntegerExpression::Variable(variable) => self.integer_from_variable(scope, variable), - IntegerExpression::Number(number) => Self::get_integer_constant(number), - expression => self.enforce_integer_expression(cs, scope, expression), - } - } - - fn enforce_u32_equality(cs: &mut CS, left: UInt32, right: UInt32) -> Boolean { + pub(crate) fn enforce_u32_eq(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { left.conditional_enforce_equal( cs.ns(|| format!("enforce field equal")), &right, @@ -139,30 +130,10 @@ impl> ResolvedProgram { ) .unwrap(); - Boolean::Constant(true) + ResolvedValue::Boolean(Boolean::Constant(true)) } - pub(crate) fn enforce_integer_equality( - &mut self, - cs: &mut CS, - scope: String, - left: IntegerExpression, - right: IntegerExpression, - ) -> Boolean { - let left = self.get_integer_value(cs, scope.clone(), left); - let right = self.get_integer_value(cs, scope.clone(), right); - - match (left, right) { - (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { - Self::enforce_u32_equality(cs, left_u32, right_u32) - } - (left_int, right_int) => { - unimplemented!("equality not impl between {} == {}", left_int, right_int) - } - } - } - - fn enforce_u32_add(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { + pub(crate) fn enforce_u32_add(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { ResolvedValue::U32( UInt32::addmany( cs.ns(|| format!("enforce {} + {}", left.value.unwrap(), right.value.unwrap())), @@ -172,27 +143,7 @@ impl> ResolvedProgram { ) } - fn enforce_integer_add( - &mut self, - cs: &mut CS, - scope: String, - left: IntegerExpression, - right: IntegerExpression, - ) -> ResolvedValue { - let left = self.get_integer_value(cs, scope.clone(), left); - let right = self.get_integer_value(cs, scope.clone(), right); - - match (left, right) { - (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { - Self::enforce_u32_add(cs, left_u32, right_u32) - } - (left_int, right_int) => { - unimplemented!("add not impl between {} + {}", left_int, right_int) - } - } - } - - fn enforce_u32_sub(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { + pub(crate) fn enforce_u32_sub(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { ResolvedValue::U32( left.sub( cs.ns(|| format!("enforce {} - {}", left.value.unwrap(), right.value.unwrap())), @@ -202,27 +153,7 @@ impl> ResolvedProgram { ) } - fn enforce_integer_sub( - &mut self, - cs: &mut CS, - scope: String, - left: IntegerExpression, - right: IntegerExpression, - ) -> ResolvedValue { - let left = self.get_integer_value(cs, scope.clone(), left); - let right = self.get_integer_value(cs, scope.clone(), right); - - match (left, right) { - (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { - Self::enforce_u32_sub(cs, left_u32, right_u32) - } - (left_int, right_int) => { - unimplemented!("add not impl between {} + {}", left_int, right_int) - } - } - } - - fn enforce_u32_mul(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { + pub(crate) fn enforce_u32_mul(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { ResolvedValue::U32( left.mul( cs.ns(|| format!("enforce {} * {}", left.value.unwrap(), right.value.unwrap())), @@ -231,28 +162,7 @@ impl> ResolvedProgram { .unwrap(), ) } - - fn enforce_integer_mul( - &mut self, - cs: &mut CS, - scope: String, - left: IntegerExpression, - right: IntegerExpression, - ) -> ResolvedValue { - let left = self.get_integer_value(cs, scope.clone(), left); - let right = self.get_integer_value(cs, scope.clone(), right); - - match (left, right) { - (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { - Self::enforce_u32_mul(cs, left_u32, right_u32) - } - (left_int, right_int) => { - unimplemented!("add not impl between {} + {}", left_int, right_int) - } - } - } - - fn enforce_u32_div(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { + pub(crate) fn enforce_u32_div(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { ResolvedValue::U32( left.div( cs.ns(|| format!("enforce {} / {}", left.value.unwrap(), right.value.unwrap())), @@ -261,28 +171,7 @@ impl> ResolvedProgram { .unwrap(), ) } - - fn enforce_integer_div( - &mut self, - cs: &mut CS, - scope: String, - left: IntegerExpression, - right: IntegerExpression, - ) -> ResolvedValue { - let left = self.get_integer_value(cs, scope.clone(), left); - let right = self.get_integer_value(cs, scope.clone(), right); - - match (left, right) { - (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { - Self::enforce_u32_div(cs, left_u32, right_u32) - } - (left_int, right_int) => { - unimplemented!("add not impl between {} + {}", left_int, right_int) - } - } - } - - fn enforce_u32_pow(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { + pub(crate) fn enforce_u32_pow(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { ResolvedValue::U32( left.pow( cs.ns(|| { @@ -298,96 +187,242 @@ impl> ResolvedProgram { ) } - fn enforce_integer_pow( - &mut self, - cs: &mut CS, - scope: String, - left: IntegerExpression, - right: IntegerExpression, - ) -> ResolvedValue { - let left = self.get_integer_value(cs, scope.clone(), left); - let right = self.get_integer_value(cs, scope.clone(), right); + // pub(crate) fn enforce_integer_equality( + // &mut self, + // cs: &mut CS, + // scope: String, + // left: UInt32, + // right: UInt32, + // ) -> Boolean { + // let left = self.get_integer_value(cs, scope.clone(), left); + // let right = self.get_integer_value(cs, scope.clone(), right); + // + // match (left, right) { + // (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { + // Self::enforce_u32_equality(cs, left_u32, right_u32) + // } + // (left_int, right_int) => { + // unimplemented!("equality not impl between {} == {}", left_int, right_int) + // } + // } + // } - match (left, right) { - (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { - Self::enforce_u32_pow(cs, left_u32, right_u32) - } - (left_int, right_int) => { - unimplemented!("add not impl between {} + {}", left_int, right_int) - } - } - } + // fn enforce_integer_add_old( + // &mut self, + // cs: &mut CS, + // scope: String, + // left: IntegerExpression, + // right: IntegerExpression, + // ) -> ResolvedValue { + // let left = self.get_integer_value(cs, scope.clone(), left); + // let right = self.get_integer_value(cs, scope.clone(), right); + // + // match (left, right) { + // (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { + // Self::enforce_u32_add(cs, left_u32, right_u32) + // } + // (left_int, right_int) => { + // unimplemented!("add not impl between {} + {}", left_int, right_int) + // } + // } + // } - pub(crate) fn enforce_integer_expression( - &mut self, - cs: &mut CS, - scope: String, - expression: IntegerExpression, - ) -> ResolvedValue { - match expression { - IntegerExpression::Variable(variable) => self.integer_from_variable(scope, variable), - IntegerExpression::Number(number) => Self::get_integer_constant(number), - IntegerExpression::Add(left, right) => { - self.enforce_integer_add(cs, scope, *left, *right) - } - IntegerExpression::Sub(left, right) => { - self.enforce_integer_sub(cs, scope, *left, *right) - } - IntegerExpression::Mul(left, right) => { - self.enforce_integer_mul(cs, scope, *left, *right) - } - IntegerExpression::Div(left, right) => { - self.enforce_integer_div(cs, scope, *left, *right) - } - IntegerExpression::Pow(left, right) => { - self.enforce_integer_pow(cs, scope, *left, *right) - } - IntegerExpression::IfElse(first, second, third) => { - let resolved_first = - match self.enforce_boolean_expression(cs, scope.clone(), *first) { - ResolvedValue::Boolean(resolved) => resolved, - _ => unimplemented!("if else conditional must resolve to boolean"), - }; + // fn enforce_u32_sub_old(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { + // ResolvedValue::U32( + // left.sub( + // cs.ns(|| format!("enforce {} - {}", left.value.unwrap(), right.value.unwrap())), + // &right, + // ) + // .unwrap(), + // ) + // } - if resolved_first.eq(&Boolean::Constant(true)) { - self.enforce_integer_expression(cs, scope, *second) - } else { - self.enforce_integer_expression(cs, scope, *third) - } - } - IntegerExpression::Array(array) => { - let mut result = vec![]; - array.into_iter().for_each(|element| match *element { - IntegerSpreadOrExpression::Spread(spread) => match spread { - IntegerExpression::Variable(variable) => { - let array_name = new_scope_from_variable(scope.clone(), &variable); - match self.get(&array_name) { - Some(value) => match value { - ResolvedValue::U32Array(array) => result.extend(array.clone()), - value => unimplemented!( - "spreads only implemented for arrays, got {}", - value - ), - }, - None => unimplemented!( - "cannot copy elements from array that does not exist {}", - variable.name - ), - } - } - value => { - unimplemented!("spreads only implemented for arrays, got {}", value) - } - }, - IntegerSpreadOrExpression::Expression(expression) => { - match self.enforce_integer_expression(cs, scope.clone(), expression) { - ResolvedValue::U32(value) => result.push(value), - _ => unimplemented!("cannot resolve field"), - } - } - }); - ResolvedValue::U32Array(result) - } - } - } + // fn enforce_integer_sub( + // &mut self, + // cs: &mut CS, + // scope: String, + // left: IntegerExpression, + // right: IntegerExpression, + // ) -> ResolvedValue { + // let left = self.get_integer_value(cs, scope.clone(), left); + // let right = self.get_integer_value(cs, scope.clone(), right); + // + // match (left, right) { + // (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { + // Self::enforce_u32_sub_old(cs, left_u32, right_u32) + // } + // (left_int, right_int) => { + // unimplemented!("add not impl between {} + {}", left_int, right_int) + // } + // } + // } + + // fn enforce_u32_mul_old(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { + // ResolvedValue::U32( + // left.mul( + // cs.ns(|| format!("enforce {} * {}", left.value.unwrap(), right.value.unwrap())), + // &right, + // ) + // .unwrap(), + // ) + // } + // + // fn enforce_integer_mul( + // &mut self, + // cs: &mut CS, + // scope: String, + // left: IntegerExpression, + // right: IntegerExpression, + // ) -> ResolvedValue { + // let left = self.get_integer_value(cs, scope.clone(), left); + // let right = self.get_integer_value(cs, scope.clone(), right); + // + // match (left, right) { + // (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { + // Self::enforce_u32_mul_old(cs, left_u32, right_u32) + // } + // (left_int, right_int) => { + // unimplemented!("add not impl between {} + {}", left_int, right_int) + // } + // } + // } + // + // fn enforce_u32_div_old(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { + // ResolvedValue::U32( + // left.div( + // cs.ns(|| format!("enforce {} / {}", left.value.unwrap(), right.value.unwrap())), + // &right, + // ) + // .unwrap(), + // ) + // } + // + // fn enforce_integer_div( + // &mut self, + // cs: &mut CS, + // scope: String, + // left: IntegerExpression, + // right: IntegerExpression, + // ) -> ResolvedValue { + // let left = self.get_integer_value(cs, scope.clone(), left); + // let right = self.get_integer_value(cs, scope.clone(), right); + // + // match (left, right) { + // (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { + // Self::enforce_u32_div_old(cs, left_u32, right_u32) + // } + // (left_int, right_int) => { + // unimplemented!("add not impl between {} + {}", left_int, right_int) + // } + // } + // } + // + // fn enforce_u32_pow_old(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue { + // ResolvedValue::U32( + // left.pow( + // cs.ns(|| { + // format!( + // "enforce {} ** {}", + // left.value.unwrap(), + // right.value.unwrap() + // ) + // }), + // &right, + // ) + // .unwrap(), + // ) + // } + // + // fn enforce_integer_pow( + // &mut self, + // cs: &mut CS, + // scope: String, + // left: IntegerExpression, + // right: IntegerExpression, + // ) -> ResolvedValue { + // let left = self.get_integer_value(cs, scope.clone(), left); + // let right = self.get_integer_value(cs, scope.clone(), right); + // + // match (left, right) { + // (ResolvedValue::U32(left_u32), ResolvedValue::U32(right_u32)) => { + // Self::enforce_u32_pow_old(cs, left_u32, right_u32) + // } + // (left_int, right_int) => { + // unimplemented!("add not impl between {} + {}", left_int, right_int) + // } + // } + // } + + // pub(crate) fn enforce_integer_expression( + // &mut self, + // cs: &mut CS, + // scope: String, + // expression: IntegerExpression, + // ) -> ResolvedValue { + // match expression { + // IntegerExpression::Variable(variable) => self.integer_from_variable(scope, variable), + // IntegerExpression::Number(number) => Self::get_integer_constant(number), + // IntegerExpression::Add(left, right) => { + // self.enforce_integer_add_old(cs, scope, *left, *right) + // } + // IntegerExpression::Sub(left, right) => { + // self.enforce_integer_sub(cs, scope, *left, *right) + // } + // IntegerExpression::Mul(left, right) => { + // self.enforce_integer_mul(cs, scope, *left, *right) + // } + // IntegerExpression::Div(left, right) => { + // self.enforce_integer_div(cs, scope, *left, *right) + // } + // IntegerExpression::Pow(left, right) => { + // self.enforce_integer_pow(cs, scope, *left, *right) + // } + // IntegerExpression::IfElse(first, second, third) => { + // let resolved_first = + // match self.enforce_boolean_expression(cs, scope.clone(), *first) { + // ResolvedValue::Boolean(resolved) => resolved, + // _ => unimplemented!("if else conditional must resolve to boolean"), + // }; + // + // if resolved_first.eq(&Boolean::Constant(true)) { + // self.enforce_integer_expression(cs, scope, *second) + // } else { + // self.enforce_integer_expression(cs, scope, *third) + // } + // } + // IntegerExpression::Array(array) => { + // let mut result = vec![]; + // array.into_iter().for_each(|element| match *element { + // IntegerSpreadOrExpression::Spread(spread) => match spread { + // IntegerExpression::Variable(variable) => { + // let array_name = new_scope_from_variable(scope.clone(), &variable); + // match self.get(&array_name) { + // Some(value) => match value { + // ResolvedValue::U32Array(array) => result.extend(array.clone()), + // value => unimplemented!( + // "spreads only implemented for arrays, got {}", + // value + // ), + // }, + // None => unimplemented!( + // "cannot copy elements from array that does not exist {}", + // variable.name + // ), + // } + // } + // value => { + // unimplemented!("spreads only implemented for arrays, got {}", value) + // } + // }, + // IntegerSpreadOrExpression::Expression(expression) => { + // match self.enforce_integer_expression(cs, scope.clone(), expression) { + // ResolvedValue::U32(value) => result.push(value), + // _ => unimplemented!("cannot resolve field"), + // } + // } + // }); + // ResolvedValue::U32Array(result) + // } + // } + // } } diff --git a/src/program/constraints/resolved_value.rs b/src/program/constraints/resolved_value.rs index 1f5a06be67..1a717764b8 100644 --- a/src/program/constraints/resolved_value.rs +++ b/src/program/constraints/resolved_value.rs @@ -13,11 +13,9 @@ use std::fmt; #[derive(Clone)] pub enum ResolvedValue { U32(UInt32), - U32Array(Vec), FieldElement(F), - FieldElementArray(Vec), Boolean(Boolean), - BooleanArray(Vec), + Array(Vec>), StructDefinition(Struct), StructExpression(Variable, Vec>), Function(Function), @@ -28,17 +26,22 @@ impl ResolvedValue { pub(crate) fn match_type(&self, ty: &Type) -> bool { match (self, ty) { (ResolvedValue::U32(ref _a), Type::U32) => true, - (ResolvedValue::U32Array(ref arr), Type::Array(ref arr_type, ref len)) => { - (arr.len() == *len) & (**arr_type == Type::U32) - } (ResolvedValue::FieldElement(ref _a), Type::FieldElement) => true, - (ResolvedValue::FieldElementArray(ref arr), Type::Array(ref arr_type, ref len)) => { - (arr.len() == *len) & (**arr_type == Type::FieldElement) - } (ResolvedValue::Boolean(ref _a), Type::Boolean) => true, - (ResolvedValue::BooleanArray(ref arr), Type::Array(ref arr_type, ref len)) => { - (arr.len() == *len) & (**arr_type == Type::Boolean) - } + (ResolvedValue::Array(ref _arr), Type::Array(ref _ty, ref _len)) => true, // todo: add array types + // (ResolvedValue::U32Array(ref arr), Type::Array(ref arr_type, ref len)) => { + // (arr.len() == *len) & (**arr_type == Type::U32) + // } + // (ResolvedValue::FieldElementArray(ref arr), Type::Array(ref arr_type, ref len)) => { + // (arr.len() == *len) & (**arr_type == Type::FieldElement) + // } + // (ResolvedValue::BooleanArray(ref arr), Type::Array(ref arr_type, ref len)) => { + // (arr.len() == *len) & (**arr_type == Type::Boolean) + // } + ( + ResolvedValue::StructExpression(ref actual_name, ref _members), + Type::Struct(ref expected_name), + ) => actual_name == expected_name, (ResolvedValue::Return(ref values), ty) => { let mut res = true; for value in values { @@ -55,18 +58,9 @@ impl fmt::Display for ResolvedValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { ResolvedValue::U32(ref value) => write!(f, "{}", value.value.unwrap()), - ResolvedValue::U32Array(ref array) => { - write!(f, "[")?; - for (i, e) in array.iter().enumerate() { - write!(f, "{}", e.value.unwrap())?; - if i < array.len() - 1 { - write!(f, ", ")?; - } - } - write!(f, "]") - } ResolvedValue::FieldElement(ref value) => write!(f, "{}", value), - ResolvedValue::FieldElementArray(ref array) => { + ResolvedValue::Boolean(ref value) => write!(f, "{}", value.get_value().unwrap()), + ResolvedValue::Array(ref array) => { write!(f, "[")?; for (i, e) in array.iter().enumerate() { write!(f, "{}", e)?; @@ -76,17 +70,36 @@ impl fmt::Display for ResolvedValue { } write!(f, "]") } - ResolvedValue::Boolean(ref value) => write!(f, "{}", value.get_value().unwrap()), - ResolvedValue::BooleanArray(ref array) => { - write!(f, "[")?; - for (i, e) in array.iter().enumerate() { - write!(f, "{}", e.get_value().unwrap())?; - if i < array.len() - 1 { - write!(f, ", ")?; - } - } - write!(f, "]") - } + // ResolvedValue::U32Array(ref array) => { + // write!(f, "[")?; + // for (i, e) in array.iter().enumerate() { + // write!(f, "{}", e.value.unwrap())?; + // if i < array.len() - 1 { + // write!(f, ", ")?; + // } + // } + // write!(f, "]") + // } + // ResolvedValue::FieldElementArray(ref array) => { + // write!(f, "[")?; + // for (i, e) in array.iter().enumerate() { + // write!(f, "{}", e)?; + // if i < array.len() - 1 { + // write!(f, ", ")?; + // } + // } + // write!(f, "]") + // } + // ResolvedValue::BooleanArray(ref array) => { + // write!(f, "[")?; + // for (i, e) in array.iter().enumerate() { + // write!(f, "{}", e.get_value().unwrap())?; + // if i < array.len() - 1 { + // write!(f, ", ")?; + // } + // } + // write!(f, "]") + // } ResolvedValue::StructExpression(ref variable, ref members) => { write!(f, "{} {{", variable)?; for (i, member) in members.iter().enumerate() { diff --git a/src/program/constraints/statement.rs b/src/program/constraints/statement.rs index af76a89dc7..4b0f02c8fc 100644 --- a/src/program/constraints/statement.rs +++ b/src/program/constraints/statement.rs @@ -5,9 +5,7 @@ //! @date 2020 use crate::program::constraints::{new_scope_from_variable, ResolvedProgram, ResolvedValue}; -use crate::program::{ - Assignee, Expression, IntegerExpression, IntegerRangeOrExpression, Statement, Type, Variable, -}; +use crate::program::{Assignee, Expression, Integer, RangeOrExpression, Statement, Type, Variable}; use snarkos_models::curves::{Field, PrimeField}; use snarkos_models::gadgets::{r1cs::ConstraintSystem, utilities::uint32::UInt32}; @@ -50,17 +48,14 @@ impl> ResolvedProgram { // Resolve index so we know if we are assigning to a single value or a range of values match index_expression { - IntegerRangeOrExpression::Expression(index) => { + RangeOrExpression::Expression(index) => { let index = self.enforce_index(cs, scope.clone(), index); // Modify the single value of the array in place match self.get_mut(&expected_array_name) { - Some(value) => match (value, result) { - (ResolvedValue::U32Array(old), ResolvedValue::U32(new)) => { - old[index] = new.to_owned(); - } - (ResolvedValue::BooleanArray(old), ResolvedValue::Boolean(new)) => { - old[index] = new.to_owned(); + Some(value) => match value { + ResolvedValue::Array(old) => { + old[index] = result.to_owned(); } _ => { unimplemented!("Cannot assign single index to array of values ") @@ -72,29 +67,20 @@ impl> ResolvedProgram { ), } } - IntegerRangeOrExpression::Range(from, to) => { + RangeOrExpression::Range(from, to) => { let from_index = match from { - Some(expression) => self.enforce_index(cs, scope.clone(), expression), + Some(integer) => integer.to_usize(), None => 0usize, }; let to_index_option = match to { - Some(expression) => { - Some(self.enforce_index(cs, scope.clone(), expression)) - } + Some(integer) => Some(integer.to_usize()), None => None, }; // Modify the range of values of the array in place match self.get_mut(&expected_array_name) { Some(value) => match (value, result) { - (ResolvedValue::U32Array(old), ResolvedValue::U32Array(new)) => { - let to_index = to_index_option.unwrap_or(old.len()); - old.splice(from_index..to_index, new.iter().cloned()); - } - ( - ResolvedValue::BooleanArray(old), - ResolvedValue::BooleanArray(new), - ) => { + (ResolvedValue::Array(old), ResolvedValue::Array(new)) => { let to_index = to_index_option.unwrap_or(old.len()); old.splice(from_index..to_index, new.iter().cloned()); } @@ -192,14 +178,11 @@ impl> ResolvedProgram { cs: &mut CS, scope: String, index: Variable, - start: IntegerExpression, - stop: IntegerExpression, + start: Integer, + stop: Integer, statements: Vec>, ) { - let start_index = self.enforce_index(cs, scope.clone(), start); - let stop_index = self.enforce_index(cs, scope.clone(), stop); - - for i in start_index..stop_index { + for i in start.to_usize()..stop.to_usize() { // Store index in current function scope. // For loop scope is not implemented. let index_name = new_scope_from_variable(scope.clone(), &index); diff --git a/src/program/types.rs b/src/program/types.rs index 5965d3a02e..60a7172205 100644 --- a/src/program/types.rs +++ b/src/program/types.rs @@ -26,122 +26,80 @@ pub enum Integer { // U64(u64), } -/// Spread operator or u32 expression enum -#[derive(Debug, Clone)] -pub enum IntegerSpreadOrExpression { - Spread(IntegerExpression), - Expression(IntegerExpression), +impl Integer { + pub fn to_usize(&self) -> usize { + match *self { + // U8(u8) + Integer::U32(num) => num as usize, // U64(u64) + } + } } -/// Range or integer expression enum +/// Range or expression enum #[derive(Debug, Clone)] -pub enum IntegerRangeOrExpression { - Range(Option>, Option>), - Expression(IntegerExpression), +pub enum RangeOrExpression { + Range(Option, Option), + Expression(Expression), } -/// Expression that evaluates to a u32 value +/// Spread or expression #[derive(Debug, Clone)] -pub enum IntegerExpression { - Variable(Variable), - Number(Integer), - // Operators - Add(Box>, Box>), - Sub(Box>, Box>), - Mul(Box>, Box>), - Div(Box>, Box>), - Pow(Box>, Box>), - // Conditionals - IfElse( - Box>, - Box>, - Box>, - ), - // Arrays - Array(Vec>>), -} - -/// Spread or field expression enum -#[derive(Debug, Clone)] -pub enum FieldSpreadOrExpression { - Spread(FieldExpression), - Expression(FieldExpression), -} - -/// Expression that evaluates to a field value -#[derive(Debug, Clone)] -pub enum FieldExpression { - Variable(Variable), - Number(F), - // Operators - Add(Box>, Box>), - Sub(Box>, Box>), - Mul(Box>, Box>), - Div(Box>, Box>), - Pow(Box>, Box>), - // Conditionals - IfElse( - Box>, - Box>, - Box>, - ), - // Arrays - Array(Vec>>), -} - -/// Spread or field expression enum -#[derive(Debug, Clone)] -pub enum BooleanSpreadOrExpression { - Spread(BooleanExpression), - Expression(BooleanExpression), -} - -/// Expression that evaluates to a boolean value -#[derive(Debug, Clone)] -pub enum BooleanExpression { - Variable(Variable), - Value(bool), - // Boolean operators - Not(Box>), - Or(Box>, Box>), - And(Box>, Box>), - BoolEq(Box>, Box>), - // Integer operators - IntegerEq(Box>, Box>), - Geq(Box>, Box>), - Gt(Box>, Box>), - Leq(Box>, Box>), - Lt(Box>, Box>), - // Field operators - FieldEq(Box>, Box>), - // Conditionals - IfElse( - Box>, - Box>, - Box>, - ), - // Arrays - Array(Vec>>), +pub enum SpreadOrExpression { + Spread(Expression), + Expression(Expression), } /// Expression that evaluates to a value #[derive(Debug, Clone)] pub enum Expression { - Integer(IntegerExpression), - FieldElement(FieldExpression), - Boolean(BooleanExpression), + // Variable Variable(Variable), + + // Values + Integer(Integer), + FieldElement(F), + Boolean(bool), + + // Number operations + Add(Box>, Box>), + Sub(Box>, Box>), + Mul(Box>, Box>), + Div(Box>, Box>), + Pow(Box>, Box>), + + // Boolean operations + Not(Box>), + Or(Box>, Box>), + And(Box>, Box>), + Eq(Box>, Box>), + Geq(Box>, Box>), + Gt(Box>, Box>), + Leq(Box>, Box>), + Lt(Box>, Box>), + + // Conditionals + IfElse(Box>, Box>, Box>), + + // Arrays + Array(Vec>>), + ArrayAccess(Box>, Box>), + + // Structs Struct(Variable, Vec>), - ArrayAccess(Box>, IntegerRangeOrExpression), StructMemberAccess(Box>, Variable), // (struct name, struct member name) + + // Functions FunctionCall(Box>, Vec>), + // IntegerExp(IntegerExpression), + // FieldElementExp(FieldExpression), + // BooleanExp(BooleanExpression), } /// Definition assignee: v, arr[0..2], Point p.x #[derive(Debug, Clone)] pub enum Assignee { Variable(Variable), - Array(Box>, IntegerRangeOrExpression), + Array(Box>, RangeOrExpression), StructMember(Box>, Variable), } @@ -150,12 +108,7 @@ pub enum Assignee { pub enum Statement { // Declaration(Variable), Definition(Assignee, Expression), - For( - Variable, - IntegerExpression, - IntegerExpression, - Vec>, - ), + For(Variable, Integer, Integer, Vec>), Return(Vec>), } @@ -232,3 +185,95 @@ impl<'ast, F: Field + PrimeField> Program<'ast, F> { self } } + +// /// Spread operator or u32 expression enum +// #[derive(Debug, Clone)] +// pub enum IntegerSpreadOrExpression { +// Spread(IntegerExpression), +// Expression(IntegerExpression), +// } + +// Expression that evaluates to a u32 value +// #[derive(Debug, Clone)] +// pub enum IntegerExpression { +// Variable(Variable), +// Number(Integer), +// Operators +// Add(Box>, Box>), +// Sub(Box>, Box>), +// Mul(Box>, Box>), +// Div(Box>, Box>), +// Pow(Box>, Box>), +// Conditionals +// IfElse( +// Box>, +// Box>, +// Box>, +// ), +// Arrays +// Array(Vec>>), +// // Unresolved +// Unresolved(Box>) // placeholder for array/struct access, function calls +// } + +// /// Spread or field expression enum +// #[derive(Debug, Clone)] +// pub enum FieldSpreadOrExpression { +// Spread(FieldExpression), +// Expression(FieldExpression), +// } + +// /// Expression that evaluates to a field value +// #[derive(Debug, Clone)] +// pub enum FieldExpression { +// Variable(Variable), +// Number(F), +// Operators +// Add(Box>, Box>), +// Sub(Box>, Box>), +// Mul(Box>, Box>), +// Div(Box>, Box>), +// Pow(Box>, Box>), +// Conditionals +// IfElse( +// Box>, +// Box>, +// Box>, +// ), +// Arrays +// Array(Vec>>), +// } + +// /// Spread or field expression enum +// #[derive(Debug, Clone)] +// pub enum BooleanSpreadOrExpression { +// Spread(BooleanExpression), +// Expression(BooleanExpression), +// } + +// Expression that evaluates to a boolean value +// #[derive(Debug, Clone)] +// pub enum BooleanExpression { +// Variable(Variable), +// Value(bool), +// Boolean operators +// Or(Box>, Box>), +// And(Box>, Box>), +// BoolEq(Box>, Box>), +// // Integer operators +// IntegerEq(Box>, Box>), +// Geq(Box>, Box>), +// Gt(Box>, Box>), +// Leq(Box>, Box>), +// Lt(Box>, Box>), +// // Field operators +// FieldEq(Box>, Box>), +// Conditionals +// IfElse( +// Box>, +// Box>, +// Box>, +// ), +// Arrays +// Array(Vec>>), +// } diff --git a/src/program/types_display.rs b/src/program/types_display.rs index 4e890f2c97..65464dffd4 100644 --- a/src/program/types_display.rs +++ b/src/program/types_display.rs @@ -5,10 +5,8 @@ //! @date 2020 use crate::program::{ - Assignee, BooleanExpression, BooleanSpreadOrExpression, Expression, FieldExpression, - FieldSpreadOrExpression, Function, FunctionName, Integer, IntegerExpression, - IntegerRangeOrExpression, IntegerSpreadOrExpression, Parameter, Statement, Struct, StructField, - Type, Variable, + Assignee, Expression, Function, FunctionName, Integer, Parameter, RangeOrExpression, + SpreadOrExpression, Statement, Struct, StructField, Type, Variable, }; use snarkos_models::curves::{Field, PrimeField}; @@ -33,132 +31,29 @@ impl fmt::Display for Integer { } } -impl fmt::Display for IntegerSpreadOrExpression { +impl<'ast, F: Field + PrimeField> fmt::Display for RangeOrExpression { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - IntegerSpreadOrExpression::Spread(ref spread) => write!(f, "...{}", spread), - IntegerSpreadOrExpression::Expression(ref expression) => write!(f, "{}", expression), - } - } -} - -impl<'ast, F: Field + PrimeField> fmt::Display for IntegerRangeOrExpression { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - IntegerRangeOrExpression::Range(ref from, ref to) => write!( + RangeOrExpression::Range(ref from, ref to) => write!( f, "{}..{}", from.as_ref() - .map(|e| e.to_string()) + .map(|e| format!("{}", e)) .unwrap_or("".to_string()), - to.as_ref().map(|e| e.to_string()).unwrap_or("".to_string()) + to.as_ref() + .map(|e| format!("{}", e)) + .unwrap_or("".to_string()) ), - IntegerRangeOrExpression::Expression(ref e) => write!(f, "{}", e), + RangeOrExpression::Expression(ref e) => write!(f, "{}", e), } } } -impl<'ast, F: Field + PrimeField> fmt::Display for IntegerExpression { +impl fmt::Display for SpreadOrExpression { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - IntegerExpression::Variable(ref variable) => write!(f, "{}", variable), - IntegerExpression::Number(ref number) => write!(f, "{}", number), - IntegerExpression::Add(ref lhs, ref rhs) => write!(f, "{} + {}", lhs, rhs), - IntegerExpression::Sub(ref lhs, ref rhs) => write!(f, "{} - {}", lhs, rhs), - IntegerExpression::Mul(ref lhs, ref rhs) => write!(f, "{} * {}", lhs, rhs), - IntegerExpression::Div(ref lhs, ref rhs) => write!(f, "{} / {}", lhs, rhs), - IntegerExpression::Pow(ref lhs, ref rhs) => write!(f, "{} ** {}", lhs, rhs), - IntegerExpression::IfElse(ref a, ref b, ref c) => { - write!(f, "if {} then {} else {} fi", a, b, c) - } - IntegerExpression::Array(ref array) => { - write!(f, "[")?; - for (i, e) in array.iter().enumerate() { - write!(f, "{}", e)?; - if i < array.len() - 1 { - write!(f, ", ")?; - } - } - write!(f, "]") - } - } - } -} - -impl fmt::Display for FieldSpreadOrExpression { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - FieldSpreadOrExpression::Spread(ref spread) => write!(f, "...{}", spread), - FieldSpreadOrExpression::Expression(ref expression) => write!(f, "{}", expression), - } - } -} - -impl<'ast, F: Field + PrimeField> fmt::Display for FieldExpression { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - FieldExpression::Variable(ref variable) => write!(f, "{}", variable), - FieldExpression::Number(ref number) => write!(f, "{}", number), - FieldExpression::Add(ref lhs, ref rhs) => write!(f, "{} + {}", lhs, rhs), - FieldExpression::Sub(ref lhs, ref rhs) => write!(f, "{} - {}", lhs, rhs), - FieldExpression::Mul(ref lhs, ref rhs) => write!(f, "{} * {}", lhs, rhs), - FieldExpression::Div(ref lhs, ref rhs) => write!(f, "{} / {}", lhs, rhs), - FieldExpression::Pow(ref lhs, ref rhs) => write!(f, "{} ** {}", lhs, rhs), - FieldExpression::IfElse(ref a, ref b, ref c) => { - write!(f, "if {} then {} else {} fi", a, b, c) - } - FieldExpression::Array(ref array) => { - write!(f, "[")?; - for (i, e) in array.iter().enumerate() { - write!(f, "{}", e)?; - if i < array.len() - 1 { - write!(f, ", ")?; - } - } - write!(f, "]") - } // _ => unimplemented!("not all field expressions can be displayed") - } - } -} - -impl fmt::Display for BooleanSpreadOrExpression { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - BooleanSpreadOrExpression::Spread(ref spread) => write!(f, "...{}", spread), - BooleanSpreadOrExpression::Expression(ref expression) => write!(f, "{}", expression), - } - } -} - -impl<'ast, F: Field + PrimeField> fmt::Display for BooleanExpression { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - BooleanExpression::Variable(ref variable) => write!(f, "{}", variable), - BooleanExpression::Value(ref value) => write!(f, "{}", value), - BooleanExpression::Not(ref expression) => write!(f, "!{}", expression), - BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs), - BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs), - BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), - BooleanExpression::IntegerEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), - BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), - // BooleanExpression::Neq(ref lhs, ref rhs) => write!(f, "{} != {}", lhs, rhs), - BooleanExpression::Geq(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs), - BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs), - BooleanExpression::Leq(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs), - BooleanExpression::Lt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs), - BooleanExpression::IfElse(ref a, ref b, ref c) => { - write!(f, "if {} then {} else {} fi", a, b, c) - } - BooleanExpression::Array(ref array) => { - write!(f, "[")?; - for (i, e) in array.iter().enumerate() { - write!(f, "{}", e)?; - if i < array.len() - 1 { - write!(f, ", ")?; - } - } - write!(f, "]") - } + SpreadOrExpression::Spread(ref spread) => write!(f, "...{}", spread), + SpreadOrExpression::Expression(ref expression) => write!(f, "{}", expression), } } } @@ -166,10 +61,48 @@ impl<'ast, F: Field + PrimeField> fmt::Display for BooleanExpression { impl<'ast, F: Field + PrimeField> fmt::Display for Expression { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Expression::Integer(ref integer_expression) => write!(f, "{}", integer_expression), - Expression::FieldElement(ref field_expression) => write!(f, "{}", field_expression), - Expression::Boolean(ref boolean_expression) => write!(f, "{}", boolean_expression), + // Variables Expression::Variable(ref variable) => write!(f, "{}", variable), + + // Values + Expression::Integer(ref integer) => write!(f, "{}", integer), + Expression::FieldElement(ref fe) => write!(f, "{}", fe), + Expression::Boolean(ref bool) => write!(f, "{}", bool), + + // Number operations + Expression::Add(ref left, ref right) => write!(f, "{} + {}", left, right), + Expression::Sub(ref left, ref right) => write!(f, "{} - {}", left, right), + Expression::Mul(ref left, ref right) => write!(f, "{} * {}", left, right), + Expression::Div(ref left, ref right) => write!(f, "{} / {}", left, right), + Expression::Pow(ref left, ref right) => write!(f, "{} ** {}", left, right), + + // Boolean operations + Expression::Not(ref expression) => write!(f, "!{}", expression), + Expression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs), + Expression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs), + Expression::Eq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), + Expression::Geq(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs), + Expression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs), + Expression::Leq(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs), + Expression::Lt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs), + + // Conditionals + Expression::IfElse(ref first, ref second, ref third) => { + write!(f, "if {} then {} else {} fi", first, second, third) + } + + Expression::Array(ref array) => { + write!(f, "[")?; + for (i, e) in array.iter().enumerate() { + write!(f, "{}", e)?; + if i < array.len() - 1 { + write!(f, ", ")?; + } + } + write!(f, "]") + } + Expression::ArrayAccess(ref array, ref index) => write!(f, "{}[{}]", array, index), + Expression::Struct(ref var, ref members) => { write!(f, "{} {{", var)?; for (i, member) in members.iter().enumerate() { @@ -180,7 +113,6 @@ impl<'ast, F: Field + PrimeField> fmt::Display for Expression { } write!(f, "}}") } - Expression::ArrayAccess(ref array, ref index) => write!(f, "{}[{}]", array, index), Expression::StructMemberAccess(ref struct_variable, ref member) => { write!(f, "{}.{}", struct_variable, member) } @@ -350,3 +282,117 @@ impl fmt::Debug for Function { ) } } + +// impl fmt::Display for IntegerSpreadOrExpression { +// fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +// match *self { +// IntegerSpreadOrExpression::Spread(ref spread) => write!(f, "...{}", spread), +// IntegerSpreadOrExpression::Expression(ref expression) => write!(f, "{}", expression), +// } +// } +// } + +// impl<'ast, F: Field + PrimeField> fmt::Display for IntegerExpression { +// fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +// match *self { +// IntegerExpression::Variable(ref variable) => write!(f, "{}", variable), +// IntegerExpression::Number(ref number) => write!(f, "{}", number), +// IntegerExpression::Add(ref lhs, ref rhs) => write!(f, "{} + {}", lhs, rhs), +// IntegerExpression::Sub(ref lhs, ref rhs) => write!(f, "{} - {}", lhs, rhs), +// IntegerExpression::Mul(ref lhs, ref rhs) => write!(f, "{} * {}", lhs, rhs), +// IntegerExpression::Div(ref lhs, ref rhs) => write!(f, "{} / {}", lhs, rhs), +// IntegerExpression::Pow(ref lhs, ref rhs) => write!(f, "{} ** {}", lhs, rhs), +// IntegerExpression::IfElse(ref a, ref b, ref c) => { +// write!(f, "if {} then {} else {} fi", a, b, c) +// } +// IntegerExpression::Array(ref array) => { +// write!(f, "[")?; +// for (i, e) in array.iter().enumerate() { +// write!(f, "{}", e)?; +// if i < array.len() - 1 { +// write!(f, ", ")?; +// } +// } +// write!(f, "]") +// } +// } +// } +// } + +// impl fmt::Display for FieldSpreadOrExpression { +// fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +// match *self { +// FieldSpreadOrExpression::Spread(ref spread) => write!(f, "...{}", spread), +// FieldSpreadOrExpression::Expression(ref expression) => write!(f, "{}", expression), +// } +// } +// } +// +// impl<'ast, F: Field + PrimeField> fmt::Display for FieldExpression { +// fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +// match *self { +// FieldExpression::Variable(ref variable) => write!(f, "{}", variable), +// FieldExpression::Number(ref number) => write!(f, "{}", number), +// FieldExpression::Add(ref lhs, ref rhs) => write!(f, "{} + {}", lhs, rhs), +// FieldExpression::Sub(ref lhs, ref rhs) => write!(f, "{} - {}", lhs, rhs), +// FieldExpression::Mul(ref lhs, ref rhs) => write!(f, "{} * {}", lhs, rhs), +// FieldExpression::Div(ref lhs, ref rhs) => write!(f, "{} / {}", lhs, rhs), +// FieldExpression::Pow(ref lhs, ref rhs) => write!(f, "{} ** {}", lhs, rhs), +// FieldExpression::IfElse(ref a, ref b, ref c) => { +// write!(f, "if {} then {} else {} fi", a, b, c) +// } +// FieldExpression::Array(ref array) => { +// write!(f, "[")?; +// for (i, e) in array.iter().enumerate() { +// write!(f, "{}", e)?; +// if i < array.len() - 1 { +// write!(f, ", ")?; +// } +// } +// write!(f, "]") +// } // _ => unimplemented!("not all field expressions can be displayed") +// } +// } +// } + +// impl fmt::Display for BooleanSpreadOrExpression { +// fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +// match *self { +// BooleanSpreadOrExpression::Spread(ref spread) => write!(f, "...{}", spread), +// BooleanSpreadOrExpression::Expression(ref expression) => write!(f, "{}", expression), +// } +// } +// } + +// impl<'ast, F: Field + PrimeField> fmt::Display for BooleanExpression { +// fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +// match *self { +// BooleanExpression::Variable(ref variable) => write!(f, "{}", variable), +// BooleanExpression::Value(ref value) => write!(f, "{}", value), +// BooleanExpression::Not(ref expression) => write!(f, "!{}", expression), +// BooleanExpression::Or(ref lhs, ref rhs) => write!(f, "{} || {}", lhs, rhs), +// BooleanExpression::And(ref lhs, ref rhs) => write!(f, "{} && {}", lhs, rhs), +// BooleanExpression::BoolEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), +// BooleanExpression::IntegerEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), +// BooleanExpression::FieldEq(ref lhs, ref rhs) => write!(f, "{} == {}", lhs, rhs), +// // BooleanExpression::Neq(ref lhs, ref rhs) => write!(f, "{} != {}", lhs, rhs), +// BooleanExpression::Geq(ref lhs, ref rhs) => write!(f, "{} >= {}", lhs, rhs), +// BooleanExpression::Gt(ref lhs, ref rhs) => write!(f, "{} > {}", lhs, rhs), +// BooleanExpression::Leq(ref lhs, ref rhs) => write!(f, "{} <= {}", lhs, rhs), +// BooleanExpression::Lt(ref lhs, ref rhs) => write!(f, "{} < {}", lhs, rhs), +// BooleanExpression::IfElse(ref a, ref b, ref c) => { +// write!(f, "if {} then {} else {} fi", a, b, c) +// } +// BooleanExpression::Array(ref array) => { +// write!(f, "[")?; +// for (i, e) in array.iter().enumerate() { +// write!(f, "{}", e)?; +// if i < array.len() - 1 { +// write!(f, ", ")?; +// } +// } +// write!(f, "]") +// } +// } +// } +// } diff --git a/src/program/types_from.rs b/src/program/types_from.rs index c262cc6091..c4c88f505b 100644 --- a/src/program/types_from.rs +++ b/src/program/types_from.rs @@ -23,291 +23,291 @@ impl<'ast, F: Field + PrimeField> From> for types::Variable< } } -/// pest ast - types::Integer - -impl<'ast, F: Field + PrimeField> From> for types::IntegerExpression { +impl<'ast, F: Field + PrimeField> From> for types::Expression { fn from(variable: ast::Variable<'ast>) -> Self { - types::IntegerExpression::Variable(types::Variable::from(variable)) + types::Expression::Variable(types::Variable::from(variable)) } } +/// pest ast - types::Integer -impl<'ast, F: Field + PrimeField> From> for types::IntegerExpression { +impl<'ast, F: Field + PrimeField> From> for types::Expression { fn from(field: ast::U32<'ast>) -> Self { - types::IntegerExpression::Number(types::Integer::U32( + types::Expression::Integer(types::Integer::U32( field .number .value .parse::() - .expect("unable to unwrap u32"), + .expect("unable to parse u32"), )) } } -impl<'ast, F: Field + PrimeField> From> for types::IntegerExpression { - fn from(expression: ast::Expression<'ast>) -> Self { - match types::Expression::from(expression) { - types::Expression::Integer(integer_expression) => integer_expression, - types::Expression::Variable(variable) => types::IntegerExpression::Variable(variable), - _ => unimplemented!("expected integer in integer expression"), - } - } -} +// impl<'ast, F: Field + PrimeField> From> for types::Expression { +// fn from(expression: ast::Expression<'ast>) -> Self { +// match types::Expression::from(expression) { +// types::Expression::Variable(variable) => types::IntegerExpression::Variable(variable), +// types::Expression::IntegerExp(integer_expression) => integer_expression, +// expression => unimplemented!("expected integer in integer expression, got {}", expression), +// } +// } +// } -impl<'ast, F: Field + PrimeField> From> - for types::IntegerSpreadOrExpression -{ - fn from(s_or_e: ast::SpreadOrExpression<'ast>) -> Self { - match s_or_e { - ast::SpreadOrExpression::Spread(spread) => types::IntegerSpreadOrExpression::Spread( - types::IntegerExpression::from(spread.expression), - ), - ast::SpreadOrExpression::Expression(expression) => { - types::IntegerSpreadOrExpression::Expression(types::IntegerExpression::from( - expression, - )) - } - } - } -} +// impl<'ast, F: Field + PrimeField> From> +// for types::IntegerSpreadOrExpression +// { +// fn from(s_or_e: ast::SpreadOrExpression<'ast>) -> Self { +// match s_or_e { +// ast::SpreadOrExpression::Spread(spread) => types::IntegerSpreadOrExpression::Spread( +// types::IntegerExpression::from(spread.expression), +// ), +// ast::SpreadOrExpression::Expression(expression) => { +// types::IntegerSpreadOrExpression::Expression(types::IntegerExpression::from( +// expression, +// )) +// } +// } +// } +// } impl<'ast, F: Field + PrimeField> From> - for types::IntegerRangeOrExpression + for types::RangeOrExpression { fn from(range_or_expression: ast::RangeOrExpression<'ast>) -> Self { match range_or_expression { ast::RangeOrExpression::Range(range) => { let from = range .from - .map(|from| match types::Expression::from(from.0) { + .map(|from| match types::Expression::::from(from.0) { types::Expression::Integer(number) => number, expression => { - unimplemented!("Range bounds should be numbers, found {}", expression) + unimplemented!("Range bounds should be integers, found {}", expression) } }); - let to = range.to.map(|to| match types::Expression::from(to.0) { + let to = range.to.map(|to| match types::Expression::::from(to.0) { types::Expression::Integer(number) => number, expression => { - unimplemented!("Range bounds should be numbers, found {}", expression) + unimplemented!("Range bounds should be intgers, found {}", expression) } }); - types::IntegerRangeOrExpression::Range(from, to) + types::RangeOrExpression::Range(from, to) } ast::RangeOrExpression::Expression(expression) => { - match types::Expression::from(expression) { - types::Expression::Integer(expression) => { - types::IntegerRangeOrExpression::Expression(expression) - } - // types::Expression::ArrayAccess(expression, field), // recursive array access - expression => unimplemented!("expression must be number, found {}", expression), - } + types::RangeOrExpression::Expression(types::Expression::from(expression)) } } } } +// +// impl<'ast, F: Field + PrimeField> From> for types::FieldExpression { +// fn from(variable: ast::Variable<'ast>) -> Self { +// types::FieldExpression::Variable(types::Variable::from(variable)) +// } +// } /// pest ast -> types::FieldExpression -impl<'ast, F: Field + PrimeField> From> for types::FieldExpression { - fn from(variable: ast::Variable<'ast>) -> Self { - types::FieldExpression::Variable(types::Variable::from(variable)) - } -} - -impl<'ast, F: Field + PrimeField> From> for types::FieldExpression { +impl<'ast, F: Field + PrimeField> From> for types::Expression { fn from(field: ast::Field<'ast>) -> Self { - types::FieldExpression::Number(F::from_str(&field.number.value).unwrap_or_default()) + types::Expression::FieldElement(F::from_str(&field.number.value).unwrap_or_default()) } } -impl<'ast, F: Field + PrimeField> From> for types::FieldExpression { - fn from(expression: ast::Expression<'ast>) -> Self { - match types::Expression::from(expression) { - types::Expression::FieldElement(field_expression) => field_expression, - types::Expression::Variable(variable) => types::FieldExpression::Variable(variable), - ty => unimplemented!("expected field in field expression, got {}", ty), - } - } -} +// impl<'ast, F: Field + PrimeField> From> for types::FieldExpression { +// fn from(expression: ast::Expression<'ast>) -> Self { +// match types::Expression::from(expression) { +// types::Expression::FieldElementExp(field_expression) => field_expression, +// types::Expression::Variable(variable) => types::FieldExpression::Variable(variable), +// expression => unimplemented!("expected field in field expression, got {}", expression), +// } +// } +// } -impl<'ast, F: Field + PrimeField> From> - for types::FieldSpreadOrExpression -{ - fn from(s_or_e: ast::SpreadOrExpression<'ast>) -> Self { - match s_or_e { - ast::SpreadOrExpression::Spread(spread) => types::FieldSpreadOrExpression::Spread( - types::FieldExpression::from(spread.expression), - ), - ast::SpreadOrExpression::Expression(expression) => { - types::FieldSpreadOrExpression::Expression(types::FieldExpression::from(expression)) - } - } - } -} +// impl<'ast, F: Field + PrimeField> From> +// for types::FieldSpreadOrExpression +// { +// fn from(s_or_e: ast::SpreadOrExpression<'ast>) -> Self { +// match s_or_e { +// ast::SpreadOrExpression::Spread(spread) => types::FieldSpreadOrExpression::Spread( +// types::FieldExpression::from(spread.expression), +// ), +// ast::SpreadOrExpression::Expression(expression) => { +// types::FieldSpreadOrExpression::Expression(types::FieldExpression::from(expression)) +// } +// } +// } +// } + +// impl<'ast, F: Field + PrimeField> From> for types::BooleanExpression { +// fn from(variable: ast::Variable<'ast>) -> Self { +// types::BooleanExpression::Variable(types::Variable::from(variable)) +// } +// } /// pest ast -> types::Boolean -impl<'ast, F: Field + PrimeField> From> for types::BooleanExpression { - fn from(variable: ast::Variable<'ast>) -> Self { - types::BooleanExpression::Variable(types::Variable::from(variable)) - } -} - -impl<'ast, F: Field + PrimeField> From> for types::BooleanExpression { +impl<'ast, F: Field + PrimeField> From> for types::Expression { fn from(boolean: ast::Boolean<'ast>) -> Self { - types::BooleanExpression::Value( + types::Expression::Boolean( boolean .value .parse::() - .expect("unable to unwrap boolean"), + .expect("unable to parse boolean"), ) } } -impl<'ast, F: Field + PrimeField> From> for types::BooleanExpression { - fn from(expression: ast::Expression<'ast>) -> Self { - match types::Expression::from(expression) { - types::Expression::Boolean(boolean_expression) => boolean_expression, - types::Expression::Variable(variable) => types::BooleanExpression::Variable(variable), - _ => unimplemented!("expected boolean in boolean expression"), - } - } -} +// impl<'ast, F: Field + PrimeField> From> for types::BooleanExpression { +// fn from(expression: ast::Expression<'ast>) -> Self { +// match types::Expression::from(expression) { +// types::Expression::BooleanExp(boolean_expression) => boolean_expression, +// types::Expression::Variable(variable) => types::BooleanExpression::Variable(variable), +// expression => unimplemented!("expected boolean in boolean expression, got {}", expression), +// } +// } +// } -impl<'ast, F: Field + PrimeField> From> - for types::BooleanSpreadOrExpression -{ - fn from(s_or_e: ast::SpreadOrExpression<'ast>) -> Self { - match s_or_e { - ast::SpreadOrExpression::Spread(spread) => types::BooleanSpreadOrExpression::Spread( - types::BooleanExpression::from(spread.expression), - ), - ast::SpreadOrExpression::Expression(expression) => { - types::BooleanSpreadOrExpression::Expression(types::BooleanExpression::from( - expression, - )) - } - } - } -} +// impl<'ast, F: Field + PrimeField> From> +// for types::BooleanSpreadOrExpression +// { +// fn from(s_or_e: ast::SpreadOrExpression<'ast>) -> Self { +// match s_or_e { +// ast::SpreadOrExpression::Spread(spread) => types::BooleanSpreadOrExpression::Spread( +// types::BooleanExpression::from(spread.expression), +// ), +// ast::SpreadOrExpression::Expression(expression) => { +// types::BooleanSpreadOrExpression::Expression(types::BooleanExpression::from( +// expression, +// )) +// } +// } +// } +// } /// pest ast -> types::Expression impl<'ast, F: Field + PrimeField> From> for types::Expression { fn from(value: ast::Value<'ast>) -> Self { match value { - ast::Value::U32(value) => { - types::Expression::Integer(types::IntegerExpression::from(value)) - } - ast::Value::Field(field) => { - types::Expression::FieldElement(types::FieldExpression::from(field)) - } - ast::Value::Boolean(value) => { - types::Expression::Boolean(types::BooleanExpression::from(value)) - } + ast::Value::U32(num) => types::Expression::from(num), + ast::Value::Field(fe) => types::Expression::from(fe), + ast::Value::Boolean(bool) => types::Expression::from(bool), } } } -impl<'ast, F: Field + PrimeField> From> for types::Expression { - fn from(variable: ast::Variable<'ast>) -> Self { - types::Expression::Variable(types::Variable::from(variable)) - } -} +// impl<'ast, F: Field + PrimeField> From> for types::Expression { +// fn from(variable: ast::Variable<'ast>) -> Self { +// types::Expression::Variable(types::Variable::from(variable)) +// } +// } impl<'ast, F: Field + PrimeField> From> for types::Expression { fn from(expression: ast::NotExpression<'ast>) -> Self { - types::Expression::Boolean(types::BooleanExpression::Not(Box::new( - types::BooleanExpression::from(*expression.expression), - ))) + types::Expression::Not(Box::new(types::Expression::from(*expression.expression))) } } -impl<'ast, F: Field + PrimeField> types::BooleanExpression { - /// Find out which types we are comparing and output the corresponding expression. - fn from_eq(expression: ast::BinaryExpression<'ast>) -> Self { - let left = types::Expression::from(*expression.left); - let right = types::Expression::from(*expression.right); +// impl<'ast, F: Field + PrimeField> types::BooleanExpression { +// /// Find out which types we are comparing and output the corresponding expression. +// fn from_eq(expression: ast::BinaryExpression<'ast>) -> Self { +// let left = types::Expression::from(*expression.left); +// let right = types::Expression::from(*expression.right); +// +// // When matching a variable, look at the opposite side to see what we are comparing to and assume that variable type +// match (left, right) { +// // Boolean equality +// (types::Expression::BooleanExp(lhs), types::Expression::BooleanExp(rhs)) => { +// types::BooleanExpression::BoolEq(Box::new(lhs), Box::new(rhs)) +// } +// (types::Expression::BooleanExp(lhs), types::Expression::Variable(rhs)) => { +// types::BooleanExpression::BoolEq( +// Box::new(lhs), +// Box::new(types::BooleanExpression::Variable(rhs)), +// ) +// } +// (types::Expression::Variable(lhs), types::Expression::BooleanExp(rhs)) => { +// types::BooleanExpression::BoolEq( +// Box::new(types::BooleanExpression::Variable(lhs)), +// Box::new(rhs), +// ) +// } //TODO: check case for two variables? +// // Integer equality +// (types::Expression::IntegerExp(lhs), types::Expression::IntegerExp(rhs)) => { +// types::BooleanExpression::IntegerEq(Box::new(lhs), Box::new(rhs)) +// } +// (types::Expression::IntegerExp(lhs), types::Expression::Variable(rhs)) => { +// types::BooleanExpression::IntegerEq( +// Box::new(lhs), +// Box::new(types::IntegerExpression::Variable(rhs)), +// ) +// } +// (types::Expression::Variable(lhs), types::Expression::IntegerExp(rhs)) => { +// types::BooleanExpression::IntegerEq( +// Box::new(types::IntegerExpression::Variable(lhs)), +// Box::new(rhs), +// ) +// } +// // Field equality +// (types::Expression::FieldElementExp(lhs), types::Expression::FieldElementExp(rhs)) => { +// types::BooleanExpression::FieldEq(Box::new(lhs), Box::new(rhs)) +// } +// (types::Expression::FieldElementExp(lhs), types::Expression::Variable(rhs)) => { +// types::BooleanExpression::FieldEq( +// Box::new(lhs), +// Box::new(types::FieldExpression::Variable(rhs)), +// ) +// } +// (types::Expression::Variable(lhs), types::Expression::FieldElementExp(rhs)) => { +// types::BooleanExpression::FieldEq( +// Box::new(types::FieldExpression::Variable(lhs)), +// Box::new(rhs), +// ) +// } +// +// (lhs, rhs) => unimplemented!("pattern {} == {} unimplemented", lhs, rhs), +// } +// } +// +// fn from_neq(expression: ast::BinaryExpression<'ast>) -> Self { +// types::BooleanExpression::Not(Box::new(Self::from_eq(expression))) +// } +// } +// impl<'ast, F: Field + PrimeField> types::Type { +// fn resolve_type(left: &Box>, right: &Box>) -> Self { +// let left = types::Expression::::from(*left.clone()); +// let right = types::Expression::::from(*right.clone()); +// +// match (left, right) { +// // Integer operation +// (types::Expression::IntegerExp(_), _) | (_, types::Expression::IntegerExp(_)) => { +// types::Type::U32 +// } +// // Field operation +// (types::Expression::FieldElementExp(_), _) | (_, types::Expression::FieldElementExp(_)) => { +// types::Type::FieldElement +// } +// // Unmatched: two array accesses, two variables +// (lhs, rhs) => unimplemented!( +// "operand types {} and {} must match for binary expression", +// lhs, +// rhs +// ), +// } +// } +// } - // When matching a variable, look at the opposite side to see what we are comparing to and assume that variable type - match (left, right) { - // Boolean equality - (types::Expression::Boolean(lhs), types::Expression::Boolean(rhs)) => { - types::BooleanExpression::BoolEq(Box::new(lhs), Box::new(rhs)) +impl<'ast, F: Field + PrimeField> From> + for types::SpreadOrExpression +{ + fn from(s_or_e: ast::SpreadOrExpression<'ast>) -> Self { + match s_or_e { + ast::SpreadOrExpression::Spread(spread) => { + types::SpreadOrExpression::Spread(types::Expression::from(spread.expression)) } - (types::Expression::Boolean(lhs), types::Expression::Variable(rhs)) => { - types::BooleanExpression::BoolEq( - Box::new(lhs), - Box::new(types::BooleanExpression::Variable(rhs)), - ) + ast::SpreadOrExpression::Expression(expression) => { + types::SpreadOrExpression::Expression(types::Expression::from(expression)) } - (types::Expression::Variable(lhs), types::Expression::Boolean(rhs)) => { - types::BooleanExpression::BoolEq( - Box::new(types::BooleanExpression::Variable(lhs)), - Box::new(rhs), - ) - } //TODO: check case for two variables? - // Integer equality - (types::Expression::Integer(lhs), types::Expression::Integer(rhs)) => { - types::BooleanExpression::IntegerEq(Box::new(lhs), Box::new(rhs)) - } - (types::Expression::Integer(lhs), types::Expression::Variable(rhs)) => { - types::BooleanExpression::IntegerEq( - Box::new(lhs), - Box::new(types::IntegerExpression::Variable(rhs)), - ) - } - (types::Expression::Variable(lhs), types::Expression::Integer(rhs)) => { - types::BooleanExpression::IntegerEq( - Box::new(types::IntegerExpression::Variable(lhs)), - Box::new(rhs), - ) - } - // Field equality - (types::Expression::FieldElement(lhs), types::Expression::FieldElement(rhs)) => { - types::BooleanExpression::FieldEq(Box::new(lhs), Box::new(rhs)) - } - (types::Expression::FieldElement(lhs), types::Expression::Variable(rhs)) => { - types::BooleanExpression::FieldEq( - Box::new(lhs), - Box::new(types::FieldExpression::Variable(rhs)), - ) - } - (types::Expression::Variable(lhs), types::Expression::FieldElement(rhs)) => { - types::BooleanExpression::FieldEq( - Box::new(types::FieldExpression::Variable(lhs)), - Box::new(rhs), - ) - } - - (lhs, rhs) => unimplemented!("pattern {} == {} unimplemented", lhs, rhs), - } - } - - fn from_neq(expression: ast::BinaryExpression<'ast>) -> Self { - types::BooleanExpression::Not(Box::new(Self::from_eq(expression))) - } -} -impl<'ast, F: Field + PrimeField> types::Type { - fn resolve_type(left: &Box>, right: &Box>) -> Self { - let left = types::Expression::::from(*left.clone()); - let right = types::Expression::::from(*right.clone()); - - match (left, right) { - // Integer operation - (types::Expression::Integer(_), _) | (_, types::Expression::Integer(_)) => { - types::Type::U32 - } - // Field operation - (types::Expression::FieldElement(_), _) | (_, types::Expression::FieldElement(_)) => { - types::Type::FieldElement - } - // Unmatched: two array accesses, two variables - (lhs, rhs) => unimplemented!( - "operand types {} and {} must match for binary expression", - lhs, - rhs - ), } } } @@ -316,112 +316,58 @@ impl<'ast, F: Field + PrimeField> From> for types::E fn from(expression: ast::BinaryExpression<'ast>) -> Self { match expression.operation { // Boolean operations - ast::BinaryOperator::Or => types::Expression::Boolean(types::BooleanExpression::Or( - Box::new(types::BooleanExpression::from(*expression.left)), - Box::new(types::BooleanExpression::from(*expression.right)), - )), - ast::BinaryOperator::And => types::Expression::Boolean(types::BooleanExpression::And( - Box::new(types::BooleanExpression::from(*expression.left)), - Box::new(types::BooleanExpression::from(*expression.right)), - )), - ast::BinaryOperator::Eq => { - types::Expression::Boolean(types::BooleanExpression::from_eq(expression)) - } + ast::BinaryOperator::Or => types::Expression::Or( + Box::new(types::Expression::from(*expression.left)), + Box::new(types::Expression::from(*expression.right)), + ), + ast::BinaryOperator::And => types::Expression::And( + Box::new(types::Expression::from(*expression.left)), + Box::new(types::Expression::from(*expression.right)), + ), + ast::BinaryOperator::Eq => types::Expression::Eq( + Box::new(types::Expression::from(*expression.left)), + Box::new(types::Expression::from(*expression.right)), + ), ast::BinaryOperator::Neq => { - types::Expression::Boolean(types::BooleanExpression::from_neq(expression)) - } - ast::BinaryOperator::Geq => types::Expression::Boolean(types::BooleanExpression::Geq( - Box::new(types::IntegerExpression::from(*expression.left)), - Box::new(types::IntegerExpression::from(*expression.right)), - )), - ast::BinaryOperator::Gt => types::Expression::Boolean(types::BooleanExpression::Gt( - Box::new(types::IntegerExpression::from(*expression.left)), - Box::new(types::IntegerExpression::from(*expression.right)), - )), - ast::BinaryOperator::Leq => types::Expression::Boolean(types::BooleanExpression::Leq( - Box::new(types::IntegerExpression::from(*expression.left)), - Box::new(types::IntegerExpression::from(*expression.right)), - )), - ast::BinaryOperator::Lt => types::Expression::Boolean(types::BooleanExpression::Lt( - Box::new(types::IntegerExpression::from(*expression.left)), - Box::new(types::IntegerExpression::from(*expression.right)), - )), - // Operations - ast::BinaryOperator::Add => { - match types::Type::::resolve_type(&expression.left, &expression.right) { - types::Type::U32 => types::Expression::Integer(types::IntegerExpression::Add( - Box::new(types::IntegerExpression::from(*expression.left)), - Box::new(types::IntegerExpression::from(*expression.right)), - )), - types::Type::FieldElement => { - types::Expression::FieldElement(types::FieldExpression::Add( - Box::new(types::FieldExpression::from(*expression.left)), - Box::new(types::FieldExpression::from(*expression.right)), - )) - } - _ => unimplemented!("unreachable"), - } - } - ast::BinaryOperator::Sub => { - match types::Type::::resolve_type(&expression.left, &expression.right) { - types::Type::U32 => types::Expression::Integer(types::IntegerExpression::Sub( - Box::new(types::IntegerExpression::from(*expression.left)), - Box::new(types::IntegerExpression::from(*expression.right)), - )), - types::Type::FieldElement => { - types::Expression::FieldElement(types::FieldExpression::Sub( - Box::new(types::FieldExpression::from(*expression.left)), - Box::new(types::FieldExpression::from(*expression.right)), - )) - } - _ => unimplemented!("unreachable"), - } - } - ast::BinaryOperator::Mul => { - match types::Type::::resolve_type(&expression.left, &expression.right) { - types::Type::U32 => types::Expression::Integer(types::IntegerExpression::Mul( - Box::new(types::IntegerExpression::from(*expression.left)), - Box::new(types::IntegerExpression::from(*expression.right)), - )), - types::Type::FieldElement => { - types::Expression::FieldElement(types::FieldExpression::Mul( - Box::new(types::FieldExpression::from(*expression.left)), - Box::new(types::FieldExpression::from(*expression.right)), - )) - } - _ => unimplemented!("unreachable"), - } - } - ast::BinaryOperator::Div => { - match types::Type::::resolve_type(&expression.left, &expression.right) { - types::Type::U32 => types::Expression::Integer(types::IntegerExpression::Div( - Box::new(types::IntegerExpression::from(*expression.left)), - Box::new(types::IntegerExpression::from(*expression.right)), - )), - types::Type::FieldElement => { - types::Expression::FieldElement(types::FieldExpression::Div( - Box::new(types::FieldExpression::from(*expression.left)), - Box::new(types::FieldExpression::from(*expression.right)), - )) - } - _ => unimplemented!("unreachable"), - } - } - ast::BinaryOperator::Pow => { - match types::Type::::resolve_type(&expression.left, &expression.right) { - types::Type::U32 => types::Expression::Integer(types::IntegerExpression::Pow( - Box::new(types::IntegerExpression::from(*expression.left)), - Box::new(types::IntegerExpression::from(*expression.right)), - )), - types::Type::FieldElement => { - types::Expression::FieldElement(types::FieldExpression::Pow( - Box::new(types::FieldExpression::from(*expression.left)), - Box::new(types::FieldExpression::from(*expression.right)), - )) - } - _ => unimplemented!("unreachable"), - } + types::Expression::Not(Box::new(types::Expression::from(expression))) } + ast::BinaryOperator::Geq => types::Expression::Geq( + Box::new(types::Expression::from(*expression.left)), + Box::new(types::Expression::from(*expression.right)), + ), + ast::BinaryOperator::Gt => types::Expression::Gt( + Box::new(types::Expression::from(*expression.left)), + Box::new(types::Expression::from(*expression.right)), + ), + ast::BinaryOperator::Leq => types::Expression::Leq( + Box::new(types::Expression::from(*expression.left)), + Box::new(types::Expression::from(*expression.right)), + ), + ast::BinaryOperator::Lt => types::Expression::Lt( + Box::new(types::Expression::from(*expression.left)), + Box::new(types::Expression::from(*expression.right)), + ), + // Number operations + ast::BinaryOperator::Add => types::Expression::Add( + Box::new(types::Expression::from(*expression.left)), + Box::new(types::Expression::from(*expression.right)), + ), + ast::BinaryOperator::Sub => types::Expression::Sub( + Box::new(types::Expression::from(*expression.left)), + Box::new(types::Expression::from(*expression.right)), + ), + ast::BinaryOperator::Mul => types::Expression::Mul( + Box::new(types::Expression::from(*expression.left)), + Box::new(types::Expression::from(*expression.right)), + ), + ast::BinaryOperator::Div => types::Expression::Div( + Box::new(types::Expression::from(*expression.left)), + Box::new(types::Expression::from(*expression.right)), + ), + ast::BinaryOperator::Pow => types::Expression::Pow( + Box::new(types::Expression::from(*expression.left)), + Box::new(types::Expression::from(*expression.right)), + ), } } } @@ -429,85 +375,91 @@ impl<'ast, F: Field + PrimeField> From> for types::E impl<'ast, F: Field + PrimeField> From> for types::Expression { fn from(expression: ast::TernaryExpression<'ast>) -> Self { // Evaluate expressions to find out result type - let first = types::BooleanExpression::from(*expression.first); - let second = types::Expression::from(*expression.second); - let third = types::Expression::from(*expression.third); + // let first = ; + // let second = ; + // let third = ; - match (second, third) { - // Boolean Result - (types::Expression::Boolean(second), types::Expression::Boolean(third)) => { - types::Expression::Boolean(types::BooleanExpression::IfElse( - Box::new(first), - Box::new(second), - Box::new(third), - )) - } - (types::Expression::Boolean(second), types::Expression::Variable(third)) => { - types::Expression::Boolean(types::BooleanExpression::IfElse( - Box::new(first), - Box::new(second), - Box::new(types::BooleanExpression::Variable(third)), - )) - } - (types::Expression::Variable(second), types::Expression::Boolean(third)) => { - types::Expression::Boolean(types::BooleanExpression::IfElse( - Box::new(first), - Box::new(types::BooleanExpression::Variable(second)), - Box::new(third), - )) - } - // Integer Result - (types::Expression::Integer(second), types::Expression::Integer(third)) => { - types::Expression::Integer(types::IntegerExpression::IfElse( - Box::new(first), - Box::new(second), - Box::new(third), - )) - } - (types::Expression::Integer(second), types::Expression::Variable(third)) => { - types::Expression::Integer(types::IntegerExpression::IfElse( - Box::new(first), - Box::new(second), - Box::new(types::IntegerExpression::Variable(third)), - )) - } - (types::Expression::Variable(second), types::Expression::Integer(third)) => { - types::Expression::Integer(types::IntegerExpression::IfElse( - Box::new(first), - Box::new(types::IntegerExpression::Variable(second)), - Box::new(third), - )) - } - // Field Result - (types::Expression::FieldElement(second), types::Expression::FieldElement(third)) => { - types::Expression::FieldElement(types::FieldExpression::IfElse( - Box::new(first), - Box::new(second), - Box::new(third), - )) - } - (types::Expression::FieldElement(second), types::Expression::Variable(third)) => { - types::Expression::FieldElement(types::FieldExpression::IfElse( - Box::new(first), - Box::new(second), - Box::new(types::FieldExpression::Variable(third)), - )) - } - (types::Expression::Variable(second), types::Expression::FieldElement(third)) => { - types::Expression::FieldElement(types::FieldExpression::IfElse( - Box::new(first), - Box::new(types::FieldExpression::Variable(second)), - Box::new(third), - )) - } + types::Expression::IfElse( + Box::new(types::Expression::from(*expression.first)), + Box::new(types::Expression::from(*expression.second)), + Box::new(types::Expression::from(*expression.third)), + ) - (second, third) => unimplemented!( - "pattern if {} then {} else {} unimplemented", - first, - second, - third - ), - } + // match (second, third) { + // // Boolean Result + // (types::Expression::BooleanExp(second), types::Expression::BooleanExp(third)) => { + // types::Expression::BooleanExp(types::Expression::IfElse( + // Box::new(first), + // Box::new(second), + // Box::new(third), + // )) + // } + // (types::Expression::BooleanExp(second), types::Expression::Variable(third)) => { + // types::Expression::BooleanExp(types::BooleanExpression::IfElse( + // Box::new(first), + // Box::new(second), + // Box::new(types::BooleanExpression::Variable(third)), + // )) + // } + // (types::Expression::Variable(second), types::Expression::BooleanExp(third)) => { + // types::Expression::BooleanExp(types::BooleanExpression::IfElse( + // Box::new(first), + // Box::new(types::BooleanExpression::Variable(second)), + // Box::new(third), + // )) + // } + // // Integer Result + // (types::Expression::IntegerExp(second), types::Expression::IntegerExp(third)) => { + // types::Expression::IntegerExp(types::IntegerExpression::IfElse( + // Box::new(first), + // Box::new(second), + // Box::new(third), + // )) + // } + // (types::Expression::IntegerExp(second), types::Expression::Variable(third)) => { + // types::Expression::IntegerExp(types::IntegerExpression::IfElse( + // Box::new(first), + // Box::new(second), + // Box::new(types::IntegerExpression::Variable(third)), + // )) + // } + // (types::Expression::Variable(second), types::Expression::IntegerExp(third)) => { + // types::Expression::IntegerExp(types::IntegerExpression::IfElse( + // Box::new(first), + // Box::new(types::IntegerExpression::Variable(second)), + // Box::new(third), + // )) + // } + // // Field Result + // (types::Expression::FieldElementExp(second), types::Expression::FieldElementExp(third)) => { + // types::Expression::FieldElementExp(types::FieldExpression::IfElse( + // Box::new(first), + // Box::new(second), + // Box::new(third), + // )) + // } + // (types::Expression::FieldElementExp(second), types::Expression::Variable(third)) => { + // types::Expression::FieldElementExp(types::FieldExpression::IfElse( + // Box::new(first), + // Box::new(second), + // Box::new(types::FieldExpression::Variable(third)), + // )) + // } + // (types::Expression::Variable(second), types::Expression::FieldElementExp(third)) => { + // types::Expression::FieldElementExp(types::FieldExpression::IfElse( + // Box::new(first), + // Box::new(types::FieldExpression::Variable(second)), + // Box::new(third), + // )) + // } + // + // (second, third) => unimplemented!( + // "pattern if {} then {} else {} unimplemented", + // first, + // second, + // third + // ), + // } } } @@ -542,12 +494,34 @@ impl<'ast, F: Field + PrimeField> From> for types:: ), ast::Access::Array(array) => types::Expression::ArrayAccess( Box::new(acc), - types::IntegerRangeOrExpression::from(array.expression), + Box::new(types::RangeOrExpression::from(array.expression)), ), }) } } +impl<'ast, F: Field + PrimeField> From> for types::Expression { + fn from(array: ast::ArrayInlineExpression<'ast>) -> Self { + types::Expression::Array( + array + .expressions + .into_iter() + .map(|s_or_e| Box::new(types::SpreadOrExpression::from(s_or_e))) + .collect(), + ) + } +} +impl<'ast, F: Field + PrimeField> From> + for types::Expression +{ + fn from(array: ast::ArrayInitializerExpression<'ast>) -> Self { + let count = types::Expression::::get_count(array.count); + let expression = Box::new(types::SpreadOrExpression::from(*array.expression)); + + types::Expression::Array(vec![expression; count]) + } +} + impl<'ast, F: Field + PrimeField> From> for types::Expression { fn from(expression: ast::Expression<'ast>) -> Self { match expression { @@ -556,12 +530,8 @@ impl<'ast, F: Field + PrimeField> From> for types::Express ast::Expression::Not(expression) => types::Expression::from(expression), ast::Expression::Binary(expression) => types::Expression::from(expression), ast::Expression::Ternary(expression) => types::Expression::from(expression), - ast::Expression::ArrayInline(_expression) => { - unimplemented!("unknown type for inline array expression") - } - ast::Expression::ArrayInitializer(_expression) => { - unimplemented!("unknown type for array initializer expression") - } + ast::Expression::ArrayInline(expression) => types::Expression::from(expression), + ast::Expression::ArrayInitializer(expression) => types::Expression::from(expression), ast::Expression::StructInline(_expression) => { unimplemented!("unknown type for inline struct expression") } @@ -575,9 +545,9 @@ impl<'ast, F: Field + PrimeField> From> for types::Express /// For defined types (ex: u32[4]) we manually construct the expression instead of implementing the From trait. /// This saves us from having to resolve things at a later point in time. impl<'ast, F: Field + PrimeField> types::Expression { - fn from_basic(_ty: ast::BasicType<'ast>, _expression: ast::Expression<'ast>) -> Self { - unimplemented!("from basic not impl"); - } + // fn from_basic(_ty: ast::BasicType<'ast>, expression: ast::Expression<'ast>) -> Self { + // types::Expression::from(expression) + // } fn get_count(count: ast::Value<'ast>) -> usize { match count { @@ -590,64 +560,64 @@ impl<'ast, F: Field + PrimeField> types::Expression { } } - fn from_array(ty: ast::ArrayType<'ast>, expression: ast::Expression<'ast>) -> Self { - match ty.ty { - ast::BasicType::U32(_ty) => { - let elements: Vec>> = match expression { - ast::Expression::ArrayInline(array) => array - .expressions - .into_iter() - .map(|s_or_e| Box::new(types::IntegerSpreadOrExpression::from(s_or_e))) - .collect(), - ast::Expression::ArrayInitializer(array) => { - let count = types::Expression::::get_count(array.count); - let expression = - Box::new(types::IntegerSpreadOrExpression::from(*array.expression)); - - vec![expression; count] - } - _ => unimplemented!("expected array after array type"), - }; - types::Expression::Integer(types::IntegerExpression::Array(elements)) - } - ast::BasicType::Field(_ty) => { - let elements: Vec>> = match expression { - ast::Expression::ArrayInline(array) => array - .expressions - .into_iter() - .map(|s_or_e| Box::new(types::FieldSpreadOrExpression::from(s_or_e))) - .collect(), - ast::Expression::ArrayInitializer(array) => { - let count = types::Expression::::get_count(array.count); - let expression = - Box::new(types::FieldSpreadOrExpression::from(*array.expression)); - - vec![expression; count] - } - _ => unimplemented!("expected array after array type"), - }; - types::Expression::FieldElement(types::FieldExpression::Array(elements)) - } - ast::BasicType::Boolean(_ty) => { - let elements: Vec>> = match expression { - ast::Expression::ArrayInline(array) => array - .expressions - .into_iter() - .map(|s_or_e| Box::new(types::BooleanSpreadOrExpression::from(s_or_e))) - .collect(), - ast::Expression::ArrayInitializer(array) => { - let count = types::Expression::::get_count(array.count); - let expression = - Box::new(types::BooleanSpreadOrExpression::from(*array.expression)); - - vec![expression; count] - } - _ => unimplemented!("expected array after array type"), - }; - types::Expression::Boolean(types::BooleanExpression::Array(elements)) - } - } - } + // fn from_array(ty: ast::ArrayType<'ast>, expression: ast::Expression<'ast>) -> Self { + // match ty.ty { + // ast::BasicType::U32(_ty) => { + // let elements: Vec>> = match expression { + // ast::Expression::ArrayInline(array) => array + // .expressions + // .into_iter() + // .map(|s_or_e| Box::new(types::IntegerSpreadOrExpression::from(s_or_e))) + // .collect(), + // ast::Expression::ArrayInitializer(array) => { + // let count = types::Expression::::get_count(array.count); + // let expression = + // Box::new(types::IntegerSpreadOrExpression::from(*array.expression)); + // + // vec![expression; count] + // } + // _ => unimplemented!("expected array after array type"), + // }; + // types::Expression::IntegerExp(types::IntegerExpression::Array(elements)) + // } + // ast::BasicType::Field(_ty) => { + // let elements: Vec>> = match expression { + // ast::Expression::ArrayInline(array) => array + // .expressions + // .into_iter() + // .map(|s_or_e| Box::new(types::FieldSpreadOrExpression::from(s_or_e))) + // .collect(), + // ast::Expression::ArrayInitializer(array) => { + // let count = types::Expression::::get_count(array.count); + // let expression = + // Box::new(types::FieldSpreadOrExpression::from(*array.expression)); + // + // vec![expression; count] + // } + // _ => unimplemented!("expected array after array type"), + // }; + // types::Expression::FieldElementExp(types::FieldExpression::Array(elements)) + // } + // ast::BasicType::Boolean(_ty) => { + // let elements: Vec>> = match expression { + // ast::Expression::ArrayInline(array) => array + // .expressions + // .into_iter() + // .map(|s_or_e| Box::new(types::BooleanSpreadOrExpression::from(s_or_e))) + // .collect(), + // ast::Expression::ArrayInitializer(array) => { + // let count = types::Expression::::get_count(array.count); + // let expression = + // Box::new(types::BooleanSpreadOrExpression::from(*array.expression)); + // + // vec![expression; count] + // } + // _ => unimplemented!("expected array after array type"), + // }; + // types::Expression::BooleanExp(types::BooleanExpression::Array(elements)) + // } + // } + // } fn from_struct(ty: ast::StructType<'ast>, expression: ast::Expression<'ast>) -> Self { let declaration_struct = ty.variable.value; @@ -671,8 +641,8 @@ impl<'ast, F: Field + PrimeField> types::Expression { fn from_type(ty: ast::Type<'ast>, expression: ast::Expression<'ast>) -> Self { match ty { - ast::Type::Basic(ty) => Self::from_basic(ty, expression), - ast::Type::Array(ty) => Self::from_array(ty, expression), + ast::Type::Basic(_ty) => Self::from(expression), + ast::Type::Array(_ty) => Self::from(expression), ast::Type::Struct(ty) => Self::from_struct(ty, expression), } } @@ -697,7 +667,7 @@ impl<'ast, F: Field + PrimeField> From> for types::Assignee< .fold(variable, |acc, access| match access { ast::AssigneeAccess::Array(array) => types::Assignee::Array( Box::new(acc), - types::IntegerRangeOrExpression::from(array.expression), + types::RangeOrExpression::from(array.expression), ), ast::AssigneeAccess::Member(struct_member) => types::Assignee::StructMember( Box::new(acc), @@ -741,10 +711,19 @@ impl<'ast, F: Field + PrimeField> From> for types::St impl<'ast, F: Field + PrimeField> From> for types::Statement { fn from(statement: ast::ForStatement<'ast>) -> Self { + let from = match types::Expression::::from(statement.start) { + types::Expression::Integer(number) => number, + expression => unimplemented!("Range bounds should be integers, found {}", expression), + }; + let to = match types::Expression::::from(statement.stop) { + types::Expression::Integer(number) => number, + expression => unimplemented!("Range bounds should be integers, found {}", expression), + }; + types::Statement::For( types::Variable::from(statement.index), - types::IntegerExpression::from(statement.start), - types::IntegerExpression::from(statement.stop), + from, + to, statement .statements .into_iter()