diff --git a/compiler/src/constraints/expression.rs b/compiler/src/constraints/expression.rs index 63905174de..4b279f709e 100644 --- a/compiler/src/constraints/expression.rs +++ b/compiler/src/constraints/expression.rs @@ -224,9 +224,9 @@ impl< // (ResolvedValue::FieldElement(fe_1), ResolvedValue::FieldElement(fe_2)) => { // Self::field_eq(fe_1, fe_2) // } - // (ConstrainedValue::Group(ge_1), ConstrainedValue::Group(ge_2)) => { - // Ok(Self::evaluate_group_eq(ge_1, ge_2)) - // } + (ConstrainedValue::Group(ge_1), ConstrainedValue::Group(ge_2)) => { + Ok(ConstrainedValue::Boolean(Boolean::Constant(ge_1.eq(&ge_2)))) + } (ConstrainedValue::Unresolved(string), val_2) => { let val_1 = ConstrainedValue::from_other(string, &val_2)?; self.evaluate_eq_expression(val_1, val_2) diff --git a/compiler/src/constraints/statement.rs b/compiler/src/constraints/statement.rs index 423faab5b4..1521f84782 100644 --- a/compiler/src/constraints/statement.rs +++ b/compiler/src/constraints/statement.rs @@ -422,6 +422,9 @@ impl< (ConstrainedValue::FieldElement(fe_1), ConstrainedValue::FieldElement(fe_2)) => { self.enforce_field_eq(cs, fe_1, fe_2) } + (ConstrainedValue::Group(ge_1), ConstrainedValue::Group(ge_2)) => { + ge_1.enforce_equal(cs, &ge_2)? + } (ConstrainedValue::Array(arr_1), ConstrainedValue::Array(arr_2)) => { for (left, right) in arr_1.into_iter().zip(arr_2.into_iter()) { self.enforce_assert_eq_statement(cs, left, right)?; diff --git a/compiler/src/errors/constraints/statement.rs b/compiler/src/errors/constraints/statement.rs index 8e864262f2..7d83567072 100644 --- a/compiler/src/errors/constraints/statement.rs +++ b/compiler/src/errors/constraints/statement.rs @@ -1,4 +1,5 @@ use crate::errors::{BooleanError, ExpressionError, FieldElementError, IntegerError, ValueError}; +use snarkos_errors::gadgets::SynthesisError; #[derive(Debug, Error)] pub enum StatementError { @@ -62,6 +63,9 @@ pub enum StatementError { #[error("Expected assignment of return values for expression {}", _0)] Unassigned(String), + + #[error("{}", _0)] + SynthesisError(#[from] SynthesisError), } impl From for StatementError { diff --git a/compiler/src/group/edwards_bls12.rs b/compiler/src/group/edwards_bls12.rs index 652008c4aa..a7e65c595f 100644 --- a/compiler/src/group/edwards_bls12.rs +++ b/compiler/src/group/edwards_bls12.rs @@ -3,10 +3,13 @@ use crate::GroupType; use snarkos_curves::edwards_bls12::{EdwardsAffine, EdwardsParameters, Fq}; use snarkos_curves::templates::twisted_edwards_extended::GroupAffine; +use snarkos_errors::gadgets::SynthesisError; use snarkos_gadgets::curves::edwards_bls12::EdwardsBlsGadget; use snarkos_models::curves::{AffineCurve, ModelParameters}; -use snarkos_models::gadgets::curves::GroupGadget; +use snarkos_models::gadgets::curves::{FpGadget, GroupGadget}; use snarkos_models::gadgets::r1cs::ConstraintSystem; +use snarkos_models::gadgets::utilities::boolean::Boolean; +use snarkos_models::gadgets::utilities::eq::{ConditionalEqGadget, EqGadget}; use std::ops::Sub; use std::str::FromStr; @@ -84,3 +87,82 @@ impl GroupType<::BaseField, Fq> for Edward } } } + +impl PartialEq for EdwardsGroupType { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (EdwardsGroupType::Constant(self_value), EdwardsGroupType::Constant(other_value)) => { + self_value == other_value + } + + (EdwardsGroupType::Allocated(self_value), EdwardsGroupType::Allocated(other_value)) => { + self_value.eq(other_value) + } + + ( + EdwardsGroupType::Constant(constant_value), + EdwardsGroupType::Allocated(allocated_value), + ) + | ( + EdwardsGroupType::Allocated(allocated_value), + EdwardsGroupType::Constant(constant_value), + ) => , Fq>>::get_value( + allocated_value, + ) + .map(|allocated_value| allocated_value == *constant_value) + .unwrap_or(false), + } + } +} + +impl Eq for EdwardsGroupType {} + +impl EqGadget for EdwardsGroupType {} + +impl ConditionalEqGadget for EdwardsGroupType { + #[inline] + fn conditional_enforce_equal>( + &self, + mut cs: CS, + other: &Self, + condition: &Boolean, + ) -> Result<(), SynthesisError> { + match (self, other) { + // c - c + (EdwardsGroupType::Constant(self_value), EdwardsGroupType::Constant(other_value)) => { + if self_value == other_value { + return Ok(()); + } + Err(SynthesisError::AssignmentMissing) + } + // a - a + (EdwardsGroupType::Allocated(self_value), EdwardsGroupType::Allocated(other_value)) => { + return ::conditional_enforce_equal( + self_value, + cs, + other_value, + condition, + ) + } + // c - a = a - c + ( + EdwardsGroupType::Constant(constant_value), + EdwardsGroupType::Allocated(allocated_value), + ) + | ( + EdwardsGroupType::Allocated(allocated_value), + EdwardsGroupType::Constant(constant_value), + ) => { + let x = FpGadget::from(&mut cs, &constant_value.x); + let y = FpGadget::from(&mut cs, &constant_value.y); + let constant_gadget = EdwardsBlsGadget::new(x, y); + + constant_gadget.conditional_enforce_equal(cs, allocated_value, condition) + } + } + } + + fn cost() -> usize { + 2 * >::cost() //upper bound + } +} diff --git a/compiler/src/group/mod.rs b/compiler/src/group/mod.rs index 13546c919e..07f061f2f9 100644 --- a/compiler/src/group/mod.rs +++ b/compiler/src/group/mod.rs @@ -1,11 +1,14 @@ use crate::errors::GroupError; use snarkos_models::curves::Field; use snarkos_models::gadgets::r1cs::ConstraintSystem; +use snarkos_models::gadgets::utilities::eq::{ConditionalEqGadget, EqGadget}; use std::fmt::Debug; pub mod edwards_bls12; -pub trait GroupType: Sized + Clone + Debug { +pub trait GroupType: + Sized + Clone + Debug + EqGadget + ConditionalEqGadget +{ fn constant(string: String) -> Result; fn add>(&self, cs: CS, other: &Self) -> Result; diff --git a/compiler/tests/boolean/mod.rs b/compiler/tests/boolean/mod.rs index 022cb390cd..bb27c84c95 100644 --- a/compiler/tests/boolean/mod.rs +++ b/compiler/tests/boolean/mod.rs @@ -9,7 +9,7 @@ use snarkos_models::gadgets::utilities::boolean::Boolean; const DIRECTORY_NAME: &str = "tests/boolean/"; -fn output_true(program: EdwardsTestCompiler) { +pub fn output_true(program: EdwardsTestCompiler) { let output = get_output(program); assert_eq!( EdwardsConstrainedValue::Return(vec![ConstrainedValue::Boolean(Boolean::Constant(true))]) @@ -18,7 +18,7 @@ fn output_true(program: EdwardsTestCompiler) { ); } -fn output_false(program: EdwardsTestCompiler) { +pub fn output_false(program: EdwardsTestCompiler) { let output = get_output(program); assert_eq!( EdwardsConstrainedValue::Return(vec![ConstrainedValue::Boolean(Boolean::Constant(false))]) diff --git a/compiler/tests/group/assert_eq_false.leo b/compiler/tests/group/assert_eq_false.leo new file mode 100644 index 0000000000..1ac3f02fd7 --- /dev/null +++ b/compiler/tests/group/assert_eq_false.leo @@ -0,0 +1,6 @@ +function main() { + let point_1 = (7374112779530666882856915975292384652154477718021969292781165691637980424078, 3435195339177955418892975564890903138308061187980579490487898366607011481796)group; + let point_2 = (1005842117974384149622370061042978581211342111653966059496918451529532134799, 79389132189982034519597104273449021362784864778548730890166152019533697186)group; + + assert_eq!(point_1, point_2); +} \ No newline at end of file diff --git a/compiler/tests/group/assert_eq_true.leo b/compiler/tests/group/assert_eq_true.leo new file mode 100644 index 0000000000..b940016ce0 --- /dev/null +++ b/compiler/tests/group/assert_eq_true.leo @@ -0,0 +1,6 @@ +function main() { + let point_1 = (7374112779530666882856915975292384652154477718021969292781165691637980424078, 3435195339177955418892975564890903138308061187980579490487898366607011481796)group; + let point_2 = (7374112779530666882856915975292384652154477718021969292781165691637980424078, 3435195339177955418892975564890903138308061187980579490487898366607011481796)group; + + assert_eq!(point_1, point_2); +} \ No newline at end of file diff --git a/compiler/tests/group/eq_false.leo b/compiler/tests/group/eq_false.leo new file mode 100644 index 0000000000..62630c66ac --- /dev/null +++ b/compiler/tests/group/eq_false.leo @@ -0,0 +1,6 @@ +function main() -> bool { + let point_1 = (7374112779530666882856915975292384652154477718021969292781165691637980424078, 3435195339177955418892975564890903138308061187980579490487898366607011481796)group; + let point_2 = (1005842117974384149622370061042978581211342111653966059496918451529532134799, 79389132189982034519597104273449021362784864778548730890166152019533697186)group; + + return point_1 == point_2 +} \ No newline at end of file diff --git a/compiler/tests/group/eq_true.leo b/compiler/tests/group/eq_true.leo new file mode 100644 index 0000000000..d98a3ca492 --- /dev/null +++ b/compiler/tests/group/eq_true.leo @@ -0,0 +1,6 @@ +function main() -> bool { + let point_1 = (7374112779530666882856915975292384652154477718021969292781165691637980424078, 3435195339177955418892975564890903138308061187980579490487898366607011481796)group; + let point_2 = (7374112779530666882856915975292384652154477718021969292781165691637980424078, 3435195339177955418892975564890903138308061187980579490487898366607011481796)group; + + return point_1 == point_2 +} \ No newline at end of file diff --git a/compiler/tests/group/mod.rs b/compiler/tests/group/mod.rs index d4cac1166a..8d8cc5671c 100644 --- a/compiler/tests/group/mod.rs +++ b/compiler/tests/group/mod.rs @@ -1,10 +1,12 @@ -use crate::{compile_program, get_output, EdwardsConstrainedValue, EdwardsTestCompiler}; +use crate::{compile_program, get_error, get_output, EdwardsConstrainedValue, EdwardsTestCompiler}; use leo_compiler::group::edwards_bls12::EdwardsGroupType; use leo_compiler::ConstrainedValue; use snarkos_curves::edwards_bls12::EdwardsAffine; use snarkos_models::curves::Group; +use crate::boolean::{output_false, output_true}; +use leo_compiler::errors::{CompilerError, FunctionError, StatementError}; use std::str::FromStr; const DIRECTORY_NAME: &str = "tests/group/"; @@ -27,6 +29,15 @@ fn output_zero(program: EdwardsTestCompiler) { output_expected(program, EdwardsAffine::zero()) } +fn fail_enforce(program: EdwardsTestCompiler) { + match get_error(program) { + CompilerError::FunctionError(FunctionError::StatementError( + StatementError::SynthesisError(_), + )) => {} + error => panic!("Expected evaluate error, got {}", error), + } +} + #[test] fn test_zero() { let program = compile_program(DIRECTORY_NAME, "zero.leo").unwrap(); @@ -65,3 +76,27 @@ fn test_sub() { let program = compile_program(DIRECTORY_NAME, "sub.leo").unwrap(); output_expected(program, sum); } + +#[test] +fn test_eq_true() { + let program = compile_program(DIRECTORY_NAME, "eq_true.leo").unwrap(); + output_true(program) +} + +#[test] +fn test_eq_false() { + let program = compile_program(DIRECTORY_NAME, "eq_false.leo").unwrap(); + output_false(program) +} + +#[test] +fn test_assert_eq_true() { + let program = compile_program(DIRECTORY_NAME, "assert_eq_true.leo").unwrap(); + let _res = get_output(program); +} + +#[test] +fn test_assert_eq_false() { + let program = compile_program(DIRECTORY_NAME, "assert_eq_false.leo").unwrap(); + fail_enforce(program); +}