From 7233e55dc1ca9561f57bc61d7cfb6719fc538048 Mon Sep 17 00:00:00 2001 From: collin Date: Thu, 25 Jun 2020 15:58:37 -0700 Subject: [PATCH] impl comparator gadgets for fields and integers --- compiler/src/constraints/comparator.rs | 28 ++++ compiler/src/constraints/expression.rs | 187 ++++++++++++++----------- compiler/src/constraints/integer.rs | 40 +++++- compiler/src/constraints/mod.rs | 3 + compiler/src/field/mod.rs | 26 ++++ 5 files changed, 202 insertions(+), 82 deletions(-) create mode 100644 compiler/src/constraints/comparator.rs diff --git a/compiler/src/constraints/comparator.rs b/compiler/src/constraints/comparator.rs new file mode 100644 index 0000000000..a22c6ede4d --- /dev/null +++ b/compiler/src/constraints/comparator.rs @@ -0,0 +1,28 @@ +use snarkos_errors::gadgets::SynthesisError; +use snarkos_models::{ + curves::Field, + gadgets::{r1cs::ConstraintSystem, utilities::boolean::Boolean}, +}; + +pub trait EvaluateLtGadget { + fn less_than>(&self, cs: CS, other: &Self) -> Result; +} + +// implementing `EvaluateLtGadget` will implement `ComparatorGadget` +pub trait ComparatorGadget +where + Self: EvaluateLtGadget, +{ + fn greater_than>(&self, cs: CS, other: &Self) -> Result { + other.less_than(cs, other) + } + + fn less_than_or_equal>(&self, cs: CS, other: &Self) -> Result { + let is_gt = self.greater_than(cs, other)?; + Ok(is_gt.not()) + } + + fn greater_than_or_equal>(&self, cs: CS, other: &Self) -> Result { + other.less_than_or_equal(cs, self) + } +} diff --git a/compiler/src/constraints/expression.rs b/compiler/src/constraints/expression.rs index 62aa35a585..8a3cbca5d0 100644 --- a/compiler/src/constraints/expression.rs +++ b/compiler/src/constraints/expression.rs @@ -1,6 +1,7 @@ //! Methods to enforce constraints on expressions in a resolved Leo program. use crate::{ + comparator::{ComparatorGadget, EvaluateLtGadget}, constraints::{ConstrainedCircuitMember, ConstrainedProgram, ConstrainedValue}, enforce_and, enforce_or, @@ -72,11 +73,11 @@ impl> ConstrainedProgram { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { Ok(ConstrainedValue::Integer(num_1.add(cs, num_2, span)?)) } - (ConstrainedValue::Field(fe_1), ConstrainedValue::Field(fe_2)) => { - Ok(ConstrainedValue::Field(fe_1.add(cs, &fe_2, span)?)) + (ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => { + Ok(ConstrainedValue::Field(field_1.add(cs, &field_2, span)?)) } - (ConstrainedValue::Group(ge_1), ConstrainedValue::Group(ge_2)) => { - Ok(ConstrainedValue::Group(ge_1.add(cs, &ge_2, span)?)) + (ConstrainedValue::Group(point_1), ConstrainedValue::Group(point_2)) => { + Ok(ConstrainedValue::Group(point_1.add(cs, &point_2, span)?)) } (ConstrainedValue::Unresolved(string), val_2) => { let val_1 = ConstrainedValue::from_other(string, &val_2, span.clone())?; @@ -104,11 +105,11 @@ impl> ConstrainedProgram { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { Ok(ConstrainedValue::Integer(num_1.sub(cs, num_2, span)?)) } - (ConstrainedValue::Field(fe_1), ConstrainedValue::Field(fe_2)) => { - Ok(ConstrainedValue::Field(fe_1.sub(cs, &fe_2, span)?)) + (ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => { + Ok(ConstrainedValue::Field(field_1.sub(cs, &field_2, span)?)) } - (ConstrainedValue::Group(ge_1), ConstrainedValue::Group(ge_2)) => { - Ok(ConstrainedValue::Group(ge_1.sub(cs, &ge_2, span)?)) + (ConstrainedValue::Group(point_1), ConstrainedValue::Group(point_2)) => { + Ok(ConstrainedValue::Group(point_1.sub(cs, &point_2, span)?)) } (ConstrainedValue::Unresolved(string), val_2) => { let val_1 = ConstrainedValue::from_other(string, &val_2, span.clone())?; @@ -136,8 +137,8 @@ impl> ConstrainedProgram { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { Ok(ConstrainedValue::Integer(num_1.mul(cs, num_2, span)?)) } - (ConstrainedValue::Field(fe_1), ConstrainedValue::Field(fe_2)) => { - Ok(ConstrainedValue::Field(fe_1.mul(cs, &fe_2, span)?)) + (ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => { + Ok(ConstrainedValue::Field(field_1.mul(cs, &field_2, span)?)) } (ConstrainedValue::Unresolved(string), val_2) => { let val_1 = ConstrainedValue::from_other(string, &val_2, span.clone())?; @@ -167,8 +168,8 @@ impl> ConstrainedProgram { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { Ok(ConstrainedValue::Integer(num_1.div(cs, num_2, span)?)) } - (ConstrainedValue::Field(fe_1), ConstrainedValue::Field(fe_2)) => { - Ok(ConstrainedValue::Field(fe_1.div(cs, &fe_2, span)?)) + (ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => { + Ok(ConstrainedValue::Field(field_1.div(cs, &field_2, span)?)) } (ConstrainedValue::Unresolved(string), val_2) => { let val_1 = ConstrainedValue::from_other(string, &val_2, span.clone())?; @@ -229,11 +230,11 @@ impl> ConstrainedProgram { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { num_1.evaluate_equal(unique_namespace, &num_2) } - (ConstrainedValue::Field(fe_1), ConstrainedValue::Field(fe_2)) => { - fe_1.evaluate_equal(unique_namespace, &fe_2) + (ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => { + field_1.evaluate_equal(unique_namespace, &field_2) } - (ConstrainedValue::Group(ge_1), ConstrainedValue::Group(ge_2)) => { - ge_1.evaluate_equal(unique_namespace, &ge_2) + (ConstrainedValue::Group(point_1), ConstrainedValue::Group(point_2)) => { + point_1.evaluate_equal(unique_namespace, &point_2) } (ConstrainedValue::Unresolved(string), val_2) => { let val_1 = ConstrainedValue::from_other(string, &val_2, span.clone())?; @@ -252,133 +253,157 @@ impl> ConstrainedProgram { }; let boolean = - constraint_result.map_err(|e| ExpressionError::cannot_enforce(format!("evaluate equals"), e, span))?; + constraint_result.map_err(|e| ExpressionError::cannot_enforce(format!("evaluate equal"), e, span))?; Ok(ConstrainedValue::Boolean(boolean)) } - //TODO: unsafe for allocated values - fn evaluate_ge_expression( + fn evaluate_ge_expression>( &mut self, + cs: &mut CS, left: ConstrainedValue, right: ConstrainedValue, span: Span, ) -> Result, ExpressionError> { - match (left, right) { + let mut unique_namespace = cs.ns(|| format!("evaluate {} >= {} {}:{}", left, right, span.line, span.start)); + let constraint_result = match (left, right) { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { - let result = num_1.ge(&num_2); - Ok(ConstrainedValue::Boolean(Boolean::Constant(result))) + num_1.greater_than_or_equal(unique_namespace, &num_2) } - (ConstrainedValue::Field(fe_1), ConstrainedValue::Field(fe_2)) => { - let result = fe_1.ge(&fe_2); - Ok(ConstrainedValue::Boolean(Boolean::Constant(result))) + (ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => { + field_1.greater_than_or_equal(unique_namespace, &field_2) } (ConstrainedValue::Unresolved(string), val_2) => { let val_1 = ConstrainedValue::from_other(string, &val_2, span.clone())?; - self.evaluate_ge_expression(val_1, val_2, span) + return self.evaluate_ge_expression(&mut unique_namespace, val_1, val_2, span); } (val_1, ConstrainedValue::Unresolved(string)) => { let val_2 = ConstrainedValue::from_other(string, &val_1, span.clone())?; - self.evaluate_ge_expression(val_1, val_2, span) + return self.evaluate_ge_expression(&mut unique_namespace, val_1, val_2, span); } - (val_1, val_2) => Err(ExpressionError::incompatible_types( - format!("{} >= {}", val_1, val_2), - span, - )), - } + (val_1, val_2) => { + return Err(ExpressionError::incompatible_types( + format!("{} >= {}", val_1, val_2), + span, + )); + } + }; + + let boolean = constraint_result + .map_err(|e| ExpressionError::cannot_enforce(format!("evaluate greater than or equal"), e, span))?; + + Ok(ConstrainedValue::Boolean(boolean)) } - //TODO: unsafe for allocated values - fn evaluate_gt_expression( + fn evaluate_gt_expression>( &mut self, + cs: &mut CS, left: ConstrainedValue, right: ConstrainedValue, span: Span, ) -> Result, ExpressionError> { - match (left, right) { + let mut unique_namespace = cs.ns(|| format!("evaluate {} > {} {}:{}", left, right, span.line, span.start)); + let constraint_result = match (left, right) { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { - let result = num_1.gt(&num_2); - Ok(ConstrainedValue::Boolean(Boolean::Constant(result))) + num_1.greater_than(unique_namespace, &num_2) } - (ConstrainedValue::Field(fe_1), ConstrainedValue::Field(fe_2)) => { - let result = fe_1.gt(&fe_2); - Ok(ConstrainedValue::Boolean(Boolean::Constant(result))) + (ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => { + field_1.greater_than(unique_namespace, &field_2) } (ConstrainedValue::Unresolved(string), val_2) => { let val_1 = ConstrainedValue::from_other(string, &val_2, span.clone())?; - self.evaluate_gt_expression(val_1, val_2, span) + return self.evaluate_gt_expression(&mut unique_namespace, val_1, val_2, span); } (val_1, ConstrainedValue::Unresolved(string)) => { let val_2 = ConstrainedValue::from_other(string, &val_1, span.clone())?; - self.evaluate_gt_expression(val_1, val_2, span) + return self.evaluate_gt_expression(&mut unique_namespace, val_1, val_2, span); } - (val_1, val_2) => Err(ExpressionError::incompatible_types( - format!("{} > {}", val_1, val_2), - span, - )), - } + (val_1, val_2) => { + return Err(ExpressionError::incompatible_types( + format!("{} > {}", val_1, val_2), + span, + )); + } + }; + + let boolean = constraint_result + .map_err(|e| ExpressionError::cannot_enforce(format!("evaluate greater than"), e, span))?; + + Ok(ConstrainedValue::Boolean(boolean)) } - //TODO: unsafe for allocated values - fn evaluate_le_expression( + fn evaluate_le_expression>( &mut self, + cs: &mut CS, left: ConstrainedValue, right: ConstrainedValue, span: Span, ) -> Result, ExpressionError> { - match (left, right) { + let mut unique_namespace = cs.ns(|| format!("evaluate {} <= {} {}:{}", left, right, span.line, span.start)); + let constraint_result = match (left, right) { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { - let result = num_1.le(&num_2); - Ok(ConstrainedValue::Boolean(Boolean::Constant(result))) + num_1.less_than_or_equal(unique_namespace, &num_2) } - (ConstrainedValue::Field(fe_1), ConstrainedValue::Field(fe_2)) => { - let result = fe_1.le(&fe_2); - Ok(ConstrainedValue::Boolean(Boolean::Constant(result))) + (ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => { + field_1.less_than_or_equal(unique_namespace, &field_2) } (ConstrainedValue::Unresolved(string), val_2) => { let val_1 = ConstrainedValue::from_other(string, &val_2, span.clone())?; - self.evaluate_le_expression(val_1, val_2, span) + return self.evaluate_le_expression(&mut unique_namespace, val_1, val_2, span); } (val_1, ConstrainedValue::Unresolved(string)) => { let val_2 = ConstrainedValue::from_other(string, &val_1, span.clone())?; - self.evaluate_le_expression(val_1, val_2, span) + return self.evaluate_le_expression(&mut unique_namespace, val_1, val_2, span); } - (val_1, val_2) => Err(ExpressionError::incompatible_types( - format!("{} <= {}", val_1, val_2), - span, - )), - } + (val_1, val_2) => { + return Err(ExpressionError::incompatible_types( + format!("{} <= {}", val_1, val_2), + span, + )); + } + }; + + let boolean = constraint_result + .map_err(|e| ExpressionError::cannot_enforce(format!("evaluate less than or equal"), e, span))?; + + Ok(ConstrainedValue::Boolean(boolean)) } - //TODO: unsafe for allocated values - fn evaluate_lt_expression( + fn evaluate_lt_expression>( &mut self, + cs: &mut CS, left: ConstrainedValue, right: ConstrainedValue, span: Span, ) -> Result, ExpressionError> { - match (left, right) { + let mut unique_namespace = cs.ns(|| format!("evaluate {} < {} {}:{}", left, right, span.line, span.start)); + let constraint_result = match (left, right) { (ConstrainedValue::Integer(num_1), ConstrainedValue::Integer(num_2)) => { - let result = num_1.lt(&num_2); - Ok(ConstrainedValue::Boolean(Boolean::Constant(result))) + num_1.less_than(unique_namespace, &num_2) } - (ConstrainedValue::Field(fe_1), ConstrainedValue::Field(fe_2)) => { - let result = fe_1.lt(&fe_2); - Ok(ConstrainedValue::Boolean(Boolean::Constant(result))) + (ConstrainedValue::Field(field_1), ConstrainedValue::Field(field_2)) => { + field_1.less_than(unique_namespace, &field_2) } (ConstrainedValue::Unresolved(string), val_2) => { let val_1 = ConstrainedValue::from_other(string, &val_2, span.clone())?; - self.evaluate_lt_expression(val_1, val_2, span) + return self.evaluate_lt_expression(&mut unique_namespace, val_1, val_2, span); } (val_1, ConstrainedValue::Unresolved(string)) => { let val_2 = ConstrainedValue::from_other(string, &val_1, span.clone())?; - self.evaluate_lt_expression(val_1, val_2, span) + return self.evaluate_lt_expression(&mut unique_namespace, val_1, val_2, span); } - (val_1, val_2) => Err(ExpressionError::incompatible_types( - format!("{} < {}", val_1, val_2,), - span, - )), - } + (val_1, val_2) => { + return Err(ExpressionError::incompatible_types( + format!("{} < {}", val_1, val_2), + span, + )); + } + }; + + let boolean = + constraint_result.map_err(|e| ExpressionError::cannot_enforce(format!("evaluate less than"), e, span))?; + + Ok(ConstrainedValue::Boolean(boolean)) } /// Enforce ternary conditional expression @@ -1034,7 +1059,7 @@ impl> ConstrainedProgram { span.clone(), )?; - Ok(self.evaluate_ge_expression(resolved_left, resolved_right, span)?) + Ok(self.evaluate_ge_expression(cs, resolved_left, resolved_right, span)?) } Expression::Gt(left, right, span) => { let (resolved_left, resolved_right) = self.enforce_binary_expression( @@ -1047,7 +1072,7 @@ impl> ConstrainedProgram { span.clone(), )?; - Ok(self.evaluate_gt_expression(resolved_left, resolved_right, span)?) + Ok(self.evaluate_gt_expression(cs, resolved_left, resolved_right, span)?) } Expression::Le(left, right, span) => { let (resolved_left, resolved_right) = self.enforce_binary_expression( @@ -1060,7 +1085,7 @@ impl> ConstrainedProgram { span.clone(), )?; - Ok(self.evaluate_le_expression(resolved_left, resolved_right, span)?) + Ok(self.evaluate_le_expression(cs, resolved_left, resolved_right, span)?) } Expression::Lt(left, right, span) => { let (resolved_left, resolved_right) = self.enforce_binary_expression( @@ -1073,7 +1098,7 @@ impl> ConstrainedProgram { span.clone(), )?; - Ok(self.evaluate_lt_expression(resolved_left, resolved_right, span)?) + Ok(self.evaluate_lt_expression(cs, resolved_left, resolved_right, span)?) } // Conditionals diff --git a/compiler/src/constraints/integer.rs b/compiler/src/constraints/integer.rs index 5ece95d229..aa8b66991a 100644 --- a/compiler/src/constraints/integer.rs +++ b/compiler/src/constraints/integer.rs @@ -1,5 +1,5 @@ //! Conversion of integer declarations to constraints in Leo. -use crate::errors::IntegerError; +use crate::{errors::IntegerError, ComparatorGadget, EvaluateLtGadget}; use leo_types::{InputValue, IntegerType, Span}; use snarkos_errors::gadgets::SynthesisError; @@ -88,6 +88,16 @@ impl Integer { Ok(value as usize) } + pub fn to_bits_le(&self) -> Vec { + match self { + Integer::U8(num) => num.bits.clone(), + Integer::U16(num) => num.bits.clone(), + Integer::U32(num) => num.bits.clone(), + Integer::U64(num) => num.bits.clone(), + Integer::U128(num) => num.bits.clone(), + } + } + pub fn get_type(&self) -> IntegerType { match self { Integer::U8(_u8) => IntegerType::U8, @@ -435,6 +445,34 @@ impl EvaluateEqGadget for Integer { } } +impl EvaluateLtGadget for Integer { + fn less_than>(&self, mut cs: CS, other: &Self) -> Result { + if self.to_bits_le().len() != other.to_bits_le().len() { + return Err(SynthesisError::Unsatisfiable); + } + + for (i, (self_bit, other_bit)) in self + .to_bits_le() + .iter() + .rev() + .zip(other.to_bits_le().iter().rev()) + .enumerate() + { + let is_less = Boolean::and(&mut cs, self_bit, &other_bit.not())?; + + if is_less.eq(&Boolean::constant(true)) { + return Ok(is_less); + } else if i == self.to_bits_le().len() - 1 { + return Ok(is_less); + } + } + + Err(SynthesisError::Unsatisfiable) + } +} + +impl ComparatorGadget for Integer {} + impl EqGadget for Integer {} impl ConditionalEqGadget for Integer { diff --git a/compiler/src/constraints/mod.rs b/compiler/src/constraints/mod.rs index a393f6342d..fe6ec7086b 100644 --- a/compiler/src/constraints/mod.rs +++ b/compiler/src/constraints/mod.rs @@ -3,6 +3,9 @@ pub(crate) mod boolean; pub(crate) use boolean::*; +pub(crate) mod comparator; +pub(crate) use comparator::*; + pub mod function; pub use function::*; diff --git a/compiler/src/field/mod.rs b/compiler/src/field/mod.rs index cef4cfec77..6db0f95a24 100644 --- a/compiler/src/field/mod.rs +++ b/compiler/src/field/mod.rs @@ -3,6 +3,7 @@ use crate::errors::FieldError; use leo_types::Span; +use crate::{ComparatorGadget, EvaluateLtGadget}; use snarkos_errors::gadgets::SynthesisError; use snarkos_models::{ curves::{Field, PrimeField}, @@ -210,6 +211,31 @@ impl EvaluateEqGadget for FieldType { } } +impl EvaluateLtGadget for FieldType { + fn less_than>(&self, mut cs: CS, other: &Self) -> Result { + match (self, other) { + (FieldType::Constant(first), FieldType::Constant(second)) => Ok(Boolean::constant(first.lt(second))), + (FieldType::Allocated(allocated), FieldType::Constant(constant)) + | (FieldType::Constant(constant), FieldType::Allocated(allocated)) => { + let bool_option = allocated.value.map(|f| f.lt(constant)); + + Boolean::alloc(&mut cs.ns(|| "less than"), || { + bool_option.ok_or(SynthesisError::AssignmentMissing) + }) + } + (FieldType::Allocated(first), FieldType::Allocated(second)) => { + let bool_option = first.value.and_then(|a| second.value.map(|b| a.eq(&b))); + + Boolean::alloc(&mut cs.ns(|| "evaluate_equal"), || { + bool_option.ok_or(SynthesisError::AssignmentMissing) + }) + } + } + } +} + +impl ComparatorGadget for FieldType {} + impl EqGadget for FieldType {} impl ConditionalEqGadget for FieldType {