mirror of
https://github.com/ProvableHQ/leo.git
synced 2024-12-24 18:52:58 +03:00
impl chained and nested conditionals
This commit is contained in:
parent
1e1e4b86d9
commit
1eaaed269d
@ -185,17 +185,20 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
|
||||
}
|
||||
|
||||
/// Evaluate Boolean operations
|
||||
fn evaluate_eq_expression(
|
||||
fn evaluate_eq_expression<CSM: ConstraintSystem<F>>(
|
||||
&mut self,
|
||||
cs: &mut CS,
|
||||
cs: &mut CSM,
|
||||
left: ConstrainedValue<F, G>,
|
||||
right: ConstrainedValue<F, G>,
|
||||
) -> Result<ConstrainedValue<F, G>, ExpressionError> {
|
||||
let mut expression_namespace = cs.ns(|| format!("evaluate {} == {}", left.to_string(), right.to_string()));
|
||||
let result_bool = match (left, right) {
|
||||
(ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => {
|
||||
bool_1.evaluate_equal(cs, &bool_2)?
|
||||
bool_1.evaluate_equal(expression_namespace, &bool_2)?
|
||||
}
|
||||
(ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => {
|
||||
num_1.evaluate_equal(expression_namespace, &num_2)?
|
||||
}
|
||||
(ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => num_1.evaluate_equal(cs, &num_2)?,
|
||||
(ConstrainedValue::Field(fe_1), ConstrainedValue::Field(fe_2)) => {
|
||||
Boolean::Constant(fe_1.eq(&fe_2)) //TODO impl evaluate eq gadget
|
||||
}
|
||||
@ -204,11 +207,11 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
|
||||
}
|
||||
(ConstrainedValue::Unresolved(string), val_2) => {
|
||||
let val_1 = ConstrainedValue::from_other(string, &val_2)?;
|
||||
return self.evaluate_eq_expression(cs, val_1, val_2);
|
||||
return self.evaluate_eq_expression(&mut expression_namespace, val_1, val_2);
|
||||
}
|
||||
(val_1, ConstrainedValue::Unresolved(string)) => {
|
||||
let val_2 = ConstrainedValue::from_other(string, &val_1)?;
|
||||
return self.evaluate_eq_expression(cs, val_1, val_2);
|
||||
return self.evaluate_eq_expression(&mut expression_namespace, val_1, val_2);
|
||||
}
|
||||
(val_1, val_2) => return Err(ExpressionError::IncompatibleTypes(format!("{} == {}", val_1, val_2,))),
|
||||
};
|
||||
|
@ -342,11 +342,15 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
|
||||
cs: &mut CS,
|
||||
file_scope: String,
|
||||
function_scope: String,
|
||||
indicator: Option<Boolean>,
|
||||
statement: ConditionalStatement,
|
||||
return_types: Vec<Type>,
|
||||
) -> Result<Option<ConstrainedValue<F, G>>, StatementError> {
|
||||
let statement_string = statement.to_string();
|
||||
let outer_indicator = indicator.unwrap_or(Boolean::Constant(true));
|
||||
|
||||
let expected_types = vec![Type::Boolean];
|
||||
let indicator = match self.enforce_expression(
|
||||
let inner_indicator = match self.enforce_expression(
|
||||
cs,
|
||||
file_scope.clone(),
|
||||
function_scope.clone(),
|
||||
@ -357,27 +361,45 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
|
||||
value => return Err(StatementError::IfElseConditional(value.to_string())),
|
||||
};
|
||||
|
||||
// Determine nested branch selection
|
||||
let branch_1_indicator = Boolean::and(
|
||||
&mut cs.ns(|| format!("statement branch 1 indicator {}", statement_string)),
|
||||
&outer_indicator,
|
||||
&inner_indicator,
|
||||
)?;
|
||||
|
||||
// Execute branch 1
|
||||
self.evaluate_branch(
|
||||
cs,
|
||||
file_scope.clone(),
|
||||
function_scope.clone(),
|
||||
Some(indicator),
|
||||
Some(branch_1_indicator),
|
||||
statement.statements,
|
||||
return_types.clone(),
|
||||
)?;
|
||||
|
||||
// Execute branch 2
|
||||
let branch_2_indicator = Boolean::and(
|
||||
&mut cs.ns(|| format!("statement branch 2 indicator {}", statement_string)),
|
||||
&outer_indicator,
|
||||
&inner_indicator.not(),
|
||||
)?;
|
||||
|
||||
match statement.next {
|
||||
Some(next) => match next {
|
||||
ConditionalNestedOrEndStatement::Nested(nested) => {
|
||||
self.enforce_conditional_statement(cs, file_scope, function_scope, *nested, return_types)
|
||||
}
|
||||
ConditionalNestedOrEndStatement::Nested(nested) => self.enforce_conditional_statement(
|
||||
cs,
|
||||
file_scope,
|
||||
function_scope,
|
||||
Some(branch_2_indicator),
|
||||
*nested,
|
||||
return_types,
|
||||
),
|
||||
ConditionalNestedOrEndStatement::End(statements) => self.evaluate_branch(
|
||||
cs,
|
||||
file_scope,
|
||||
function_scope,
|
||||
Some(indicator.not()),
|
||||
Some(branch_2_indicator),
|
||||
statements,
|
||||
return_types,
|
||||
),
|
||||
@ -465,9 +487,14 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
|
||||
self.enforce_multiple_definition_statement(cs, file_scope, function_scope, variables, function)?;
|
||||
}
|
||||
Statement::Conditional(statement) => {
|
||||
if let Some(early_return) =
|
||||
self.enforce_conditional_statement(cs, file_scope, function_scope, statement, return_types)?
|
||||
{
|
||||
if let Some(early_return) = self.enforce_conditional_statement(
|
||||
cs,
|
||||
file_scope,
|
||||
function_scope,
|
||||
indicator,
|
||||
statement,
|
||||
return_types,
|
||||
)? {
|
||||
res = Some(early_return)
|
||||
}
|
||||
}
|
||||
|
13
compiler/tests/statements/conditional/chain.leo
Normal file
13
compiler/tests/statements/conditional/chain.leo
Normal file
@ -0,0 +1,13 @@
|
||||
function main(bit: u32) -> u32 {
|
||||
let mut result = 0u32;
|
||||
|
||||
if bit == 1u32 {
|
||||
result = 1u32;
|
||||
} else if bit == 2u32 {
|
||||
result = 2u32;
|
||||
} else {
|
||||
result = 3u32;
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
@ -1,5 +1,4 @@
|
||||
use crate::{
|
||||
boolean::{output_false, output_true},
|
||||
get_output,
|
||||
integers::u32::{output_one, output_zero},
|
||||
parse_program,
|
||||
@ -76,13 +75,56 @@ fn conditional_for_loop() {
|
||||
let mut program_true_6 = parse_program(bytes).unwrap();
|
||||
let mut program_false_0 = program_true_6.clone();
|
||||
|
||||
// Check that an input value of 1 satisfies the constraint system
|
||||
// Check that an input value of true satisfies the constraint system
|
||||
|
||||
program_true_6.set_inputs(vec![Some(InputValue::Boolean(true))]);
|
||||
output_number(program_true_6, 6u32);
|
||||
|
||||
// Check that an input value of 0 satisfies the constraint system
|
||||
// Check that an input value of false satisfies the constraint system
|
||||
|
||||
program_false_0.set_inputs(vec![Some(InputValue::Boolean(false))]);
|
||||
output_zero(program_false_0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conditional_chain() {
|
||||
let bytes = include_bytes!("chain.leo");
|
||||
let mut program_1_1 = parse_program(bytes).unwrap();
|
||||
let mut program_2_2 = program_1_1.clone();
|
||||
let mut program_2_3 = program_1_1.clone();
|
||||
|
||||
// Check that an input of 1 outputs true
|
||||
program_1_1.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 1))]);
|
||||
output_number(program_1_1, 1u32);
|
||||
|
||||
// Check that an input of 0 outputs true
|
||||
program_2_2.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 2))]);
|
||||
output_number(program_2_2, 2u32);
|
||||
|
||||
// Check that an input of 0 outputs true
|
||||
program_2_3.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 5))]);
|
||||
output_number(program_2_3, 3u32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conditional_nested() {
|
||||
let bytes = include_bytes!("nested.leo");
|
||||
let mut program_true_true_3 = parse_program(bytes).unwrap();
|
||||
let mut program_true_false_1 = program_true_true_3.clone();
|
||||
let mut program_false_false_0 = program_true_true_3.clone();
|
||||
|
||||
// Check that an input value of true true satisfies the constraint system
|
||||
|
||||
program_true_true_3.set_inputs(vec![Some(InputValue::Boolean(true)); 2]);
|
||||
output_number(program_true_true_3, 3u32);
|
||||
|
||||
// Check that an input value of true false satisfies the constraint system
|
||||
|
||||
program_true_false_1.set_inputs(vec![Some(InputValue::Boolean(true)), Some(InputValue::Boolean(false))]);
|
||||
output_number(program_true_false_1, 1u32);
|
||||
|
||||
// Check that an input value of false false satisfies the constraint system
|
||||
|
||||
program_false_false_0.set_inputs(vec![Some(InputValue::Boolean(false)), Some(InputValue::Boolean(false))]);
|
||||
output_number(program_false_false_0, 0u32);
|
||||
}
|
||||
|
12
compiler/tests/statements/conditional/nested.leo
Normal file
12
compiler/tests/statements/conditional/nested.leo
Normal file
@ -0,0 +1,12 @@
|
||||
function main(a: bool, b: bool) -> u32 {
|
||||
let mut result = 0u32;
|
||||
|
||||
if a {
|
||||
result += 1;
|
||||
if b {
|
||||
result += 2;
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
Loading…
Reference in New Issue
Block a user