impl multiple returns and test

This commit is contained in:
collin 2020-06-23 23:56:51 -07:00
parent 2291cc51fd
commit 25a66d1f58
6 changed files with 224 additions and 56 deletions

View File

@ -11,9 +11,13 @@ use crate::{
};
use leo_types::{Expression, Function, InputValue, Integer, Program, Span, Type};
use crate::errors::StatementError;
use snarkos_models::{
curves::{Field, PrimeField},
gadgets::r1cs::ConstraintSystem,
gadgets::{
r1cs::ConstraintSystem,
utilities::{boolean::Boolean, select::CondSelectGadget},
},
};
impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
@ -45,6 +49,51 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
}
}
/// iterates through a vector of results and selects one based off of indicators
fn conditionally_select_result<CS: ConstraintSystem<F>>(
cs: &mut CS,
return_value: &mut ConstrainedValue<F, G>,
results: Vec<(Option<Boolean>, ConstrainedValue<F, G>)>,
span: Span,
) -> Result<(), StatementError> {
// if there are no results, continue
if results.len() == 0 {
return Ok(());
}
// If all indicators are none, then there are no branch conditions in the function.
// We simply return the last result.
if let None = results.iter().find(|(indicator, _res)| indicator.is_some()) {
let result = &results[results.len() - 1].1;
*return_value = result.clone();
return Ok(());
}
// If there are branches in the function we need to use the `ConditionalSelectGadget` to parse through and select the correct one.
// This can be thought of as de-multiplexing all previous wires that may have returned results into one.
for (i, (indicator, result)) in results.into_iter().enumerate() {
// Set the first value as the starting point
if i == 0 {
*return_value = result.clone();
}
let condition = indicator.unwrap_or(Boolean::Constant(true));
let name_unique = format!("select {} {}:{}", result, span.line, span.start);
let selected_value =
ConstrainedValue::conditionally_select(cs.ns(|| name_unique), &condition, &result, return_value)
.map_err(|_| {
StatementError::select_fail(result.to_string(), return_value.to_string(), span.clone())
})?;
*return_value = selected_value;
}
Ok(())
}
pub(crate) fn enforce_function<CS: ConstraintSystem<F>>(
&mut self,
cs: &mut CS,
@ -79,26 +128,38 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
self.store(input_program_identifier, input_value);
}
// Evaluate function statements
// Evaluate every statement in the function and save all potential results
let mut return_values = ConstrainedValue::Return(vec![]);
let mut results = vec![];
for statement in function.statements.iter() {
if let Some(returned) = self.enforce_statement(
let mut result = self.enforce_statement(
cs,
scope.clone(),
function_name.clone(),
None,
statement.clone(),
function.returns.clone(),
)? {
return_values = returned;
break;
}
)?;
results.append(&mut result);
}
println!("{:?}", results);
// Conditionally select a result based on returned indicators
let mut return_values = ConstrainedValue::Return(vec![]);
Self::conditionally_select_result(cs, &mut return_values, results, function.span.clone())?;
if let ConstrainedValue::Return(ref returns) = return_values {
Self::check_arguments_length(function.returns.len(), returns.len(), function.span.clone())?;
if function.returns.len() != returns.len() {
return Err(FunctionError::return_arguments_length(
function.returns.len(),
returns.len(),
function.span.clone(),
));
}
}
Ok(return_values)

View File

