diff --git a/compiler/src/constraints/statement.rs b/compiler/src/constraints/statement.rs index 11f8c1aff7..c2fb9f5dc6 100644 --- a/compiler/src/constraints/statement.rs +++ b/compiler/src/constraints/statement.rs @@ -404,19 +404,28 @@ impl, CS: ConstraintSystem> Constraine right: &ConstrainedValue, ) -> Result<(), StatementError> { let condition = indicator.unwrap_or(Boolean::Constant(true)); + let result = match (left, right) { - (ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => { - bool_1.conditional_enforce_equal(cs, bool_2, &condition) - } - (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { - num_1.conditional_enforce_equal(cs, num_2, &condition) - } - (ConstrainedValue::Field(fe_1), ConstrainedValue::Field(fe_2)) => { - fe_1.conditional_enforce_equal(cs, fe_2, &condition) - } - (ConstrainedValue::Group(ge_1), ConstrainedValue::Group(ge_2)) => { - ge_1.conditional_enforce_equal(cs, ge_2, &condition) - } + (ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => bool_1.conditional_enforce_equal( + cs.ns(|| format!("{} == {}", left.to_string(), right.to_string())), + bool_2, + &condition, + ), + (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => num_1.conditional_enforce_equal( + cs.ns(|| format!("{} == {}", left.to_string(), right.to_string())), + num_2, + &condition, + ), + (ConstrainedValue::Field(fe_1), ConstrainedValue::Field(fe_2)) => fe_1.conditional_enforce_equal( + cs.ns(|| format!("{} == {}", left.to_string(), right.to_string())), + fe_2, + &condition, + ), + (ConstrainedValue::Group(ge_1), ConstrainedValue::Group(ge_2)) => ge_1.conditional_enforce_equal( + cs.ns(|| format!("{} == {}", left.to_string(), right.to_string())), + ge_2, + &condition, + ), (ConstrainedValue::Array(arr_1), ConstrainedValue::Array(arr_2)) => { for (left, right) in arr_1.into_iter().zip(arr_2.into_iter()) { self.enforce_assert_eq_statement(cs, indicator.clone(), left, right)?; diff --git a/compiler/tests/conditional/conditional_basic.leo b/compiler/tests/conditional/conditional_basic.leo new file mode 100644 index 0000000000..a04bbd8d1b --- /dev/null +++ b/compiler/tests/conditional/conditional_basic.leo @@ -0,0 +1,7 @@ +function main(bit: private u8) { + if bit == 1u8 { + assert_eq!(bit, 1u8); + } else { + assert_eq!(bit, 0u8); + } +} diff --git a/compiler/tests/conditional/mod.rs b/compiler/tests/conditional/mod.rs new file mode 100644 index 0000000000..7ec0cef8dc --- /dev/null +++ b/compiler/tests/conditional/mod.rs @@ -0,0 +1,45 @@ +use crate::{get_output, parse_program, EdwardsConstrainedValue, EdwardsTestCompiler}; +use leo_inputs::types::{IntegerType, U8Type}; +use leo_types::InputValue; +use snarkos_curves::edwards_bls12::Fq; +use snarkos_models::gadgets::r1cs::TestConstraintSystem; + +fn empty_output_satisfied(program: EdwardsTestCompiler) { + let output = get_output(program); + + assert_eq!(EdwardsConstrainedValue::Return(vec![]).to_string(), output.to_string()); +} + +// Tests a conditional enforceBit() program +// +// function main(bit: private u8) { +// if bit == 1u8 { +// assert_eq!(bit, 1u8); +// } else { +// assert_eq!(bit, 0u8); +// } +// } +#[test] +fn conditional_basic() { + let bytes = include_bytes!("conditional_basic.leo"); + let mut program_1_pass = parse_program(bytes).unwrap(); + let mut program_0_pass = program_1_pass.clone(); + let mut program_2_fail = program_1_pass.clone(); + + // Check that an input value of 1 satisfies the constraint system + + program_1_pass.set_inputs(vec![Some(InputValue::Integer(IntegerType::U8Type(U8Type {}), 1))]); + empty_output_satisfied(program_1_pass); + + // Check that an input value of 0 satisfies the constraint system + + program_0_pass.set_inputs(vec![Some(InputValue::Integer(IntegerType::U8Type(U8Type {}), 0))]); + empty_output_satisfied(program_0_pass); + + // Check that an input value of 2 does not satisfy the constraint system + + program_2_fail.set_inputs(vec![Some(InputValue::Integer(IntegerType::U8Type(U8Type {}), 2))]); + let mut cs = TestConstraintSystem::::new(); + let _output = program_2_fail.compile_constraints(&mut cs).unwrap(); + assert!(!cs.is_satisfied()); +} diff --git a/compiler/tests/mod.rs b/compiler/tests/mod.rs index 1822a963f8..c350c00c30 100644 --- a/compiler/tests/mod.rs +++ b/compiler/tests/mod.rs @@ -1,6 +1,7 @@ pub mod array; pub mod boolean; pub mod circuits; +pub mod conditional; pub mod field; pub mod function; pub mod group;