mirror of
https://github.com/ProvableHQ/leo.git
synced 2024-11-28 01:01:53 +03:00
impl mutable statements in basic conditional
This commit is contained in:
parent
4c64edb032
commit
2d17b39da6
@ -23,7 +23,7 @@ use snarkos_models::{
|
||||
curves::{Field, PrimeField},
|
||||
gadgets::{
|
||||
r1cs::ConstraintSystem,
|
||||
utilities::{boolean::Boolean, eq::ConditionalEqGadget, uint::UInt32},
|
||||
utilities::{boolean::Boolean, eq::ConditionalEqGadget, select::CondSelectGadget, uint::UInt32},
|
||||
},
|
||||
};
|
||||
|
||||
@ -52,10 +52,13 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
|
||||
cs: &mut CS,
|
||||
file_scope: String,
|
||||
function_scope: String,
|
||||
indicator: Option<Boolean>,
|
||||
name: String,
|
||||
range_or_expression: RangeOrExpression,
|
||||
new_value: ConstrainedValue<F, G>,
|
||||
) -> Result<(), StatementError> {
|
||||
let condition = indicator.unwrap_or(Boolean::Constant(true));
|
||||
|
||||
// Resolve index so we know if we are assigning to a single value or a range of values
|
||||
match range_or_expression {
|
||||
RangeOrExpression::Expression(index) => {
|
||||
@ -64,7 +67,11 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
|
||||
// Modify the single value of the array in place
|
||||
match self.get_mutable_assignee(name)? {
|
||||
ConstrainedValue::Array(old) => {
|
||||
old[index] = new_value;
|
||||
let selected_value =
|
||||
ConstrainedValue::conditionally_select(cs, &condition, &new_value, &old[index]).map_err(
|
||||
|_| StatementError::SelectFail(new_value.to_string(), old[index].to_string()),
|
||||
)?;
|
||||
old[index] = selected_value;
|
||||
}
|
||||
_ => return Err(StatementError::ArrayAssignIndex),
|
||||
}
|
||||
@ -79,14 +86,20 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
|
||||
None => None,
|
||||
};
|
||||
|
||||
// Modify the range of values of the array in place
|
||||
match (self.get_mutable_assignee(name)?, new_value) {
|
||||
(ConstrainedValue::Array(old), ConstrainedValue::Array(ref new)) => {
|
||||
let to_index = to_index_option.unwrap_or(old.len());
|
||||
old.splice(from_index..to_index, new.iter().cloned());
|
||||
// Modify the range of values of the array
|
||||
let old_array = self.get_mutable_assignee(name)?;
|
||||
let new_array = match (old_array.clone(), new_value) {
|
||||
(ConstrainedValue::Array(mut mutable), ConstrainedValue::Array(new)) => {
|
||||
let to_index = to_index_option.unwrap_or(mutable.len());
|
||||
|
||||
mutable.splice(from_index..to_index, new.iter().cloned());
|
||||
ConstrainedValue::Array(mutable)
|
||||
}
|
||||
_ => return Err(StatementError::ArrayAssignRange),
|
||||
}
|
||||
};
|
||||
let selected_array = ConstrainedValue::conditionally_select(cs, &condition, &new_array, old_array)
|
||||
.map_err(|_| StatementError::SelectFail(new_array.to_string(), old_array.to_string()))?;
|
||||
*old_array = selected_array;
|
||||
}
|
||||
}
|
||||
|
||||
@ -95,10 +108,14 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
|
||||
|
||||
fn mutute_circuit_field(
|
||||
&mut self,
|
||||
cs: &mut CS,
|
||||
indicator: Option<Boolean>,
|
||||
circuit_name: String,
|
||||
object_name: Identifier,
|
||||
new_value: ConstrainedValue<F, G>,
|
||||
) -> Result<(), StatementError> {
|
||||
let condition = indicator.unwrap_or(Boolean::Constant(true));
|
||||
|
||||
match self.get_mutable_assignee(circuit_name)? {
|
||||
ConstrainedValue::CircuitExpression(_variable, members) => {
|
||||
// Modify the circuit field in place
|
||||
@ -114,7 +131,14 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
|
||||
ConstrainedValue::Static(_value) => {
|
||||
return Err(StatementError::ImmutableCircuitFunction("static".into()));
|
||||
}
|
||||
_ => object.1 = new_value.to_owned(),
|
||||
_ => {
|
||||
let selected_value = ConstrainedValue::conditionally_select(
|
||||
cs, &condition, &new_value, &object.1,
|
||||
)
|
||||
.map_err(|_| StatementError::SelectFail(new_value.to_string(), object.1.to_string()))?;
|
||||
|
||||
object.1 = selected_value.to_owned();
|
||||
}
|
||||
},
|
||||
None => return Err(StatementError::UndefinedCircuitObject(object_name.to_string())),
|
||||
}
|
||||
@ -130,6 +154,7 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
|
||||
cs: &mut CS,
|
||||
file_scope: String,
|
||||
function_scope: String,
|
||||
indicator: Option<Boolean>,
|
||||
assignee: Assignee,
|
||||
expression: Expression,
|
||||
) -> Result<(), StatementError> {
|
||||
@ -142,9 +167,12 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
|
||||
// Mutate the old value into the new value
|
||||
match assignee {
|
||||
Assignee::Identifier(_identifier) => {
|
||||
let condition = indicator.unwrap_or(Boolean::Constant(true));
|
||||
let old_value = self.get_mutable_assignee(variable_name.clone())?;
|
||||
let selected_value = ConstrainedValue::conditionally_select(cs, &condition, &new_value, old_value)
|
||||
.map_err(|_| StatementError::SelectFail(new_value.to_string(), old_value.to_string()))?;
|
||||
|
||||
*old_value = new_value;
|
||||
*old_value = selected_value;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -152,12 +180,13 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
|
||||
cs,
|
||||
file_scope,
|
||||
function_scope,
|
||||
indicator,
|
||||
variable_name,
|
||||
range_or_expression,
|
||||
new_value,
|
||||
),
|
||||
Assignee::CircuitField(_assignee, object_name) => {
|
||||
self.mutute_circuit_field(variable_name, object_name, new_value)
|
||||
self.mutute_circuit_field(cs, indicator, variable_name, object_name, new_value)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -404,36 +433,7 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
|
||||
right: &ConstrainedValue<F, G>,
|
||||
) -> Result<(), StatementError> {
|
||||
let condition = indicator.unwrap_or(Boolean::Constant(true));
|
||||
|
||||
let result = match (left, right) {
|
||||
(ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => bool_1.conditional_enforce_equal(
|
||||
cs.ns(|| format!("{} == {}", left.to_string(), right.to_string())),
|
||||
bool_2,
|
||||
&condition,
|
||||
),
|
||||
(ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => num_1.conditional_enforce_equal(
|
||||
cs.ns(|| format!("{} == {}", left.to_string(), right.to_string())),
|
||||
num_2,
|
||||
&condition,
|
||||
),
|
||||
(ConstrainedValue::Field(fe_1), ConstrainedValue::Field(fe_2)) => fe_1.conditional_enforce_equal(
|
||||
cs.ns(|| format!("{} == {}", left.to_string(), right.to_string())),
|
||||
fe_2,
|
||||
&condition,
|
||||
),
|
||||
(ConstrainedValue::Group(ge_1), ConstrainedValue::Group(ge_2)) => ge_1.conditional_enforce_equal(
|
||||
cs.ns(|| format!("{} == {}", left.to_string(), right.to_string())),
|
||||
ge_2,
|
||||
&condition,
|
||||
),
|
||||
(ConstrainedValue::Array(arr_1), ConstrainedValue::Array(arr_2)) => {
|
||||
for (left, right) in arr_1.into_iter().zip(arr_2.into_iter()) {
|
||||
self.enforce_assert_eq_statement(cs, indicator.clone(), left, right)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
(val_1, val_2) => return Err(StatementError::AssertEq(val_1.to_string(), val_2.to_string())),
|
||||
};
|
||||
let result = left.conditional_enforce_equal(cs, right, &condition);
|
||||
|
||||
Ok(result.map_err(|_| StatementError::AssertionFailed(left.to_string(), right.to_string()))?)
|
||||
}
|
||||
@ -456,7 +456,7 @@ impl<F: Field + PrimeField, G: GroupType<F>, CS: ConstraintSystem<F>> Constraine
|
||||
self.enforce_definition_statement(cs, file_scope, function_scope, variable, expression)?;
|
||||
}
|
||||
Statement::Assign(variable, expression) => {
|
||||
self.enforce_assign_statement(cs, file_scope, function_scope, variable, expression)?;
|
||||
self.enforce_assign_statement(cs, file_scope, function_scope, indicator, variable, expression)?;
|
||||
}
|
||||
Statement::MultipleAssign(variables, function) => {
|
||||
self.enforce_multiple_definition_statement(cs, file_scope, function_scope, variables, function)?;
|
||||
|
@ -3,11 +3,17 @@
|
||||
use crate::{errors::ValueError, FieldType, GroupType};
|
||||
use leo_types::{Circuit, Function, Identifier, Integer, IntegerType, Type};
|
||||
|
||||
use snarkos_errors::gadgets::SynthesisError;
|
||||
use snarkos_models::{
|
||||
curves::{Field, PrimeField},
|
||||
gadgets::utilities::{
|
||||
boolean::Boolean,
|
||||
uint::{UInt128, UInt16, UInt32, UInt64, UInt8},
|
||||
gadgets::{
|
||||
r1cs::ConstraintSystem,
|
||||
utilities::{
|
||||
boolean::Boolean,
|
||||
eq::ConditionalEqGadget,
|
||||
select::CondSelectGadget,
|
||||
uint::{UInt128, UInt16, UInt32, UInt64, UInt8},
|
||||
},
|
||||
},
|
||||
};
|
||||
use std::fmt;
|
||||
@ -139,3 +145,118 @@ impl<F: Field + PrimeField, G: GroupType<F>> fmt::Debug for ConstrainedValue<F,
|
||||
write!(f, "{}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field + PrimeField, G: GroupType<F>> ConditionalEqGadget<F> for ConstrainedValue<F, G> {
|
||||
fn conditional_enforce_equal<CS: ConstraintSystem<F>>(
|
||||
&self,
|
||||
mut cs: CS,
|
||||
other: &Self,
|
||||
condition: &Boolean,
|
||||
) -> Result<(), SynthesisError> {
|
||||
match (self, other) {
|
||||
(ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => bool_1.conditional_enforce_equal(
|
||||
cs.ns(|| format!("{} == {}", self.to_string(), other.to_string())),
|
||||
bool_2,
|
||||
&condition,
|
||||
),
|
||||
(ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => num_1.conditional_enforce_equal(
|
||||
cs.ns(|| format!("{} == {}", self.to_string(), other.to_string())),
|
||||
num_2,
|
||||
&condition,
|
||||
),
|
||||
(ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => field_1.conditional_enforce_equal(
|
||||
cs.ns(|| format!("{} == {}", self.to_string(), other.to_string())),
|
||||
field_2,
|
||||
&condition,
|
||||
),
|
||||
(ConstrainedValue::Group(group_1), ConstrainedValue::Group(group_2)) => group_1.conditional_enforce_equal(
|
||||
cs.ns(|| format!("{} == {}", self.to_string(), other.to_string())),
|
||||
group_2,
|
||||
&condition,
|
||||
),
|
||||
(ConstrainedValue::Array(arr_1), ConstrainedValue::Array(arr_2)) => {
|
||||
for (i, (left, right)) in arr_1.into_iter().zip(arr_2.into_iter()).enumerate() {
|
||||
left.conditional_enforce_equal(
|
||||
cs.ns(|| format!("array[{}] equal {} == {}", i, left.to_string(), right.to_string())),
|
||||
right,
|
||||
&condition,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
(_, _) => return Err(SynthesisError::Unsatisfiable),
|
||||
}
|
||||
}
|
||||
|
||||
fn cost() -> usize {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: Field + PrimeField, G: GroupType<F>> CondSelectGadget<F> for ConstrainedValue<F, G> {
|
||||
fn conditionally_select<CS: ConstraintSystem<F>>(
|
||||
mut cs: CS,
|
||||
cond: &Boolean,
|
||||
first: &Self,
|
||||
second: &Self,
|
||||
) -> Result<Self, SynthesisError> {
|
||||
Ok(match (first, second) {
|
||||
(ConstrainedValue::Boolean(bool_1), ConstrainedValue::Boolean(bool_2)) => {
|
||||
ConstrainedValue::Boolean(Boolean::conditionally_select(
|
||||
cs.ns(|| format!("if cond ? {} else {}", first.to_string(), second.to_string())),
|
||||
cond,
|
||||
bool_1,
|
||||
bool_2,
|
||||
)?)
|
||||
}
|
||||
(ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => {
|
||||
ConstrainedValue::Integer(Integer::conditionally_select(
|
||||
cs.ns(|| format!("if cond ? {} else {}", first.to_string(), second.to_string())),
|
||||
cond,
|
||||
num_1,
|
||||
num_2,
|
||||
)?)
|
||||
}
|
||||
(ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => {
|
||||
ConstrainedValue::Field(FieldType::conditionally_select(
|
||||
cs.ns(|| format!("if cond ? {} else {}", first.to_string(), second.to_string())),
|
||||
cond,
|
||||
field_1,
|
||||
field_2,
|
||||
)?)
|
||||
}
|
||||
(ConstrainedValue::Group(group_1), ConstrainedValue::Group(group_2)) => {
|
||||
ConstrainedValue::Group(G::conditionally_select(
|
||||
cs.ns(|| format!("if cond ? {} else {}", first.to_string(), second.to_string())),
|
||||
cond,
|
||||
group_1,
|
||||
group_2,
|
||||
)?)
|
||||
}
|
||||
(ConstrainedValue::Array(arr_1), ConstrainedValue::Array(arr_2)) => {
|
||||
let mut array = vec![];
|
||||
for (i, (first, second)) in arr_1.into_iter().zip(arr_2.into_iter()).enumerate() {
|
||||
array.push(Self::conditionally_select(
|
||||
cs.ns(|| {
|
||||
format!(
|
||||
"array[{}] = if cond ? {} else {}",
|
||||
i,
|
||||
first.to_string(),
|
||||
second.to_string()
|
||||
)
|
||||
}),
|
||||
cond,
|
||||
first,
|
||||
second,
|
||||
)?);
|
||||
}
|
||||
ConstrainedValue::Array(array)
|
||||
}
|
||||
(_, _) => return Err(SynthesisError::Unsatisfiable),
|
||||
})
|
||||
}
|
||||
|
||||
fn cost() -> usize {
|
||||
unimplemented!() //lower bound 1, upper bound 128 or length of static array
|
||||
}
|
||||
}
|
||||
|
@ -49,6 +49,9 @@ pub enum StatementError {
|
||||
#[error("Function return statement expected {} return values, got {}", _0, _1)]
|
||||
InvalidNumberOfReturns(usize, usize),
|
||||
|
||||
#[error("Conditional select gadget failed to select between {} or {}", _0, _1)]
|
||||
SelectFail(String, String),
|
||||
|
||||
#[error("{}", _0)]
|
||||
SynthesisError(#[from] SynthesisError),
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
function main(bit: private u8) {
|
||||
if bit == 1u8 {
|
||||
assert_eq!(bit, 1u8);
|
||||
function main(bit: private u32) {
|
||||
if bit == 1u32 {
|
||||
assert_eq!(bit, 1u32);
|
||||
} else {
|
||||
assert_eq!(bit, 0u8);
|
||||
assert_eq!(bit, 0u32);
|
||||
}
|
||||
}
|
||||
|
11
compiler/tests/conditional/conditional_mutate.leo
Normal file
11
compiler/tests/conditional/conditional_mutate.leo
Normal file
@ -0,0 +1,11 @@
|
||||
function main(bit: private u32) -> u32 {
|
||||
let mut a = 5u32;
|
||||
|
||||
if bit == 1u32 {
|
||||
a = 1u32;
|
||||
} else {
|
||||
a = 0u32;
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
@ -1,6 +1,14 @@
|
||||
use crate::{get_output, parse_program, EdwardsConstrainedValue, EdwardsTestCompiler};
|
||||
use leo_inputs::types::{IntegerType, U8Type};
|
||||
use crate::{
|
||||
boolean::{output_false, output_true},
|
||||
get_output,
|
||||
integers::u32::{output_one, output_zero},
|
||||
parse_program,
|
||||
EdwardsConstrainedValue,
|
||||
EdwardsTestCompiler,
|
||||
};
|
||||
use leo_inputs::types::{IntegerType, U32Type};
|
||||
use leo_types::InputValue;
|
||||
|
||||
use snarkos_curves::edwards_bls12::Fq;
|
||||
use snarkos_models::gadgets::r1cs::TestConstraintSystem;
|
||||
|
||||
@ -28,18 +36,35 @@ fn conditional_basic() {
|
||||
|
||||
// Check that an input value of 1 satisfies the constraint system
|
||||
|
||||
program_1_pass.set_inputs(vec![Some(InputValue::Integer(IntegerType::U8Type(U8Type {}), 1))]);
|
||||
program_1_pass.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 1))]);
|
||||
empty_output_satisfied(program_1_pass);
|
||||
|
||||
// Check that an input value of 0 satisfies the constraint system
|
||||
|
||||
program_0_pass.set_inputs(vec![Some(InputValue::Integer(IntegerType::U8Type(U8Type {}), 0))]);
|
||||
program_0_pass.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 0))]);
|
||||
empty_output_satisfied(program_0_pass);
|
||||
|
||||
// Check that an input value of 2 does not satisfy the constraint system
|
||||
|
||||
program_2_fail.set_inputs(vec![Some(InputValue::Integer(IntegerType::U8Type(U8Type {}), 2))]);
|
||||
program_2_fail.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 2))]);
|
||||
let mut cs = TestConstraintSystem::<Fq>::new();
|
||||
let _output = program_2_fail.compile_constraints(&mut cs).unwrap();
|
||||
assert!(!cs.is_satisfied());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conditional_mutate() {
|
||||
let bytes = include_bytes!("conditional_mutate.leo");
|
||||
let mut program_1_true = parse_program(bytes).unwrap();
|
||||
let mut program_0_pass = program_1_true.clone();
|
||||
|
||||
// Check that an input value of 1 satisfies the constraint system
|
||||
|
||||
program_1_true.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 1))]);
|
||||
output_one(program_1_true);
|
||||
|
||||
// Check that an input value of 0 satisfies the constraint system
|
||||
|
||||
program_0_pass.set_inputs(vec![Some(InputValue::Integer(IntegerType::U32Type(U32Type {}), 0))]);
|
||||
output_zero(program_0_pass);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user