diff --git a/compiler/src/constraints/boolean.rs b/compiler/src/constraints/boolean.rs index 370df34743..93fcacadf3 100644 --- a/compiler/src/constraints/boolean.rs +++ b/compiler/src/constraints/boolean.rs @@ -12,7 +12,7 @@ use snarkos_models::{ curves::{Field, PrimeField}, gadgets::{ r1cs::ConstraintSystem, - utilities::{alloc::AllocGadget, boolean::Boolean, eq::EqGadget}, + utilities::{alloc::AllocGadget, boolean::Boolean}, }, }; @@ -46,10 +46,6 @@ impl> ConstrainedProgram { Ok(ConstrainedValue::Boolean(number)) } - pub(crate) fn get_boolean_constant(bool: Boolean) -> ConstrainedValue { - ConstrainedValue::Boolean(bool) - } - pub(crate) fn evaluate_not(value: ConstrainedValue) -> Result, BooleanError> { match value { ConstrainedValue::Boolean(boolean) => Ok(ConstrainedValue::Boolean(boolean.not())), @@ -90,17 +86,4 @@ impl> ConstrainedProgram { ))), } } - - pub(crate) fn boolean_eq(left: Boolean, right: Boolean) -> ConstrainedValue { - ConstrainedValue::Boolean(Boolean::Constant(left.eq(&right))) - } - - pub(crate) fn enforce_boolean_eq>( - &mut self, - cs: &mut CS, - left: Boolean, - right: Boolean, - ) -> Result<(), BooleanError> { - Ok(left.enforce_equal(cs.ns(|| format!("enforce bool equal")), &right)?) - } } diff --git a/compiler/src/constraints/expression.rs b/compiler/src/constraints/expression.rs index e4b5430214..05e5b6805c 100644 --- a/compiler/src/constraints/expression.rs +++ b/compiler/src/constraints/expression.rs @@ -23,7 +23,7 @@ use snarkos_models::{ curves::{Field, PrimeField}, gadgets::{ r1cs::ConstraintSystem, - utilities::{boolean::Boolean, select::CondSelectGadget}, + utilities::{boolean::Boolean, eq::EvaluateEqGadget, select::CondSelectGadget}, }, }; @@ -185,36 +185,41 @@ impl> ConstrainedProgram { } /// Evaluate Boolean operations - fn evaluate_eq_expression( + fn evaluate_eq_expression>( &mut self, + cs: &mut CS, left: ConstrainedValue, right: ConstrainedValue, ) -> Result, ExpressionError> { - match (left, right) { + let mut expression_namespace = cs.ns(|| format!("evaluate {} == {}", left.to_string(), right.to_string())); + let result_bool = match (left, right) { (ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => { - Ok(Self::boolean_eq(bool_1, bool_2)) + bool_1.evaluate_equal(expression_namespace, &bool_2)? } (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { - Ok(ConstrainedValue::Boolean(Boolean::Constant(num_1.eq(&num_2)))) + num_1.evaluate_equal(expression_namespace, &num_2)? } (ConstrainedValue::Field(fe_1), ConstrainedValue::Field(fe_2)) => { - Ok(ConstrainedValue::Boolean(Boolean::Constant(fe_1.eq(&fe_2)))) + fe_1.evaluate_equal(expression_namespace, &fe_2)? } (ConstrainedValue::Group(ge_1), ConstrainedValue::Group(ge_2)) => { - Ok(ConstrainedValue::Boolean(Boolean::Constant(ge_1.eq(&ge_2)))) + ge_1.evaluate_equal(expression_namespace, &ge_2)? } (ConstrainedValue::Unresolved(string), val_2) => { let val_1 = ConstrainedValue::from_other(string, &val_2)?; - self.evaluate_eq_expression(val_1, val_2) + return self.evaluate_eq_expression(&mut expression_namespace, val_1, val_2); } (val_1, ConstrainedValue::Unresolved(string)) => { let val_2 = ConstrainedValue::from_other(string, &val_1)?; - self.evaluate_eq_expression(val_1, val_2) + return self.evaluate_eq_expression(&mut expression_namespace, val_1, val_2); } - (val_1, val_2) => Err(ExpressionError::IncompatibleTypes(format!("{} == {}", val_1, val_2,))), - } + (val_1, val_2) => return Err(ExpressionError::IncompatibleTypes(format!("{} == {}", val_1, val_2,))), + }; + + Ok(ConstrainedValue::Boolean(result_bool)) } + //TODO: unsafe for allocated values fn evaluate_ge_expression( &mut self, left: ConstrainedValue, @@ -244,6 +249,7 @@ impl> ConstrainedProgram { } } + //TODO: unsafe for allocated values fn evaluate_gt_expression( &mut self, left: ConstrainedValue, @@ -273,6 +279,7 @@ impl> ConstrainedProgram { } } + //TODO: unsafe for allocated values fn evaluate_le_expression( &mut self, left: ConstrainedValue, @@ -302,6 +309,7 @@ impl> ConstrainedProgram { } } + //TODO: unsafe for allocated values fn evaluate_lt_expression( &mut self, left: ConstrainedValue, @@ -354,8 +362,8 @@ impl> ConstrainedProgram { }; let resolved_second = - self.enforce_branch(cs, file_scope.clone(), function_scope.clone(), expected_types, second)?; - let resolved_third = self.enforce_branch(cs, file_scope, function_scope, expected_types, third)?; + self.enforce_expression_value(cs, file_scope.clone(), function_scope.clone(), expected_types, second)?; + let resolved_third = self.enforce_expression_value(cs, file_scope, function_scope, expected_types, third)?; match (resolved_second, resolved_third) { (ConstrainedValue::Boolean(bool_2), ConstrainedValue::Boolean(bool_3)) => { @@ -374,7 +382,7 @@ impl> ConstrainedProgram { let result = G::conditionally_select(cs, &resolved_first, &ge_1, &ge_2)?; Ok(ConstrainedValue::Group(result)) } - (_, _) => unimplemented!("conditional select gadget not implemented between given types"), + (_, _) => unimplemented!("statements.conditional select gadget not implemented between given types"), } } @@ -451,7 +459,7 @@ impl> ConstrainedProgram { index: Expression, ) -> Result { let expected_types = vec![Type::IntegerType(IntegerType::U32)]; - match self.enforce_branch(cs, file_scope.clone(), function_scope.clone(), &expected_types, index)? { + match self.enforce_expression_value(cs, file_scope.clone(), function_scope.clone(), &expected_types, index)? { ConstrainedValue::Integer(number) => Ok(number.to_usize()), value => Err(ExpressionError::InvalidIndex(value.to_string())), } @@ -466,7 +474,13 @@ impl> ConstrainedProgram { array: Box, index: RangeOrExpression, ) -> Result, ExpressionError> { - let array = match self.enforce_branch(cs, file_scope.clone(), function_scope.clone(), expected_types, *array)? { + let array = match self.enforce_expression_value( + cs, + file_scope.clone(), + function_scope.clone(), + expected_types, + *array, + )? { ConstrainedValue::Array(array) => array, value => return Err(ExpressionError::InvalidArrayAccess(value.to_string())), }; @@ -562,7 +576,7 @@ impl> ConstrainedProgram { circuit_identifier: Box, circuit_member: Identifier, ) -> Result, ExpressionError> { - let (circuit_name, members) = match self.enforce_branch( + let (circuit_name, members) = match self.enforce_expression_value( cs, file_scope.clone(), function_scope.clone(), @@ -712,7 +726,7 @@ impl> ConstrainedProgram { /// Enforce a branch of a binary expression. /// We don't care about mutability because we are not changing any variables. /// We try to resolve unresolved types here if the type is given explicitly. - pub(crate) fn enforce_branch>( + pub(crate) fn enforce_expression_value>( &mut self, cs: &mut CS, file_scope: String, @@ -737,10 +751,12 @@ impl> ConstrainedProgram { left: Expression, right: Expression, ) -> Result<(ConstrainedValue, ConstrainedValue), ExpressionError> { - let resolved_left = - self.enforce_branch(cs, file_scope.clone(), function_scope.clone(), expected_types, left)?; - let resolved_right = - self.enforce_branch(cs, file_scope.clone(), function_scope.clone(), expected_types, right)?; + let mut resolved_left = + self.enforce_expression_value(cs, file_scope.clone(), function_scope.clone(), expected_types, left)?; + let mut resolved_right = + self.enforce_expression_value(cs, file_scope.clone(), function_scope.clone(), expected_types, right)?; + + resolved_left.resolve_types(&mut resolved_right, expected_types)?; Ok((resolved_left, resolved_right)) } @@ -763,7 +779,7 @@ impl> ConstrainedProgram { Expression::Integer(integer) => Ok(ConstrainedValue::Integer(integer)), Expression::Field(field) => Ok(ConstrainedValue::Field(FieldType::constant(field)?)), Expression::Group(group_affine) => Ok(ConstrainedValue::Group(G::constant(group_affine)?)), - Expression::Boolean(bool) => Ok(Self::get_boolean_constant(bool)), + Expression::Boolean(bool) => Ok(ConstrainedValue::Boolean(bool)), Expression::Implicit(value) => Self::enforce_number_implicit(expected_types, value), // Binary operations @@ -865,19 +881,19 @@ impl> ConstrainedProgram { cs, file_scope.clone(), function_scope.clone(), - expected_types, + &vec![], *left, *right, )?; - Ok(self.evaluate_eq_expression(resolved_left, resolved_right)?) + Ok(self.evaluate_eq_expression(cs, resolved_left, resolved_right)?) } Expression::Ge(left, right) => { let (resolved_left, resolved_right) = self.enforce_binary_expression( cs, file_scope.clone(), function_scope.clone(), - expected_types, + &vec![], *left, *right, )?; @@ -889,7 +905,7 @@ impl> ConstrainedProgram { cs, file_scope.clone(), function_scope.clone(), - expected_types, + &vec![], *left, *right, )?; @@ -901,7 +917,7 @@ impl> ConstrainedProgram { cs, file_scope.clone(), function_scope.clone(), - expected_types, + &vec![], *left, *right, )?; @@ -913,7 +929,7 @@ impl> ConstrainedProgram { cs, file_scope.clone(), function_scope.clone(), - expected_types, + &vec![], *left, *right, )?; diff --git a/compiler/src/constraints/function.rs b/compiler/src/constraints/function.rs index 2132d99298..992aa8f890 100644 --- a/compiler/src/constraints/function.rs +++ b/compiler/src/constraints/function.rs @@ -87,6 +87,7 @@ impl> ConstrainedProgram { cs, scope.clone(), function_name.clone(), + None, statement.clone(), function.returns.clone(), )? { diff --git a/compiler/src/constraints/statement.rs b/compiler/src/constraints/statement.rs index d923080a03..5fc5e77b10 100644 --- a/compiler/src/constraints/statement.rs +++ b/compiler/src/constraints/statement.rs @@ -23,7 +23,7 @@ use snarkos_models::{ curves::{Field, PrimeField}, gadgets::{ r1cs::ConstraintSystem, - utilities::{boolean::Boolean, eq::EqGadget, uint::UInt32}, + utilities::{boolean::Boolean, eq::ConditionalEqGadget, select::CondSelectGadget, uint::UInt32}, }, }; @@ -52,10 +52,13 @@ impl> ConstrainedProgram { cs: &mut CS, file_scope: String, function_scope: String, + indicator: Option, name: String, range_or_expression: RangeOrExpression, - new_value: ConstrainedValue, + mut new_value: ConstrainedValue, ) -> Result<(), StatementError> { + let condition = indicator.unwrap_or(Boolean::Constant(true)); + // Resolve index so we know if we are assigning to a single value or a range of values match range_or_expression { RangeOrExpression::Expression(index) => { @@ -64,7 +67,14 @@ impl> ConstrainedProgram { // Modify the single value of the array in place match self.get_mutable_assignee(name)? { ConstrainedValue::Array(old) => { - old[index] = new_value; + new_value.resolve_type(&vec![old[index].to_type()])?; + + let selected_value = + ConstrainedValue::conditionally_select(cs, &condition, &new_value, &old[index]).map_err( + |_| StatementError::SelectFail(new_value.to_string(), old[index].to_string()), + )?; + + old[index] = selected_value; } _ => return Err(StatementError::ArrayAssignIndex), } @@ -79,26 +89,36 @@ impl> ConstrainedProgram { None => None, }; - // Modify the range of values of the array in place - match (self.get_mutable_assignee(name)?, new_value) { - (ConstrainedValue::Array(old), ConstrainedValue::Array(ref new)) => { - let to_index = to_index_option.unwrap_or(old.len()); - old.splice(from_index..to_index, new.iter().cloned()); + // Modify the range of values of the array + let old_array = self.get_mutable_assignee(name)?; + let new_array = match (old_array.clone(), new_value) { + (ConstrainedValue::Array(mut mutable), ConstrainedValue::Array(new)) => { + let to_index = to_index_option.unwrap_or(mutable.len()); + + mutable.splice(from_index..to_index, new.iter().cloned()); + ConstrainedValue::Array(mutable) } _ => return Err(StatementError::ArrayAssignRange), - } + }; + let selected_array = ConstrainedValue::conditionally_select(cs, &condition, &new_array, old_array) + .map_err(|_| StatementError::SelectFail(new_array.to_string(), old_array.to_string()))?; + *old_array = selected_array; } } Ok(()) } - fn mutute_circuit_field( + fn mutute_circuit_field>( &mut self, + cs: &mut CS, + indicator: Option, circuit_name: String, object_name: Identifier, - new_value: ConstrainedValue, + mut new_value: ConstrainedValue, ) -> Result<(), StatementError> { + let condition = indicator.unwrap_or(Boolean::Constant(true)); + match self.get_mutable_assignee(circuit_name)? { ConstrainedValue::CircuitExpression(_variable, members) => { // Modify the circuit field in place @@ -114,7 +134,16 @@ impl> ConstrainedProgram { ConstrainedValue::Static(_value) => { return Err(StatementError::ImmutableCircuitFunction("static".into())); } - _ => object.1 = new_value.to_owned(), + _ => { + new_value.resolve_type(&vec![object.1.to_type()])?; + + let selected_value = ConstrainedValue::conditionally_select( + cs, &condition, &new_value, &object.1, + ) + .map_err(|_| StatementError::SelectFail(new_value.to_string(), object.1.to_string()))?; + + object.1 = selected_value.to_owned(); + } }, None => return Err(StatementError::UndefinedCircuitObject(object_name.to_string())), } @@ -130,6 +159,7 @@ impl> ConstrainedProgram { cs: &mut CS, file_scope: String, function_scope: String, + indicator: Option, assignee: Assignee, expression: Expression, ) -> Result<(), StatementError> { @@ -137,14 +167,19 @@ impl> ConstrainedProgram { let variable_name = self.resolve_assignee(function_scope.clone(), assignee.clone()); // Evaluate new value - let new_value = self.enforce_expression(cs, file_scope.clone(), function_scope.clone(), &vec![], expression)?; + let mut new_value = + self.enforce_expression(cs, file_scope.clone(), function_scope.clone(), &vec![], expression)?; // Mutate the old value into the new value match assignee { Assignee::Identifier(_identifier) => { + let condition = indicator.unwrap_or(Boolean::Constant(true)); let old_value = self.get_mutable_assignee(variable_name.clone())?; + new_value.resolve_type(&vec![old_value.to_type()])?; + let selected_value = ConstrainedValue::conditionally_select(cs, &condition, &new_value, old_value) + .map_err(|_| StatementError::SelectFail(new_value.to_string(), old_value.to_string()))?; - *old_value = new_value; + *old_value = selected_value; Ok(()) } @@ -152,12 +187,13 @@ impl> ConstrainedProgram { cs, file_scope, function_scope, + indicator, variable_name, range_or_expression, new_value, ), Assignee::CircuitField(_assignee, object_name) => { - self.mutute_circuit_field(variable_name, object_name, new_value) + self.mutute_circuit_field(cs, indicator, variable_name, object_name, new_value) } } } @@ -262,7 +298,7 @@ impl> ConstrainedProgram { let mut returns = vec![]; for (expression, ty) in expressions.into_iter().zip(return_types.into_iter()) { let expected_types = vec![ty.clone()]; - let result = self.enforce_branch( + let result = self.enforce_expression_value( cs, file_scope.clone(), function_scope.clone(), @@ -276,11 +312,12 @@ impl> ConstrainedProgram { Ok(ConstrainedValue::Return(returns)) } - fn iterate_or_early_return>( + fn evaluate_branch>( &mut self, cs: &mut CS, file_scope: String, function_scope: String, + indicator: Option, statements: Vec, return_types: Vec, ) -> Result>, StatementError> { @@ -291,6 +328,7 @@ impl> ConstrainedProgram { cs, file_scope.clone(), function_scope.clone(), + indicator.clone(), statement.clone(), return_types.clone(), )? { @@ -302,16 +340,24 @@ impl> ConstrainedProgram { Ok(res) } + /// Enforces a statements.conditional statement with one or more branches. + /// Due to R1CS constraints, we must evaluate every branch to properly construct the circuit. + /// At program execution, we will pass an `indicator bit` down to all child statements within each branch. + /// The `indicator bit` will select that branch while keeping the constraint system satisfied. fn enforce_conditional_statement>( &mut self, cs: &mut CS, file_scope: String, function_scope: String, + indicator: Option, statement: ConditionalStatement, return_types: Vec, ) -> Result>, StatementError> { + let statement_string = statement.to_string(); + let outer_indicator = indicator.unwrap_or(Boolean::Constant(true)); + let expected_types = vec![Type::Boolean]; - let condition = match self.enforce_expression( + let inner_indicator = match self.enforce_expression( cs, file_scope.clone(), function_scope.clone(), @@ -322,21 +368,50 @@ impl> ConstrainedProgram { value => return Err(StatementError::IfElseConditional(value.to_string())), }; - // use gadget impl - if condition.eq(&Boolean::Constant(true)) { - self.iterate_or_early_return(cs, file_scope, function_scope, statement.statements, return_types) - } else { - match statement.next { - Some(next) => match next { - ConditionalNestedOrEndStatement::Nested(nested) => { - self.enforce_conditional_statement(cs, file_scope, function_scope, *nested, return_types) - } - ConditionalNestedOrEndStatement::End(statements) => { - self.iterate_or_early_return(cs, file_scope, function_scope, statements, return_types) - } - }, - None => Ok(None), - } + // Determine nested branch selection + let branch_1_indicator = Boolean::and( + &mut cs.ns(|| format!("statement branch 1 indicator {}", statement_string)), + &outer_indicator, + &inner_indicator, + )?; + + // Execute branch 1 + self.evaluate_branch( + cs, + file_scope.clone(), + function_scope.clone(), + Some(branch_1_indicator), + statement.statements, + return_types.clone(), + )?; + + // Execute branch 2 + let branch_2_indicator = Boolean::and( + &mut cs.ns(|| format!("statement branch 2 indicator {}", statement_string)), + &outer_indicator, + &inner_indicator.not(), + )?; + + match statement.next { + Some(next) => match next { + ConditionalNestedOrEndStatement::Nested(nested) => self.enforce_conditional_statement( + cs, + file_scope, + function_scope, + Some(branch_2_indicator), + *nested, + return_types, + ), + ConditionalNestedOrEndStatement::End(statements) => self.evaluate_branch( + cs, + file_scope, + function_scope, + Some(branch_2_indicator), + statements, + return_types, + ), + }, + None => Ok(None), // this is an if with no else, have to pass statements.conditional down to next statements somehow } } @@ -345,6 +420,7 @@ impl> ConstrainedProgram { cs: &mut CS, file_scope: String, function_scope: String, + indicator: Option, index: Identifier, start: Integer, stop: Integer, @@ -362,11 +438,14 @@ impl> ConstrainedProgram { ConstrainedValue::Integer(Integer::U32(UInt32::constant(i as u32))), ); + cs.ns(|| format!("loop {} = {}", index.to_string(), i)); + // Evaluate statements and possibly return early - if let Some(early_return) = self.iterate_or_early_return( + if let Some(early_return) = self.evaluate_branch( cs, file_scope.clone(), function_scope.clone(), + indicator, statements.clone(), return_types.clone(), )? { @@ -381,29 +460,14 @@ impl> ConstrainedProgram { fn enforce_assert_eq_statement>( &mut self, cs: &mut CS, - left: ConstrainedValue, - right: ConstrainedValue, + indicator: Option, + left: &ConstrainedValue, + right: &ConstrainedValue, ) -> Result<(), StatementError> { - Ok(match (left, right) { - (ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => { - self.enforce_boolean_eq(cs, bool_1, bool_2)? - } - (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => num_1 - .enforce_equal(cs, &num_2) - .map_err(|_| StatementError::AssertionFailed(num_1.to_string(), num_2.to_string()))?, - (ConstrainedValue::Field(fe_1), ConstrainedValue::Field(fe_2)) => fe_1 - .enforce_equal(cs, &fe_2) - .map_err(|_| StatementError::AssertionFailed(fe_1.to_string(), fe_2.to_string()))?, - (ConstrainedValue::Group(ge_1), ConstrainedValue::Group(ge_2)) => ge_1 - .enforce_equal(cs, &ge_2) - .map_err(|_| StatementError::AssertionFailed(ge_1.to_string(), ge_2.to_string()))?, - (ConstrainedValue::Array(arr_1), ConstrainedValue::Array(arr_2)) => { - for (left, right) in arr_1.into_iter().zip(arr_2.into_iter()) { - self.enforce_assert_eq_statement(cs, left, right)?; - } - } - (val_1, val_2) => return Err(StatementError::AssertEq(val_1.to_string(), val_2.to_string())), - }) + let condition = indicator.unwrap_or(Boolean::Constant(true)); + let result = left.conditional_enforce_equal(cs, right, &condition); + + Ok(result.map_err(|_| StatementError::AssertionFailed(left.to_string(), right.to_string()))?) } pub(crate) fn enforce_statement>( @@ -411,6 +475,7 @@ impl> ConstrainedProgram { cs: &mut CS, file_scope: String, function_scope: String, + indicator: Option, statement: Statement, return_types: Vec, ) -> Result>, StatementError> { @@ -423,15 +488,20 @@ impl> ConstrainedProgram { self.enforce_definition_statement(cs, file_scope, function_scope, variable, expression)?; } Statement::Assign(variable, expression) => { - self.enforce_assign_statement(cs, file_scope, function_scope, variable, expression)?; + self.enforce_assign_statement(cs, file_scope, function_scope, indicator, variable, expression)?; } Statement::MultipleAssign(variables, function) => { self.enforce_multiple_definition_statement(cs, file_scope, function_scope, variables, function)?; } Statement::Conditional(statement) => { - if let Some(early_return) = - self.enforce_conditional_statement(cs, file_scope, function_scope, statement, return_types)? - { + if let Some(early_return) = self.enforce_conditional_statement( + cs, + file_scope, + function_scope, + indicator, + statement, + return_types, + )? { res = Some(early_return) } } @@ -440,6 +510,7 @@ impl> ConstrainedProgram { cs, file_scope, function_scope, + indicator, index, start, stop, @@ -450,12 +521,10 @@ impl> ConstrainedProgram { } } Statement::AssertEq(left, right) => { - let resolved_left = - self.enforce_expression(cs, file_scope.clone(), function_scope.clone(), &vec![], left)?; - let resolved_right = - self.enforce_expression(cs, file_scope.clone(), function_scope.clone(), &vec![], right)?; + let (resolved_left, resolved_right) = + self.enforce_binary_expression(cs, file_scope, function_scope, &vec![], left, right)?; - self.enforce_assert_eq_statement(cs, resolved_left, resolved_right)?; + self.enforce_assert_eq_statement(cs, indicator, &resolved_left, &resolved_right)?; } Statement::Expression(expression) => { match self.enforce_expression(cs, file_scope, function_scope, &vec![], expression.clone())? { diff --git a/compiler/src/constraints/value.rs b/compiler/src/constraints/value.rs index 7c796bafe7..bf658c4691 100644 --- a/compiler/src/constraints/value.rs +++ b/compiler/src/constraints/value.rs @@ -3,11 +3,17 @@ use crate::{errors::ValueError, FieldType, GroupType}; use leo_types::{Circuit, Function, Identifier, Integer, IntegerType, Type}; +use snarkos_errors::gadgets::SynthesisError; use snarkos_models::{ curves::{Field, PrimeField}, - gadgets::utilities::{ - boolean::Boolean, - uint::{UInt128, UInt16, UInt32, UInt64, UInt8}, + gadgets::{ + r1cs::ConstraintSystem, + utilities::{ + boolean::Boolean, + eq::ConditionalEqGadget, + select::CondSelectGadget, + uint::{UInt128, UInt16, UInt32, UInt64, UInt8}, + }, }, }; use std::fmt; @@ -79,6 +85,21 @@ impl> ConstrainedValue { Ok(()) } + /// Expect both `self` and `other` to resolve to the same type + pub(crate) fn resolve_types(&mut self, other: &mut Self, types: &Vec) -> Result<(), ValueError> { + if !types.is_empty() { + self.resolve_type(types)?; + return other.resolve_type(types); + } + + match (&self, &other) { + (ConstrainedValue::Unresolved(_), ConstrainedValue::Unresolved(_)) => Ok(()), + (ConstrainedValue::Unresolved(_), _) => self.resolve_type(&vec![other.to_type()]), + (_, ConstrainedValue::Unresolved(_)) => other.resolve_type(&vec![self.to_type()]), + _ => Ok(()), + } + } + pub(crate) fn get_inner_mut(&mut self) { if let ConstrainedValue::Mutable(inner) = self { *self = *inner.clone() @@ -139,3 +160,118 @@ impl> fmt::Debug for ConstrainedValue> ConditionalEqGadget for ConstrainedValue { + fn conditional_enforce_equal>( + &self, + mut cs: CS, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + match (self, other) { + (ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => bool_1.conditional_enforce_equal( + cs.ns(|| format!("{} == {}", self.to_string(), other.to_string())), + bool_2, + &condition, + ), + (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => num_1.conditional_enforce_equal( + cs.ns(|| format!("{} == {}", self.to_string(), other.to_string())), + num_2, + &condition, + ), + (ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => field_1.conditional_enforce_equal( + cs.ns(|| format!("{} == {}", self.to_string(), other.to_string())), + field_2, + &condition, + ), + (ConstrainedValue::Group(group_1), ConstrainedValue::Group(group_2)) => group_1.conditional_enforce_equal( + cs.ns(|| format!("{} == {}", self.to_string(), other.to_string())), + group_2, + &condition, + ), + (ConstrainedValue::Array(arr_1), ConstrainedValue::Array(arr_2)) => { + for (i, (left, right)) in arr_1.into_iter().zip(arr_2.into_iter()).enumerate() { + left.conditional_enforce_equal( + cs.ns(|| format!("array[{}] equal {} == {}", i, left.to_string(), right.to_string())), + right, + &condition, + )?; + } + Ok(()) + } + (_, _) => return Err(SynthesisError::Unsatisfiable), + } + } + + fn cost() -> usize { + unimplemented!() + } +} + +impl> CondSelectGadget for ConstrainedValue { + fn conditionally_select>( + mut cs: CS, + cond: &Boolean, + first: &Self, + second: &Self, + ) -> Result { + Ok(match (first, second) { + (ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => { + ConstrainedValue::Boolean(Boolean::conditionally_select( + cs.ns(|| format!("if cond ? {} else {}", first.to_string(), second.to_string())), + cond, + bool_1, + bool_2, + )?) + } + (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { + ConstrainedValue::Integer(Integer::conditionally_select( + cs.ns(|| format!("if cond ? {} else {}", first.to_string(), second.to_string())), + cond, + num_1, + num_2, + )?) + } + (ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => { + ConstrainedValue::Field(FieldType::conditionally_select( + cs.ns(|| format!("if cond ? {} else {}", first.to_string(), second.to_string())), + cond, + field_1, + field_2, + )?) + } + (ConstrainedValue::Group(group_1), ConstrainedValue::Group(group_2)) => { + ConstrainedValue::Group(G::conditionally_select( + cs.ns(|| format!("if cond ? {} else {}", first.to_string(), second.to_string())), + cond, + group_1, + group_2, + )?) + } + (ConstrainedValue::Array(arr_1), ConstrainedValue::Array(arr_2)) => { + let mut array = vec![]; + for (i, (first, second)) in arr_1.into_iter().zip(arr_2.into_iter()).enumerate() { + array.push(Self::conditionally_select( + cs.ns(|| { + format!( + "array[{}] = if cond ? {} else {}", + i, + first.to_string(), + second.to_string() + ) + }), + cond, + first, + second, + )?); + } + ConstrainedValue::Array(array) + } + (_, _) => return Err(SynthesisError::Unsatisfiable), + }) + } + + fn cost() -> usize { + unimplemented!() //lower bound 1, upper bound 128 or length of static array + } +} diff --git a/compiler/src/errors/constraints/expression.rs b/compiler/src/errors/constraints/expression.rs index cf06142228..671eb21839 100644 --- a/compiler/src/errors/constraints/expression.rs +++ b/compiler/src/errors/constraints/expression.rs @@ -81,7 +81,7 @@ pub enum ExpressionError { UndefinedFunction(String), // Conditionals - #[error("If, else conditional must resolve to a boolean, got {}", _0)] + #[error("If, else statements.conditional must resolve to a boolean, got {}", _0)] IfElseConditional(String), #[error("{}", _0)] diff --git a/compiler/src/errors/constraints/statement.rs b/compiler/src/errors/constraints/statement.rs index b41d2894e6..180d5acbb6 100644 --- a/compiler/src/errors/constraints/statement.rs +++ b/compiler/src/errors/constraints/statement.rs @@ -1,4 +1,4 @@ -use crate::errors::{BooleanError, ExpressionError}; +use crate::errors::{BooleanError, ExpressionError, ValueError}; use snarkos_errors::gadgets::SynthesisError; @@ -37,7 +37,7 @@ pub enum StatementError { #[error("Assertion {:?} == {:?} failed", _0, _1)] AssertionFailed(String, String), - #[error("If, else conditional must resolve to a boolean, got {}", _0)] + #[error("If, else statements.conditional must resolve to a boolean, got {}", _0)] IfElseConditional(String), #[error("Cannot assign to immutable variable {}", _0)] @@ -49,9 +49,15 @@ pub enum StatementError { #[error("Function return statement expected {} return values, got {}", _0, _1)] InvalidNumberOfReturns(usize, usize), + #[error("Conditional select gadget failed to select between {} or {}", _0, _1)] + SelectFail(String, String), + #[error("{}", _0)] SynthesisError(#[from] SynthesisError), #[error("Expected assignment of return values for expression {}", _0)] Unassigned(String), + + #[error("{}", _0)] + ValueError(#[from] ValueError), } diff --git a/compiler/src/field/mod.rs b/compiler/src/field/mod.rs index 42fadf4b23..b7a455b1f2 100644 --- a/compiler/src/field/mod.rs +++ b/compiler/src/field/mod.rs @@ -11,7 +11,7 @@ use snarkos_models::{ utilities::{ alloc::AllocGadget, boolean::Boolean, - eq::{ConditionalEqGadget, EqGadget}, + eq::{ConditionalEqGadget, EqGadget, EvaluateEqGadget}, select::CondSelectGadget, uint::UInt8, ToBitsGadget, @@ -175,6 +175,22 @@ impl PartialOrd for FieldType { } } +impl EvaluateEqGadget for FieldType { + fn evaluate_equal>(&self, mut cs: CS, other: &Self) -> Result { + match (self, other) { + (FieldType::Constant(first), FieldType::Constant(second)) => Ok(Boolean::constant(first.eq(second))), + (FieldType::Allocated(allocated), FieldType::Constant(constant)) + | (FieldType::Constant(constant), FieldType::Allocated(allocated)) => { + let bool_option = allocated.value.map(|f| f.eq(constant)); + Boolean::alloc(&mut cs.ns(|| "evaluate_equal"), || { + bool_option.ok_or(SynthesisError::AssignmentMissing) + }) + } + (FieldType::Allocated(first), FieldType::Allocated(second)) => first.evaluate_equal(cs, second), + } + } +} + impl EqGadget for FieldType {} impl ConditionalEqGadget for FieldType { diff --git a/compiler/src/group/edwards_bls12.rs b/compiler/src/group/edwards_bls12.rs index 8a52b22892..34ad793900 100644 --- a/compiler/src/group/edwards_bls12.rs +++ b/compiler/src/group/edwards_bls12.rs @@ -14,7 +14,7 @@ use snarkos_models::{ utilities::{ alloc::AllocGadget, boolean::Boolean, - eq::{ConditionalEqGadget, EqGadget}, + eq::{ConditionalEqGadget, EqGadget, EvaluateEqGadget}, select::CondSelectGadget, uint::UInt8, ToBitsGadget, @@ -165,6 +165,40 @@ impl PartialEq for EdwardsGroupType { impl Eq for EdwardsGroupType {} +impl EvaluateEqGadget for EdwardsGroupType { + fn evaluate_equal>(&self, mut cs: CS, other: &Self) -> Result { + match (self, other) { + (EdwardsGroupType::Constant(self_value), EdwardsGroupType::Constant(other_value)) => { + Ok(Boolean::Constant(self_value == other_value)) + } + + (EdwardsGroupType::Allocated(self_value), EdwardsGroupType::Allocated(other_value)) => { + let bool_option = + , Fq>>::get_value(self_value) + .and_then(|a| { + , Fq>>::get_value( + other_value, + ) + .map(|b| a.eq(&b)) + }); + Boolean::alloc(&mut cs.ns(|| "evaluate_equal"), || { + bool_option.ok_or(SynthesisError::AssignmentMissing) + }) + } + + (EdwardsGroupType::Constant(constant_value), EdwardsGroupType::Allocated(allocated_value)) + | (EdwardsGroupType::Allocated(allocated_value), EdwardsGroupType::Constant(constant_value)) => { + let bool_option = + , Fq>>::get_value(allocated_value) + .map(|a| a.eq(constant_value)); + Boolean::alloc(&mut cs.ns(|| "evaluate_equal"), || { + bool_option.ok_or(SynthesisError::AssignmentMissing) + }) + } + } + } +} + impl EqGadget for EdwardsGroupType {} impl ConditionalEqGadget for EdwardsGroupType { diff --git a/compiler/src/group/mod.rs b/compiler/src/group/mod.rs index 897c648442..6af5532570 100644 --- a/compiler/src/group/mod.rs +++ b/compiler/src/group/mod.rs @@ -8,7 +8,7 @@ use snarkos_models::{ r1cs::ConstraintSystem, utilities::{ alloc::AllocGadget, - eq::{ConditionalEqGadget, EqGadget}, + eq::{ConditionalEqGadget, EqGadget, EvaluateEqGadget}, select::CondSelectGadget, ToBitsGadget, ToBytesGadget, @@ -24,6 +24,7 @@ pub trait GroupType: + Clone + Debug + Display + + EvaluateEqGadget + EqGadget + ConditionalEqGadget + AllocGadget diff --git a/compiler/tests/integers/macros.rs b/compiler/tests/integers/macros.rs index f322113f39..4973535f7c 100644 --- a/compiler/tests/integers/macros.rs +++ b/compiler/tests/integers/macros.rs @@ -118,11 +118,6 @@ macro_rules! test_uint { let r1: $_type = rand::random(); let r2: $_type = rand::random(); - let quotient = r1.wrapping_div(r2); - - let cs = TestConstraintSystem::::new(); - let quotient_allocated = <$gadget>::alloc(cs, || Ok(quotient)).unwrap(); - let bytes = include_bytes!("div.leo"); let mut program = parse_program(bytes).unwrap(); @@ -131,7 +126,16 @@ macro_rules! test_uint { Some(InputValue::Integer($integer_type, r2 as u128)), ]); - output_expected_allocated(program, quotient_allocated); + // expect an error when dividing by zero + if r2 == 0 { + let _err = get_error(program); + } else { + let cs = TestConstraintSystem::::new(); + let quotient = r1.wrapping_div(r2); + let quotient_allocated = <$gadget>::alloc(cs, || Ok(quotient)).unwrap(); + + output_expected_allocated(program, quotient_allocated); + } } } diff --git a/compiler/tests/integers/u128/mod.rs b/compiler/tests/integers/u128/mod.rs index ad5cc58bd9..2989214a8b 100644 --- a/compiler/tests/integers/u128/mod.rs +++ b/compiler/tests/integers/u128/mod.rs @@ -1,5 +1,6 @@ use crate::{ boolean::{output_expected_boolean, output_false, output_true}, + get_error, get_output, integers::{fail_integer, fail_synthesis, IntegerTester}, parse_program, diff --git a/compiler/tests/integers/u16/mod.rs b/compiler/tests/integers/u16/mod.rs index e6aa8b1060..9978bfdda8 100644 --- a/compiler/tests/integers/u16/mod.rs +++ b/compiler/tests/integers/u16/mod.rs @@ -1,5 +1,6 @@ use crate::{ boolean::{output_expected_boolean, output_false, output_true}, + get_error, get_output, integers::{fail_integer, fail_synthesis, IntegerTester}, parse_program, diff --git a/compiler/tests/integers/u32/mod.rs b/compiler/tests/integers/u32/mod.rs index d9b02855c7..bcc17744c1 100644 --- a/compiler/tests/integers/u32/mod.rs +++ b/compiler/tests/integers/u32/mod.rs @@ -1,5 +1,6 @@ use crate::{ boolean::{output_expected_boolean, output_false, output_true}, + get_error, get_output, integers::{fail_integer, fail_synthesis, IntegerTester}, parse_program, @@ -28,22 +29,21 @@ fn output_expected_allocated(program: EdwardsTestCompiler, expected: UInt32) { } } -pub(crate) fn output_zero(program: EdwardsTestCompiler) { +pub(crate) fn output_number(program: EdwardsTestCompiler, number: u32) { let output = get_output(program); assert_eq!( - EdwardsConstrainedValue::Return(vec![ConstrainedValue::Integer(Integer::U32(UInt32::constant(0u32)))]) + EdwardsConstrainedValue::Return(vec![ConstrainedValue::Integer(Integer::U32(UInt32::constant(number)))]) .to_string(), output.to_string() ) } +pub(crate) fn output_zero(program: EdwardsTestCompiler) { + output_number(program, 0u32); +} + pub(crate) fn output_one(program: EdwardsTestCompiler) { - let output = get_output(program); - assert_eq!( - EdwardsConstrainedValue::Return(vec![ConstrainedValue::Integer(Integer::U32(UInt32::constant(1u32)))]) - .to_string(), - output.to_string() - ) + output_number(program, 1u32); } #[test] diff --git a/compiler/tests/integers/u64/mod.rs b/compiler/tests/integers/u64/mod.rs index 87d73ddcfe..d1e19725ae 100644 --- a/compiler/tests/integers/u64/mod.rs +++ b/compiler/tests/integers/u64/mod.rs @@ -1,5 +1,6 @@ use crate::{ boolean::{output_expected_boolean, output_false, output_true}, + get_error, get_output, integers::{fail_integer, fail_synthesis, IntegerTester}, parse_program, diff --git a/compiler/tests/integers/u8/mod.rs b/compiler/tests/integers/u8/mod.rs index 40796075f8..4e973c9c71 100644 --- a/compiler/tests/integers/u8/mod.rs +++ b/compiler/tests/integers/u8/mod.rs @@ -1,5 +1,6 @@ use crate::{ boolean::{output_expected_boolean, output_false, output_true}, + get_error, get_output, integers::{fail_integer, fail_synthesis, IntegerTester}, parse_program, diff --git a/compiler/tests/statements/conditional/assert.leo b/compiler/tests/statements/conditional/assert.leo new file mode 100644 index 0000000000..0f5fde2224 --- /dev/null +++ b/compiler/tests/statements/conditional/assert.leo @@ -0,0 +1,7 @@ +function main(bit: private u32) { + if bit == 1 { + assert_eq!(bit, 1); + } else { + assert_eq!(bit, 0); + } +} diff --git a/compiler/tests/statements/conditional/chain.leo b/compiler/tests/statements/conditional/chain.leo new file mode 100644 index 0000000000..159c94be8b --- /dev/null +++ b/compiler/tests/statements/conditional/chain.leo @@ -0,0 +1,13 @@ +function main(bit: u32) -> u32 { + let mut result = 0u32; + + if bit == 1 { + result = 1; + } else if bit == 2 { + result = 2; + } else { + result = 3; + } + + return result +} \ No newline at end of file diff --git a/compiler/tests/statements/conditional/for_loop.leo b/compiler/tests/statements/conditional/for_loop.leo new file mode 100644 index 0000000000..2728e2d302 --- /dev/null +++ b/compiler/tests/statements/conditional/for_loop.leo @@ -0,0 +1,11 @@ +function main(cond: bool) -> u32 { + let mut a = 0u32; + + if cond { + for i in 0..4 { + a += i; + } + } + + return a +} diff --git a/compiler/tests/statements/conditional/mod.rs b/compiler/tests/statements/conditional/mod.rs new file mode 100644 index 0000000000..15133bedd1 --- /dev/null +++ b/compiler/tests/statements/conditional/mod.rs @@ -0,0 +1,130 @@ +use crate::{ + get_output, + integers::u32::{output_one, output_zero}, + parse_program, + EdwardsConstrainedValue, + EdwardsTestCompiler, +}; +use leo_inputs::types::{IntegerType, U32Type}; +use leo_types::InputValue; + +use crate::integers::u32::output_number; +use snarkos_curves::edwards_bls12::Fq; +use snarkos_models::gadgets::r1cs::TestConstraintSystem; + +fn empty_output_satisfied(program: EdwardsTestCompiler) { + let output = get_output(program); + + assert_eq!(EdwardsConstrainedValue::Return(vec![]).to_string(), output.to_string()); +} + +// Tests a statements.conditional enforceBit() program +// +// function main(bit: private u8) { +// if bit == 1u8 { +// assert_eq!(bit, 1u8); +// } else { +// assert_eq!(bit, 0u8); +// } +// } +#[test] +fn test_assert() { + let bytes = include_bytes!("assert.leo"); + let mut program_1_pass = parse_program(bytes).unwrap(); + let mut program_0_pass = program_1_pass.clone(); + let mut program_2_fail = program_1_pass.clone(); + + // Check that an input value of 1 satisfies the constraint system + + program_1_pass.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 1))]); + empty_output_satisfied(program_1_pass); + + // Check that an input value of 0 satisfies the constraint system + + program_0_pass.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 0))]); + empty_output_satisfied(program_0_pass); + + // Check that an input value of 2 does not satisfy the constraint system + + program_2_fail.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 2))]); + let mut cs = TestConstraintSystem::::new(); + let _output = program_2_fail.compile_constraints(&mut cs).unwrap(); + assert!(!cs.is_satisfied()); +} + +#[test] +fn test_mutate() { + let bytes = include_bytes!("mutate.leo"); + let mut program_1_true = parse_program(bytes).unwrap(); + let mut program_0_pass = program_1_true.clone(); + + // Check that an input value of 1 satisfies the constraint system + + program_1_true.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 1))]); + output_one(program_1_true); + + // Check that an input value of 0 satisfies the constraint system + + program_0_pass.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 0))]); + output_zero(program_0_pass); +} + +#[test] +fn test_for_loop() { + let bytes = include_bytes!("for_loop.leo"); + let mut program_true_6 = parse_program(bytes).unwrap(); + let mut program_false_0 = program_true_6.clone(); + + // Check that an input value of true satisfies the constraint system + + program_true_6.set_inputs(vec![Some(InputValue::Boolean(true))]); + output_number(program_true_6, 6u32); + + // Check that an input value of false satisfies the constraint system + + program_false_0.set_inputs(vec![Some(InputValue::Boolean(false))]); + output_zero(program_false_0); +} + +#[test] +fn test_chain() { + let bytes = include_bytes!("chain.leo"); + let mut program_1_1 = parse_program(bytes).unwrap(); + let mut program_2_2 = program_1_1.clone(); + let mut program_2_3 = program_1_1.clone(); + + // Check that an input of 1 outputs true + program_1_1.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 1))]); + output_number(program_1_1, 1u32); + + // Check that an input of 0 outputs true + program_2_2.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 2))]); + output_number(program_2_2, 2u32); + + // Check that an input of 0 outputs true + program_2_3.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 5))]); + output_number(program_2_3, 3u32); +} + +#[test] +fn test_nested() { + let bytes = include_bytes!("nested.leo"); + let mut program_true_true_3 = parse_program(bytes).unwrap(); + let mut program_true_false_1 = program_true_true_3.clone(); + let mut program_false_false_0 = program_true_true_3.clone(); + + // Check that an input value of true true satisfies the constraint system + + program_true_true_3.set_inputs(vec![Some(InputValue::Boolean(true)); 2]); + output_number(program_true_true_3, 3u32); + + // Check that an input value of true false satisfies the constraint system + + program_true_false_1.set_inputs(vec![Some(InputValue::Boolean(true)), Some(InputValue::Boolean(false))]); + output_number(program_true_false_1, 1u32); + + // Check that an input value of false false satisfies the constraint system + + program_false_false_0.set_inputs(vec![Some(InputValue::Boolean(false)), Some(InputValue::Boolean(false))]); + output_number(program_false_false_0, 0u32); +} diff --git a/compiler/tests/statements/conditional/mutate.leo b/compiler/tests/statements/conditional/mutate.leo new file mode 100644 index 0000000000..168fa3866c --- /dev/null +++ b/compiler/tests/statements/conditional/mutate.leo @@ -0,0 +1,11 @@ +function main(bit: private u32) -> u32 { + let mut a = 5u32; + + if bit == 1 { + a = 1; + } else { + a = 0; + } + + return a +} diff --git a/compiler/tests/statements/conditional/nested.leo b/compiler/tests/statements/conditional/nested.leo new file mode 100644 index 0000000000..7eb814c9f1 --- /dev/null +++ b/compiler/tests/statements/conditional/nested.leo @@ -0,0 +1,12 @@ +function main(a: bool, b: bool) -> u32 { + let mut result = 0u32; + + if a { + result += 1; + if b { + result += 2; + } + } + + return result +} \ No newline at end of file diff --git a/compiler/tests/statements/mod.rs b/compiler/tests/statements/mod.rs index de9100156d..950f2896f5 100644 --- a/compiler/tests/statements/mod.rs +++ b/compiler/tests/statements/mod.rs @@ -7,6 +7,8 @@ use leo_types::InputValue; use snarkos_curves::edwards_bls12::Fq; use snarkos_models::gadgets::r1cs::TestConstraintSystem; +pub mod conditional; + // Ternary if {bool}? {expression} : {expression}; #[test] diff --git a/examples/fibonacci/src/main.leo b/examples/fibonacci/src/main.leo index b52c0b09ce..b1cd8efb3a 100644 --- a/examples/fibonacci/src/main.leo +++ b/examples/fibonacci/src/main.leo @@ -12,3 +12,16 @@ function fibonacci(i: u32) -> u32 { function main() -> u32 { return fibonacci(1) } + + +Function mutateNoLet(b: bool) { + let mut a = 5; + if b { + // must be turned into statements.conditional expression + a = 0; + // a = if b ? 0 : a; + } else { + a = 3; + // a = if b ? a : 3; + } +} \ No newline at end of file diff --git a/types/src/inputs/inputs.rs b/types/src/inputs/inputs.rs index d5b2b8b701..172b42de0c 100644 --- a/types/src/inputs/inputs.rs +++ b/types/src/inputs/inputs.rs @@ -31,7 +31,7 @@ impl Inputs { } pub fn from_inputs_file(file: File, expected_inputs: Vec) -> Result { - let mut private = vec![]; + let mut program_inputs = vec![]; let mut public = vec![]; for section in file.sections.into_iter() { @@ -62,7 +62,7 @@ impl Inputs { } // push value to vector - private.push(Some(value)); + program_inputs.push(Some(value)); } None => return Err(InputParserError::InputNotFound(input.to_string())), } @@ -70,10 +70,7 @@ impl Inputs { } } - Ok(Self { - program_inputs: private, - public, - }) + Ok(Self { program_inputs, public }) } pub fn get_public_inputs(&self) -> Result, InputParserError> { diff --git a/types/src/integer.rs b/types/src/integer.rs index eff0b91dfe..abd8f8538d 100644 --- a/types/src/integer.rs +++ b/types/src/integer.rs @@ -18,6 +18,7 @@ use snarkos_models::{ }, }; +use snarkos_models::gadgets::utilities::eq::EvaluateEqGadget; use std::fmt; /// An integer type enum wrapping the integer value. Used only in expressions. @@ -374,6 +375,19 @@ impl Integer { } } +impl EvaluateEqGadget for Integer { + fn evaluate_equal>(&self, cs: CS, other: &Self) -> Result { + match (self, other) { + (Integer::U8(left_u8), Integer::U8(right_u8)) => left_u8.evaluate_equal(cs, right_u8), + (Integer::U16(left_u16), Integer::U16(right_u16)) => left_u16.evaluate_equal(cs, right_u16), + (Integer::U32(left_u32), Integer::U32(right_u32)) => left_u32.evaluate_equal(cs, right_u32), + (Integer::U64(left_u64), Integer::U64(right_u64)) => left_u64.evaluate_equal(cs, right_u64), + (Integer::U128(left_u128), Integer::U128(right_u128)) => left_u128.evaluate_equal(cs, right_u128), + (_, _) => Err(SynthesisError::AssignmentMissing), + } + } +} + impl EqGadget for Integer {} impl ConditionalEqGadget for Integer { @@ -440,6 +454,16 @@ impl CondSelectGadget for Integer { impl fmt::Display for Integer { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}{}", self.to_usize(), self.get_type()) + let option = match self { + Integer::U8(u8) => u8.value.map(|num| num as usize), + Integer::U16(u16) => u16.value.map(|num| num as usize), + Integer::U32(u32) => u32.value.map(|num| num as usize), + Integer::U64(u64) => u64.value.map(|num| num as usize), + Integer::U128(u128) => u128.value.map(|num| num as usize), + }; + match option { + Some(number) => write!(f, "{}{}", number, self.get_type()), + None => write!(f, "[input]{}", self.get_type()), + } } }