impl assert_eq( , )

This commit is contained in:
collin 2020-04-30 14:00:30 -07:00
parent 1bfc31c4d5
commit b683e93762
11 changed files with 94 additions and 44 deletions

View File

@ -1,7 +1,5 @@
function main() -> (fe) {
if (4fe >= 4fe) {
return 5fe
} else {
return 2fe
}
function main() -> () {
assert_eq(45, 45);
return
}

View File

@ -896,6 +896,21 @@ pub struct AssignStatement<'ast> {
pub span: Span<'ast>,
}
#[derive(Clone, Debug, FromPest, PartialEq)]
#[pest_ast(rule(Rule::assert_eq))]
pub struct AssertEq<'ast> {
pub left: Expression<'ast>,
pub right: Expression<'ast>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Clone, Debug, FromPest, PartialEq)]
#[pest_ast(rule(Rule::statement_assert))]
pub enum AssertStatement<'ast> {
AssertEq(AssertEq<'ast>),
}
#[derive(Clone, Debug, FromPest, PartialEq)]
#[pest_ast(rule(Rule::statement))]
pub enum Statement<'ast> {
@ -905,6 +920,7 @@ pub enum Statement<'ast> {
MultipleAssignment(MultipleAssignmentStatement<'ast>),
Conditional(ConditionalStatement<'ast>),
Iteration(ForStatement<'ast>),
Assert(AssertStatement<'ast>),
}
impl<'ast> fmt::Display for ReturnStatement<'ast> {
@ -975,6 +991,16 @@ impl<'ast> fmt::Display for AssignStatement<'ast> {
}
}
impl<'ast> fmt::Display for AssertStatement<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
AssertStatement::AssertEq(ref assert) => {
write!(f, "assert_eq({}, {});", assert.left, assert.right)
}
}
}
}
impl<'ast> fmt::Display for Statement<'ast> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
@ -984,6 +1010,7 @@ impl<'ast> fmt::Display for Statement<'ast> {
Statement::MultipleAssignment(ref statement) => write!(f, "{}", statement),
Statement::Conditional(ref statement) => write!(f, "{}", statement),
Statement::Iteration(ref statement) => write!(f, "{}", statement),
Statement::Assert(ref statement) => write!(f, "{}", statement),
}
}
}

View File

@ -140,15 +140,8 @@ impl<F: Field + PrimeField, CS: ConstraintSystem<F>> ResolvedProgram<F, CS> {
ResolvedValue::Boolean(Boolean::Constant(left.eq(&right)))
}
pub(crate) fn enforce_boolean_eq(
&mut self,
cs: &mut CS,
left: Boolean,
right: Boolean,
) -> ResolvedValue<F> {
pub(crate) fn enforce_boolean_eq(&mut self, cs: &mut CS, left: Boolean, right: Boolean) {
left.enforce_equal(cs.ns(|| format!("enforce bool equal")), &right)
.unwrap();
ResolvedValue::Boolean(Boolean::Constant(true))
}
}

View File