@ -390,24 +390,23 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
indicator: Option<Boolean>,
statements: Vec<Statement>,
return_types: Vec<Type>,
) -> Result<Option<ConstrainedValue<F, G>>, StatementError> {
let mut res = None;
// Evaluate statements and possibly return early
) -> Result<Vec<(Option<Boolean>, ConstrainedValue<F, G>)>, StatementError> {
let mut results = vec![];
// Evaluate statements. Only allow a single return argument to be returned.
for statement in statements.iter() {
if let Some(early_return) = self.enforce_statement(
let mut value = self.enforce_statement(
cs,
file_scope.clone(),
function_scope.clone(),
indicator.clone(),
statement.clone(),
return_types.clone(),
)? {
res = Some(early_return);
break;
}
)?;
results.append(&mut value);
}
Ok(res)
Ok(results)
}
/// Enforces a statements.conditional statement with one or more branches.
@ -423,7 +422,7 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
statement: ConditionalStatement,
return_types: Vec<Type>,
span: Span,
) -> Result<Option<ConstrainedValue<F, G>>, StatementError> {
) -> Result<Vec<(Option<Boolean>, ConstrainedValue<F, G>)>, StatementError> {
let statement_string = statement.to_string();
let outer_indicator = indicator.unwrap_or(Boolean::Constant(true));
@ -459,8 +458,10 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
)
.map_err(|_| StatementError::indicator_calculation(branch_1_name, span.clone()))?;
let mut results = vec![];
// Execute branch 1
self.evaluate_branch(
let mut branch_1_result = self.evaluate_branch(
cs,
file_scope.clone(),
function_scope.clone(),
@ -469,6 +470,8 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
return_types.clone(),
)?;
results.append(&mut branch_1_result);
// Determine nested branch 2 selection
let inner_indicator = inner_indicator.not();
let inner_indicator_string = inner_indicator
@ -487,7 +490,7 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
.map_err(|_| StatementError::indicator_calculation(branch_2_name, span.clone()))?;
// Execute branch 2
match statement.next {
let mut branch_2_result = match statement.next {
Some(next) => match next {
ConditionalNestedOrEndStatement::Nested(nested) => self.enforce_conditional_statement(
cs,
@ -497,7 +500,7 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
*nested,
return_types,
span,
),
)?,
ConditionalNestedOrEndStatement::End(statements) => self.evaluate_branch(
cs,
file_scope,
@ -505,10 +508,14 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
Some(branch_2_indicator),
statements,
return_types,
),
)?,
},
None => Ok(None),
}
None => vec![],
};
results.append(&mut branch_2_result);
Ok(results)
}
fn enforce_for_statement<CS: ConstraintSystem<F>>(
@ -523,8 +530,8 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
statements: Vec<Statement>,
return_types: Vec<Type>,
span: Span,
) -> Result<Option<ConstrainedValue<F, G>>, StatementError> {
let mut res = None;
) -> Result<Vec<(Option<Boolean>, ConstrainedValue<F, G>)>, StatementError> {
let mut results = vec![];
let from = start.to_usize(span.clone())?;
let to = stop.to_usize(span.clone())?;
@ -540,20 +547,19 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
// Evaluate statements and possibly return early
let name_unique = format!("for loop iteration {} {}:{}", i, span.line, span.start);
if let Some(early_return) = self.evaluate_branch(
let mut result = self.evaluate_branch(
&mut cs.ns(|| name_unique),
file_scope.clone(),
function_scope.clone(),
indicator,
statements.clone(),
return_types.clone(),
)? {
res = Some(early_return);
break;
}
)?;
results.append(&mut result);
}
Ok(res)
Ok(results)
}
fn enforce_assert_eq_statement<CS: ConstraintSystem<F>>(
@ -571,6 +577,10 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
Ok(result.map_err(|_| StatementError::assertion_failed(left.to_string(), right.to_string(), span))?)
}
/// Enforce a program statement.
/// Returns a Vector of (indicator, value) tuples.
/// Each evaluated statement may execute of one or more statements that may return early.
/// To indicate which of these return values to take we conditionally select that value with the indicator bit.
pub(crate) fn enforce_statement<CS: ConstraintSystem<F>>(
&mut self,
cs: &mut CS,
@ -579,18 +589,16 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
indicator: Option<Boolean>,
statement: Statement,
return_types: Vec<Type>,
) -> Result<Option<ConstrainedValue<F, G>>, StatementError> {
let mut res = None;
) -> Result<Vec<(Option<Boolean>, ConstrainedValue<F, G>)>, StatementError> {
let mut results = vec![];
match statement {
Statement::Return(expressions, span) => {
res = Some(self.enforce_return_statement(
cs,
file_scope,
function_scope,
expressions,
return_types,
span,
)?);
let return_value = (
indicator,
self.enforce_return_statement(cs, file_scope, function_scope, expressions, return_types, span)?,
);
results.push(return_value);
}
Statement::Definition(declare, variable, expression, span) => {
self.enforce_definition_statement(cs, file_scope, function_scope, declare, variable, expression, span)?;
@ -602,7 +610,7 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
self.enforce_multiple_definition_statement(cs, file_scope, function_scope, variables, function, span)?;
}
Statement::Conditional(statement, span) => {
if let Some(early_return) = self.enforce_conditional_statement(
let mut result = self.enforce_conditional_statement(
cs,
file_scope,
function_scope,
@ -610,12 +618,12 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
statement,
return_types,
span,
)? {
res = Some(early_return)
}
)?;
results.append(&mut result);
}
Statement::For(index, start, stop, statements, span) => {
if let Some(early_return) = self.enforce_for_statement(
let mut result = self.enforce_for_statement(
cs,
file_scope,
function_scope,
@ -626,9 +634,9 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
statements,
return_types,
span,
)? {
res = Some(early_return)
}
)?;
results.append(&mut result);
}
Statement::AssertEq(left, right, span) => {
let (resolved_left, resolved_right) =
@ -637,17 +645,25 @@ impl<F: Field + PrimeField, G: GroupType<F>> ConstrainedProgram<F, G> {
self.enforce_assert_eq_statement(cs, indicator, &resolved_left, &resolved_right, span)?;
}
Statement::Expression(expression, span) => {
match self.enforce_expression(cs, file_scope, function_scope, &vec![], expression.clone())? {
let expression_string = expression.to_string();
let value = self.enforce_expression(cs, file_scope, function_scope, &vec![], expression)?;
// handle empty return value cases
match &value {
ConstrainedValue::Return(values) => {
if !values.is_empty() {
return Err(StatementError::unassigned(expression.to_string(), span));
return Err(StatementError::unassigned(expression_string, span));
}
}
_ => return Err(StatementError::unassigned(expression.to_string(), span)),
_ => return Err(StatementError::unassigned(expression_string, span)),
}
let result = (indicator, value);
results.push(result);
}
};
Ok(res)
Ok(results)
}
}

View File

@ -319,16 +319,59 @@ impl<F: Field + PrimeField, G: GroupType<F>> CondSelectGadget<F> for Constrained
}
(ConstrainedValue::Array(arr_1), ConstrainedValue::Array(arr_2)) => {
let mut array = vec![];
for (i, (first, second)) in arr_1.into_iter().zip(arr_2.into_iter()).enumerate() {
array.push(Self::conditionally_select(
cs.ns(|| format!("array[{}]", i,)),
cs.ns(|| format!("array[{}]", i)),
cond,
first,
second,
)?);
}
ConstrainedValue::Array(array)
}
(ConstrainedValue::Function(identifier_1, function_1), ConstrainedValue::Function(_, _)) => {
// This is a no-op. functions cannot hold circuit values
// However, we must return a result here
ConstrainedValue::Function(identifier_1.clone(), function_1.clone())
}
(
ConstrainedValue::CircuitExpression(identifier, members_1),
ConstrainedValue::CircuitExpression(_identifier, members_2),
) => {
let mut members = vec![];
for (i, (first, second)) in members_1.into_iter().zip(members_2.into_iter()).enumerate() {
members.push(ConstrainedCircuitMember::conditionally_select(
cs.ns(|| format!("circuit member[{}]", i)),
cond,
first,
second,
)?);
}
ConstrainedValue::CircuitExpression(identifier.clone(), members)
}
(ConstrainedValue::Return(returns_1), ConstrainedValue::Return(returns_2)) => {
let mut returns = vec![];
for (i, (first, second)) in returns_1.into_iter().zip(returns_2.into_iter()).enumerate() {
returns.push(Self::conditionally_select(
cs.ns(|| format!("return[{}]", i)),
cond,
first,
second,
)?);
}
ConstrainedValue::Return(returns)
}
(ConstrainedValue::Static(first), ConstrainedValue::Static(second)) => {
let value = Self::conditionally_select(cs, cond, first, second)?;
ConstrainedValue::Static(Box::new(value))
}
(ConstrainedValue::Mutable(first), _) => Self::conditionally_select(cs, cond, first, second)?,
(_, ConstrainedValue::Mutable(second)) => Self::conditionally_select(cs, cond, first, second)?,
(_, _) => return Err(SynthesisError::Unsatisfiable),
@ -339,3 +382,21 @@ impl<F: Field + PrimeField, G: GroupType<F>> CondSelectGadget<F> for Constrained
unimplemented!() //lower bound 1, upper bound 128 or length of static array
}
}
impl<F: Field + PrimeField, G: GroupType<F>> CondSelectGadget<F> for ConstrainedCircuitMember<F, G> {
fn conditionally_select<CS: ConstraintSystem<F>>(
cs: CS,
cond: &Boolean,
first: &Self,
second: &Self,
) -> Result<Self, SynthesisError> {
// identifiers will be the same
let value = ConstrainedValue::conditionally_select(cs, cond, &first.1, &second.1)?;
Ok(ConstrainedCircuitMember(first.0.clone(), value))
}
fn cost() -> usize {
unimplemented!()
}
}

