impl mutable statements in basic conditional

This commit is contained in:
collin 2020-06-13 00:47:09 -07:00
parent 4c64edb032
commit 2d17b39da6
6 changed files with 214 additions and 54 deletions

View File

@ -23,7 +23,7 @@ use snarkos_models::{
curves::{Field, PrimeField},
gadgets::{
r1cs::ConstraintSystem,
utilities::{boolean::Boolean, eq::ConditionalEqGadget, uint::UInt32},
utilities::{boolean::Boolean, eq::ConditionalEqGadget, select::CondSelectGadget, uint::UInt32},
},
};
@ -52,10 +52,13 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
cs: &mut CS,
file_scope: String,
function_scope: String,
indicator: Option<Boolean>,
name: String,
range_or_expression: RangeOrExpression,
new_value: ConstrainedValue<F, G>,
) -> Result<(), StatementError> {
let condition = indicator.unwrap_or(Boolean::Constant(true));
// Resolve index so we know if we are assigning to a single value or a range of values
match range_or_expression {
RangeOrExpression::Expression(index) => {
@ -64,7 +67,11 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
// Modify the single value of the array in place
match self.get_mutable_assignee(name)? {
ConstrainedValue::Array(old) => {
old[index] = new_value;
let selected_value =
ConstrainedValue::conditionally_select(cs, &condition, &new_value, &old[index]).map_err(
|_| StatementError::SelectFail(new_value.to_string(), old[index].to_string()),
)?;
old[index] = selected_value;
}
_ => return Err(StatementError::ArrayAssignIndex),
}
@ -79,14 +86,20 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
None => None,
};
// Modify the range of values of the array in place
match (self.get_mutable_assignee(name)?, new_value) {
(ConstrainedValue::Array(old), ConstrainedValue::Array(ref new)) => {
let to_index = to_index_option.unwrap_or(old.len());
old.splice(from_index..to_index, new.iter().cloned());
// Modify the range of values of the array
let old_array = self.get_mutable_assignee(name)?;
let new_array = match (old_array.clone(), new_value) {
(ConstrainedValue::Array(mut mutable), ConstrainedValue::Array(new)) => {
let to_index = to_index_option.unwrap_or(mutable.len());
mutable.splice(from_index..to_index, new.iter().cloned());
ConstrainedValue::Array(mutable)
}
_ => return Err(StatementError::ArrayAssignRange),
}
};
let selected_array = ConstrainedValue::conditionally_select(cs, &condition, &new_array, old_array)
.map_err(|_| StatementError::SelectFail(new_array.to_string(), old_array.to_string()))?;
*old_array = selected_array;
}
}
@ -95,10 +108,14 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
fn mutute_circuit_field(
&mut self,
cs: &mut CS,
indicator: Option<Boolean>,
circuit_name: String,
object_name: Identifier,
new_value: ConstrainedValue<F, G>,
) -> Result<(), StatementError> {
let condition = indicator.unwrap_or(Boolean::Constant(true));
match self.get_mutable_assignee(circuit_name)? {
ConstrainedValue::CircuitExpression(_variable, members) => {
// Modify the circuit field in place
@ -114,7 +131,14 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
ConstrainedValue::Static(_value) => {
return Err(StatementError::ImmutableCircuitFunction("static".into()));
}
_ => object.1 = new_value.to_owned(),
_ => {
let selected_value = ConstrainedValue::conditionally_select(
cs, &condition, &new_value, &object.1,
)
.map_err(|_| StatementError::SelectFail(new_value.to_string(), object.1.to_string()))?;
object.1 = selected_value.to_owned();
}
},
None => return Err(StatementError::UndefinedCircuitObject(object_name.to_string())),
}
@ -130,6 +154,7 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
cs: &mut CS,
file_scope: String,
function_scope: String,
indicator: Option<Boolean>,
assignee: Assignee,
expression: Expression,
) -> Result<(), StatementError> {
@ -142,9 +167,12 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
// Mutate the old value into the new value
match assignee {
Assignee::Identifier(_identifier) => {
let condition = indicator.unwrap_or(Boolean::Constant(true));
let old_value = self.get_mutable_assignee(variable_name.clone())?;
let selected_value = ConstrainedValue::conditionally_select(cs, &condition, &new_value, old_value)
.map_err(|_| StatementError::SelectFail(new_value.to_string(), old_value.to_string()))?;
*old_value = new_value;
*old_value = selected_value;
Ok(())
}
@ -152,12 +180,13 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
cs,
file_scope,
function_scope,
indicator,
variable_name,
range_or_expression,
new_value,
),
Assignee::CircuitField(_assignee, object_name) => {
self.mutute_circuit_field(variable_name, object_name, new_value)
self.mutute_circuit_field(cs, indicator, variable_name, object_name, new_value)
}
}
}
@ -404,36 +433,7 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
right: &ConstrainedValue<F, G>,
) -> Result<(), StatementError> {
let condition = indicator.unwrap_or(Boolean::Constant(true));
let result = match (left, right) {
(ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => bool_1.conditional_enforce_equal(
cs.ns(|| format!("{} == {}", left.to_string(), right.to_string())),
bool_2,
&condition,
),
(ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => num_1.conditional_enforce_equal(
cs.ns(|| format!("{} == {}", left.to_string(), right.to_string())),
num_2,
&condition,
),
(ConstrainedValue::Field(fe_1), ConstrainedValue::Field(fe_2)) => fe_1.conditional_enforce_equal(
cs.ns(|| format!("{} == {}", left.to_string(), right.to_string())),
fe_2,
&condition,
),
(ConstrainedValue::Group(ge_1), ConstrainedValue::Group(ge_2)) => ge_1.conditional_enforce_equal(
cs.ns(|| format!("{} == {}", left.to_string(), right.to_string())),
ge_2,
&condition,
),
(ConstrainedValue::Array(arr_1), ConstrainedValue::Array(arr_2)) => {
for (left, right) in arr_1.into_iter().zip(arr_2.into_iter()) {
self.enforce_assert_eq_statement(cs, indicator.clone(), left, right)?;
}
Ok(())
}
(val_1, val_2) => return Err(StatementError::AssertEq(val_1.to_string(), val_2.to_string())),
};
let result = left.conditional_enforce_equal(cs, right, &condition);
Ok(result.map_err(|_| StatementError::AssertionFailed(left.to_string(), right.to_string()))?)
}
@ -456,7 +456,7 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
self.enforce_definition_statement(cs, file_scope, function_scope, variable, expression)?;
}
Statement::Assign(variable, expression) => {
self.enforce_assign_statement(cs, file_scope, function_scope, variable, expression)?;
self.enforce_assign_statement(cs, file_scope, function_scope, indicator, variable, expression)?;
}
Statement::MultipleAssign(variables, function) => {
self.enforce_multiple_definition_statement(cs, file_scope, function_scope, variables, function)?;

View File

@ -3,11 +3,17 @@
use crate::{errors::ValueError, FieldType, GroupType};
use leo_types::{Circuit, Function, Identifier, Integer, IntegerType, Type};
use snarkos_errors::gadgets::SynthesisError;
use snarkos_models::{
curves::{Field, PrimeField},
gadgets::utilities::{
boolean::Boolean,
uint::{UInt128, UInt16, UInt32, UInt64, UInt8},
gadgets::{
r1cs::ConstraintSystem,
utilities::{
boolean::Boolean,
eq::ConditionalEqGadget,
select::CondSelectGadget,
uint::{UInt128, UInt16, UInt32, UInt64, UInt8},
},
},
};
use std::fmt;
@ -139,3 +145,118 @@ impl<F: Field + PrimeField, G: GroupType<F>> fmt::Debug for ConstrainedValue<F,
write!(f, "{}", self)
}
}
impl<F: Field + PrimeField, G: GroupType<F>> ConditionalEqGadget<F> for ConstrainedValue<F, G> {
fn conditional_enforce_equal<CS: ConstraintSystem<F>>(
&self,
mut cs: CS,
other: &Self,
condition: &Boolean,
) -> Result<(), SynthesisError> {
match (self, other) {
(ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => bool_1.conditional_enforce_equal(
cs.ns(|| format!("{} == {}", self.to_string(), other.to_string())),
bool_2,
&condition,
),
(ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => num_1.conditional_enforce_equal(
cs.ns(|| format!("{} == {}", self.to_string(), other.to_string())),
num_2,
&condition,
),
(ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => field_1.conditional_enforce_equal(
cs.ns(|| format!("{} == {}", self.to_string(), other.to_string())),
field_2,
&condition,
),
(ConstrainedValue::Group(group_1), ConstrainedValue::Group(group_2)) => group_1.conditional_enforce_equal(
cs.ns(|| format!("{} == {}", self.to_string(), other.to_string())),
group_2,
&condition,
),
(ConstrainedValue::Array(arr_1), ConstrainedValue::Array(arr_2)) => {
for (i, (left, right)) in arr_1.into_iter().zip(arr_2.into_iter()).enumerate() {
left.conditional_enforce_equal(
cs.ns(|| format!("array[{}] equal {} == {}", i, left.to_string(), right.to_string())),
right,
&condition,
)?;
}
Ok(())
}
(_, _) => return Err(SynthesisError::Unsatisfiable),
}
}
fn cost() -> usize {
unimplemented!()
}
}
impl<F: Field + PrimeField, G: GroupType<F>> CondSelectGadget<F> for ConstrainedValue<F, G> {
fn conditionally_select<CS: ConstraintSystem<F>>(
mut cs: CS,
cond: &Boolean,
first: &Self,
second: &Self,
) -> Result<Self, SynthesisError> {
Ok(match (first, second) {
(ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => {
ConstrainedValue::Boolean(Boolean::conditionally_select(
cs.ns(|| format!("if cond ? {} else {}", first.to_string(), second.to_string())),
cond,
bool_1,
bool_2,
)?)
}
(ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => {
ConstrainedValue::Integer(Integer::conditionally_select(
cs.ns(|| format!("if cond ? {} else {}", first.to_string(), second.to_string())),
cond,
num_1,
num_2,
)?)
}
(ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => {
ConstrainedValue::Field(FieldType::conditionally_select(
cs.ns(|| format!("if cond ? {} else {}", first.to_string(), second.to_string())),
cond,
field_1,
field_2,
)?)
}
(ConstrainedValue::Group(group_1), ConstrainedValue::Group(group_2)) => {
ConstrainedValue::Group(G::conditionally_select(
cs.ns(|| format!("if cond ? {} else {}", first.to_string(), second.to_string())),
cond,
group_1,
group_2,
)?)
}
(ConstrainedValue::Array(arr_1), ConstrainedValue::Array(arr_2)) => {
let mut array = vec![];
for (i, (first, second)) in arr_1.into_iter().zip(arr_2.into_iter()).enumerate() {
array.push(Self::conditionally_select(
cs.ns(|| {
format!(
"array[{}] = if cond ? {} else {}",
i,
first.to_string(),
second.to_string()
)
}),
cond,
first,
second,
)?);
}
ConstrainedValue::Array(array)
}
(_, _) => return Err(SynthesisError::Unsatisfiable),
})
}
fn cost() -> usize {
unimplemented!() //lower bound 1, upper bound 128 or length of static array
}
}

View File

@ -49,6 +49,9 @@ pub enum StatementError {
#[error("Function return statement expected {} return values, got {}", _0, _1)]
InvalidNumberOfReturns(usize, usize),
#[error("Conditional select gadget failed to select between {} or {}", _0, _1)]
SelectFail(String, String),
#[error("{}", _0)]
SynthesisError(#[from] SynthesisError),

View File

@ -1,7 +1,7 @@
function main(bit: private u8) {
if bit == 1u8 {
assert_eq!(bit, 1u8);
function main(bit: private u32) {
if bit == 1u32 {
assert_eq!(bit, 1u32);
} else {
assert_eq!(bit, 0u8);
assert_eq!(bit, 0u32);
}
}

View File

@ -0,0 +1,11 @@
function main(bit: private u32) -> u32 {
let mut a = 5u32;
if bit == 1u32 {
a = 1u32;
} else {
a = 0u32;
}
return a
}

View File

@ -1,6 +1,14 @@
use crate::{get_output, parse_program, EdwardsConstrainedValue, EdwardsTestCompiler};
use leo_inputs::types::{IntegerType, U8Type};
use crate::{
boolean::{output_false, output_true},
get_output,
integers::u32::{output_one, output_zero},
parse_program,
EdwardsConstrainedValue,
EdwardsTestCompiler,
};
use leo_inputs::types::{IntegerType, U32Type};
use leo_types::InputValue;
use snarkos_curves::edwards_bls12::Fq;
use snarkos_models::gadgets::r1cs::TestConstraintSystem;
@ -28,18 +36,35 @@ fn conditional_basic() {
// Check that an input value of 1 satisfies the constraint system
program_1_pass.set_inputs(vec![Some(InputValue::Integer(IntegerType::U8Type(U8Type {}), 1))]);
program_1_pass.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 1))]);
empty_output_satisfied(program_1_pass);
// Check that an input value of 0 satisfies the constraint system
program_0_pass.set_inputs(vec![Some(InputValue::Integer(IntegerType::U8Type(U8Type {}), 0))]);
program_0_pass.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 0))]);
empty_output_satisfied(program_0_pass);
// Check that an input value of 2 does not satisfy the constraint system
program_2_fail.set_inputs(vec![Some(InputValue::Integer(IntegerType::U8Type(U8Type {}), 2))]);
program_2_fail.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 2))]);
let mut cs = TestConstraintSystem::<Fq>::new();
let _output = program_2_fail.compile_constraints(&mut cs).unwrap();
assert!(!cs.is_satisfied());
}
#[test]
fn conditional_mutate() {
let bytes = include_bytes!("conditional_mutate.leo");
let mut program_1_true = parse_program(bytes).unwrap();
let mut program_0_pass = program_1_true.clone();
// Check that an input value of 1 satisfies the constraint system
program_1_true.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 1))]);
output_one(program_1_true);
// Check that an input value of 0 satisfies the constraint system
program_0_pass.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 0))]);
output_zero(program_0_pass);
}