@ -201,27 +201,6 @@ impl<F: Field + PrimeField, CS: ConstraintSystem<F>> ResolvedProgram<F, CS> {
}
}
/// Enforce Boolean operations, returns true on success
fn enforce_eq_expression(
&mut self,
cs: &mut CS,
left: ResolvedValue<F>,
right: ResolvedValue<F>,
) -> ResolvedValue<F> {
match (left, right) {
(ResolvedValue::Boolean(bool1), ResolvedValue::Boolean(bool2)) => {
self.enforce_boolean_eq(cs, bool1, bool2)
}
(ResolvedValue::U32(num1), ResolvedValue::U32(num2)) => {
Self::enforce_u32_eq(cs, num1, num2)
}
(ResolvedValue::FieldElement(fe1), ResolvedValue::FieldElement(fe2)) => {
self.enforce_field_eq(cs, fe1, fe2)
}
(val1, val2) => unimplemented!("cannot enforce equality between {} == {}", val1, val2),
}
}
/// Enforce array expressions
fn enforce_array_expression(
&mut self,

View File

@ -108,7 +108,7 @@ impl<F: Field + PrimeField, CS: ConstraintSystem<F>> ResolvedProgram<F, CS> {
ResolvedValue::Boolean(Boolean::Constant(fe1.lt(&fe2)))
}
pub(crate) fn enforce_field_eq(&mut self, cs: &mut CS, fe1: F, fe2: F) -> ResolvedValue<F> {
pub(crate) fn enforce_field_eq(&mut self, cs: &mut CS, fe1: F, fe2: F) {
let mut lc = LinearCombination::zero();
// add (fe1 * 1) and subtract (fe2 * 1) from the linear combination
@ -116,9 +116,6 @@ impl<F: Field + PrimeField, CS: ConstraintSystem<F>> ResolvedProgram<F, CS> {
// enforce that the linear combination is zero
cs.enforce(|| "field equality", |lc| lc, |lc| lc, |_| lc);
// return success
ResolvedValue::Boolean(Boolean::constant(true))
}
pub(crate) fn enforce_field_add(&mut self, fe1: F, fe2: F) -> ResolvedValue<F> {

View File

@ -99,15 +99,13 @@ impl<F: Field + PrimeField, CS: ConstraintSystem<F>> ResolvedProgram<F, CS> {
ResolvedValue::Boolean(Boolean::Constant(left.eq(&right)))
}
pub(crate) fn enforce_u32_eq(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue<F> {
pub(crate) fn enforce_u32_eq(cs: &mut CS, left: UInt32, right: UInt32) {
left.conditional_enforce_equal(
cs.ns(|| format!("enforce field equal")),
&right,
&Boolean::Constant(true),
)
.unwrap();
ResolvedValue::Boolean(Boolean::Constant(true))
}
pub(crate) fn enforce_u32_add(cs: &mut CS, left: UInt32, right: UInt32) -> ResolvedValue<F> {

View File

@ -342,6 +342,26 @@ impl<F: Field + PrimeField, CS: ConstraintSystem<F>> ResolvedProgram<F, CS> {
res
}
fn enforce_assert_eq_statement(
&mut self,
cs: &mut CS,
left: ResolvedValue<F>,
right: ResolvedValue<F>,
) {
match (left, right) {
(ResolvedValue::Boolean(bool1), ResolvedValue::Boolean(bool2)) => {
self.enforce_boolean_eq(cs, bool1, bool2)
}
(ResolvedValue::U32(num1), ResolvedValue::U32(num2)) => {
Self::enforce_u32_eq(cs, num1, num2)
}
(ResolvedValue::FieldElement(fe1), ResolvedValue::FieldElement(fe2)) => {
self.enforce_field_eq(cs, fe1, fe2)
}
(val1, val2) => unimplemented!("cannot enforce equality between {} == {}", val1, val2),
}
}
pub(crate) fn enforce_statement(
&mut self,
cs: &mut CS,
@ -409,6 +429,14 @@ impl<F: Field + PrimeField, CS: ConstraintSystem<F>> ResolvedProgram<F, CS> {
res = Some(early_return)
}
}
Statement::AssertEq(left, right) => {
let resolved_left =
self.enforce_expression(cs, file_scope.clone(), function_scope.clone(), left);
let resolved_right =
self.enforce_expression(cs, file_scope.clone(), function_scope.clone(), right);
self.enforce_assert_eq_statement(cs, resolved_left, resolved_right);
}
};
res

View File

@ -150,21 +150,32 @@ expression_term = {
expression = { expression_term ~ (operation_binary ~ expression_term)* }
expression_tuple = _{ (expression ~ ("," ~ expression)*)? }
/// Asserts
assert_eq = {"assert_eq" ~ "(" ~ NEWLINE* ~ expression ~ "," ~ NEWLINE* ~ expression ~ NEWLINE* ~ ")"}
// assert_true = {"assert"}
/// Conditionals
conditional_nested_or_end = { statement_conditional | "{" ~ NEWLINE* ~ statement+ ~ "}"}
/// Statements
// statement_one_or_more = { (statement ~ NEWLINE*)* }
statement_return = { "return" ~ expression_tuple }
statement_definition = { ty ~ variable ~ "=" ~ expression }
statement_assign = { assignee ~ operation_assign ~ expression }
statement_multiple_assignment = { optionally_typed_variable_tuple ~ "=" ~ variable ~ "(" ~ expression_tuple ~ ")" }
statement_conditional = {"if" ~ "(" ~ expression ~ ")" ~ "{" ~ NEWLINE* ~ statement+ ~ "}" ~ ("else" ~ conditional_nested_or_end)?}
conditional_nested_or_end = { statement_conditional | "{" ~ NEWLINE* ~ statement+ ~ "}"}
statement_for = { "for" ~ variable ~ "in" ~ expression ~ ".." ~ expression ~ "{" ~ NEWLINE* ~ statement+ ~ "}"}
statement_assert = {
assert_eq
// | assert_true |
}
statement = {
(statement_return
| statement_conditional
| statement_for
| (statement_multiple_assignment
| statement_assert
| statement_definition
| statement_assign
) ~ LINE_END

View File

@ -130,6 +130,7 @@ pub enum Statement<F: Field + PrimeField> {
MultipleAssign(Vec<Assignee<F>>, Expression<F>),
Conditional(ConditionalStatement<F>),
For(Variable<F>, Integer, Integer, Vec<Statement<F>>),
AssertEq(Expression<F>, Expression<F>),
}
#[derive(Clone, Debug)]

View File

@ -203,6 +203,9 @@ impl<F: Field + PrimeField> fmt::Display for Statement<F> {
}
write!(f, "\t}}")
}
Statement::AssertEq(ref left, ref right) => {
write!(f, "assert_eq({}, {});", left, right)
}
}
}
}
@ -243,6 +246,9 @@ impl<F: Field + PrimeField> fmt::Debug for Statement<F> {
}
write!(f, "\tendfor;")
}
Statement::AssertEq(ref left, ref right) => {
write!(f, "assert_eq({}, {});", left, right)
}
}
}
}

View File

@ -521,6 +521,17 @@ impl<'ast, F: Field + PrimeField> From<ast::ForStatement<'ast>> for types::State
}
}
impl<'ast, F: Field + PrimeField> From<ast::AssertStatement<'ast>> for types::Statement<F> {
fn from(statement: ast::AssertStatement<'ast>) -> Self {
match statement {
ast::AssertStatement::AssertEq(assert_eq) => types::Statement::AssertEq(
types::Expression::from(assert_eq.left),
types::Expression::from(assert_eq.right),
),
}
}
}
impl<'ast, F: Field + PrimeField> From<ast::Statement<'ast>> for types::Statement<F> {
fn from(statement: ast::Statement<'ast>) -> Self {
match statement {
@ -532,6 +543,7 @@ impl<'ast, F: Field + PrimeField> From<ast::Statement<'ast>> for types::Statemen
types::Statement::Conditional(types::ConditionalStatement::from(statement))
}
ast::Statement::Iteration(statement) => types::Statement::from(statement),
ast::Statement::Assert(statement) => types::Statement::from(statement),
}
}
}