View File

@ -58,4 +58,10 @@ impl FunctionError {
Self::new_from_span(message, span)
}
pub fn return_arguments_length(expected: usize, actual: usize, span: Span) -> Self {
let message = format!("function expected {} returns, found {} returns", expected, actual);
Self::new_from_span(message, span)
}
}

View File

@ -127,3 +127,20 @@ fn test_nested() {
program_false_false_0.set_inputs(vec![Some(InputValue::Boolean(false)), Some(InputValue::Boolean(false))]);
output_number(program_false_false_0, 0u32);
}
#[test]
fn test_multiple_returns() {
let bytes = include_bytes!("multiple_returns.leo");
let mut program_true_1 = parse_program(bytes).unwrap();
let mut program_false_0 = program_true_1.clone();
// Check that an input value of true returns 1 and satisfies the constraint system
program_true_1.set_inputs(vec![Some(InputValue::Boolean(true))]);
output_number(program_true_1, 1u32);
// Check that an input value of false returns 0 and satisfies the constraint system
program_false_0.set_inputs(vec![Some(InputValue::Boolean(false))]);
output_number(program_false_0, 0u32);
}

View File

@ -0,0 +1,7 @@
function main(cond: bool) -> u32 {
if cond {
return 1u32
} else {
return 0u32
}
}