composable return statements working

This commit is contained in:
collin 2020-04-03 12:22:03 -07:00
parent 82eaba499c
commit acd8992843
3 changed files with 108 additions and 88 deletions

View File

@ -1 +1 @@
return a / 2 x = 5 + 3

View File

@ -13,15 +13,15 @@ fn bool_from_variable<F: Field, CS: ConstraintSystem<F>>(
cs: &mut CS, cs: &mut CS,
variable: Variable, variable: Variable,
) -> Boolean { ) -> Boolean {
// let argument = std::env::args() let argument = std::env::args()
// .nth(1) .nth(1)
// .unwrap_or("true".into()) .unwrap_or("true".into())
// .parse::<bool>() .parse::<bool>()
// .unwrap(); .unwrap();
//
// println!(" argument passed to command line a = {:?}", argument); println!(" argument passed to command line a = {:?}", argument);
let a = true; // let a = true;
Boolean::alloc(cs.ns(|| variable.0), || Ok(a)).unwrap() Boolean::alloc_input(cs.ns(|| variable.0), || Ok(argument)).unwrap()
} }
fn u32_from_variable<F: Field, CS: ConstraintSystem<F>>(cs: &mut CS, variable: Variable) -> UInt32 { fn u32_from_variable<F: Field, CS: ConstraintSystem<F>>(cs: &mut CS, variable: Variable) -> UInt32 {
@ -44,7 +44,7 @@ fn get_bool_value<F: Field, CS: ConstraintSystem<F>>(
match expression { match expression {
BooleanExpression::Variable(variable) => bool_from_variable(cs, variable), BooleanExpression::Variable(variable) => bool_from_variable(cs, variable),
BooleanExpression::Value(value) => Boolean::Constant(value), BooleanExpression::Value(value) => Boolean::Constant(value),
_ => unimplemented!(), expression => enforce_boolean_expression(cs, expression),
} }
} }
@ -55,7 +55,7 @@ fn get_u32_value<F: Field, CS: ConstraintSystem<F>>(
match expression { match expression {
FieldExpression::Variable(variable) => u32_from_variable(cs, variable), FieldExpression::Variable(variable) => u32_from_variable(cs, variable),
FieldExpression::Number(number) => UInt32::constant(number), FieldExpression::Number(number) => UInt32::constant(number),
_ => unimplemented!(), // FieldExpression::Add(left, right) => field => enforce_field_expression(cs, field),
} }
} }
@ -63,41 +63,43 @@ fn enforce_or<F: Field, CS: ConstraintSystem<F>>(
cs: &mut CS, cs: &mut CS,
left: BooleanExpression, left: BooleanExpression,
right: BooleanExpression, right: BooleanExpression,
) { ) -> Boolean {
let left = get_bool_value(cs, left); let left = get_bool_value(cs, left);
let right = get_bool_value(cs, right); let right = get_bool_value(cs, right);
let _result = Boolean::or(cs, &left, &right).unwrap(); Boolean::or(cs, &left, &right).unwrap()
} }
fn enforce_and<F: Field, CS: ConstraintSystem<F>>( fn enforce_and<F: Field, CS: ConstraintSystem<F>>(
cs: &mut CS, cs: &mut CS,
left: BooleanExpression, left: BooleanExpression,
right: BooleanExpression, right: BooleanExpression,
) { ) -> Boolean {
let left = get_bool_value(cs, left); let left = get_bool_value(cs, left);
let right = get_bool_value(cs, right); let right = get_bool_value(cs, right);
let _result = Boolean::and(cs, &left, &right).unwrap(); Boolean::and(cs, &left, &right).unwrap()
} }
fn enforce_bool_equality<F: Field, CS: ConstraintSystem<F>>( fn enforce_bool_equality<F: Field, CS: ConstraintSystem<F>>(
cs: &mut CS, cs: &mut CS,
left: BooleanExpression, left: BooleanExpression,
right: BooleanExpression, right: BooleanExpression,
) { ) -> Boolean {
let left = get_bool_value(cs, left); let left = get_bool_value(cs, left);
let right = get_bool_value(cs, right); let right = get_bool_value(cs, right);
left.enforce_equal(cs.ns(|| format!("enforce bool equal")), &right) left.enforce_equal(cs.ns(|| format!("enforce bool equal")), &right)
.unwrap(); .unwrap();
Boolean::Constant(true)
} }
fn enforce_field_equality<F: Field, CS: ConstraintSystem<F>>( fn enforce_field_equality<F: Field, CS: ConstraintSystem<F>>(
cs: &mut CS, cs: &mut CS,
left: FieldExpression, left: FieldExpression,
right: FieldExpression, right: FieldExpression,
) { ) -> Boolean {
let left = get_u32_value(cs, left); let left = get_u32_value(cs, left);
let right = get_u32_value(cs, right); let right = get_u32_value(cs, right);
@ -107,25 +109,19 @@ fn enforce_field_equality<F: Field, CS: ConstraintSystem<F>>(
&Boolean::Constant(true), &Boolean::Constant(true),
) )
.unwrap(); .unwrap();
Boolean::Constant(true)
} }
fn enforce_boolean_expression<F: Field, CS: ConstraintSystem<F>>( fn enforce_boolean_expression<F: Field, CS: ConstraintSystem<F>>(
cs: &mut CS, cs: &mut CS,
expression: BooleanExpression, expression: BooleanExpression,
) { ) -> Boolean {
match expression { match expression {
BooleanExpression::Or(left, right) => { BooleanExpression::Or(left, right) => enforce_or(cs, *left, *right),
enforce_or(cs, *left, *right); BooleanExpression::And(left, right) => enforce_and(cs, *left, *right),
} BooleanExpression::BoolEq(left, right) => enforce_bool_equality(cs, *left, *right),
BooleanExpression::And(left, right) => { BooleanExpression::FieldEq(left, right) => enforce_field_equality(cs, *left, *right),
enforce_and(cs, *left, *right);
}
BooleanExpression::BoolEq(left, right) => {
enforce_bool_equality(cs, *left, *right);
}
BooleanExpression::FieldEq(left, right) => {
enforce_field_equality(cs, *left, *right);
}
_ => unimplemented!(), _ => unimplemented!(),
} }
} }
@ -134,87 +130,93 @@ fn enforce_add<F: Field, CS: ConstraintSystem<F>>(
cs: &mut CS, cs: &mut CS,
left: FieldExpression, left: FieldExpression,
right: FieldExpression, right: FieldExpression,
) { ) -> UInt32 {
let left = get_u32_value(cs, left); let left = get_u32_value(cs, left);
let right = get_u32_value(cs, right); let right = get_u32_value(cs, right);
let r = left left.add(
.add(
cs.ns(|| format!("enforce {} + {}", left.value.unwrap(), right.value.unwrap())), cs.ns(|| format!("enforce {} + {}", left.value.unwrap(), right.value.unwrap())),
&right, &right,
) )
.unwrap(); .unwrap()
println!("result {}", r.value.unwrap());
} }
fn enforce_sub<F: Field, CS: ConstraintSystem<F>>( fn enforce_sub<F: Field, CS: ConstraintSystem<F>>(
cs: &mut CS, cs: &mut CS,
left: FieldExpression, left: FieldExpression,
right: FieldExpression, right: FieldExpression,
) { ) -> UInt32 {
let left = get_u32_value(cs, left); let left = get_u32_value(cs, left);
let right = get_u32_value(cs, right); let right = get_u32_value(cs, right);
let r = left left.sub(
.sub(
cs.ns(|| format!("enforce {} - {}", left.value.unwrap(), right.value.unwrap())), cs.ns(|| format!("enforce {} - {}", left.value.unwrap(), right.value.unwrap())),
&right, &right,
) )
.unwrap(); .unwrap()
println!("result {}", r.value.unwrap());
} }
fn enforce_mul<F: Field, CS: ConstraintSystem<F>>( fn enforce_mul<F: Field, CS: ConstraintSystem<F>>(
cs: &mut CS, cs: &mut CS,
left: FieldExpression, left: FieldExpression,
right: FieldExpression, right: FieldExpression,
) { ) -> UInt32 {
let left = get_u32_value(cs, left); let left = get_u32_value(cs, left);
let right = get_u32_value(cs, right); let right = get_u32_value(cs, right);
let r = left left.mul(
.mul(
cs.ns(|| format!("enforce {} * {}", left.value.unwrap(), right.value.unwrap())), cs.ns(|| format!("enforce {} * {}", left.value.unwrap(), right.value.unwrap())),
&right, &right,
) )
.unwrap(); .unwrap()
println!("result {}", r.value.unwrap());
} }
fn enforce_div<F: Field, CS: ConstraintSystem<F>>( fn enforce_div<F: Field, CS: ConstraintSystem<F>>(
cs: &mut CS, cs: &mut CS,
left: FieldExpression, left: FieldExpression,
right: FieldExpression, right: FieldExpression,
) { ) -> UInt32 {
let left = get_u32_value(cs, left); let left = get_u32_value(cs, left);
let right = get_u32_value(cs, right); let right = get_u32_value(cs, right);
let r = left left.div(
.div(
cs.ns(|| format!("enforce {} / {}", left.value.unwrap(), right.value.unwrap())), cs.ns(|| format!("enforce {} / {}", left.value.unwrap(), right.value.unwrap())),
&right, &right,
) )
.unwrap(); .unwrap()
println!("result {}", r.value.unwrap()); }
fn enforce_pow<F: Field, CS: ConstraintSystem<F>>(
cs: &mut CS,
left: FieldExpression,
right: FieldExpression,
) -> UInt32 {
let left = get_u32_value(cs, left);
let right = get_u32_value(cs, right);
left.pow(
cs.ns(|| {
format!(
"enforce {} ** {}",
left.value.unwrap(),
right.value.unwrap()
)
}),
&right,
)
.unwrap()
} }
fn enforce_field_expression<F: Field, CS: ConstraintSystem<F>>( fn enforce_field_expression<F: Field, CS: ConstraintSystem<F>>(
cs: &mut CS, cs: &mut CS,
expression: FieldExpression, expression: FieldExpression,
) { ) -> UInt32 {
match expression { match expression {
FieldExpression::Add(left, right) => { FieldExpression::Add(left, right) => enforce_add(cs, *left, *right),
enforce_add(cs, *left, *right); FieldExpression::Sub(left, right) => enforce_sub(cs, *left, *right),
} FieldExpression::Mul(left, right) => enforce_mul(cs, *left, *right),
FieldExpression::Sub(left, right) => { FieldExpression::Div(left, right) => enforce_div(cs, *left, *right),
enforce_sub(cs, *left, *right); FieldExpression::Pow(left, right) => enforce_pow(cs, *left, *right),
}
FieldExpression::Mul(left, right) => {
enforce_mul(cs, *left, *right);
}
FieldExpression::Div(left, right) => {
enforce_div(cs, *left, *right);
}
_ => unimplemented!(), _ => unimplemented!(),
} }
} }
@ -224,15 +226,28 @@ pub fn generate_constraints<F: Field, CS: ConstraintSystem<F>>(cs: &mut CS, prog
.statements .statements
.into_iter() .into_iter()
.for_each(|statement| match statement { .for_each(|statement| match statement {
Statement::Definition(variable, expression) => match expression {
Expression::Boolean(boolean_expression) => {
let res = enforce_boolean_expression(cs, boolean_expression);
println!("boolean result: {}", res.get_value().unwrap());
}
Expression::FieldElement(field_expression) => {
let res = enforce_field_expression(cs, field_expression);
println!("field result: {}", res.value.unwrap());
}
_ => unimplemented!(),
},
Statement::Return(statements) => { Statement::Return(statements) => {
statements statements
.into_iter() .into_iter()
.for_each(|expression| match expression { .for_each(|expression| match expression {
Expression::Boolean(boolean_expression) => { Expression::Boolean(boolean_expression) => {
enforce_boolean_expression(cs, boolean_expression); let res = enforce_boolean_expression(cs, boolean_expression);
println!("boolean result: {}", res.get_value().unwrap());
} }
Expression::FieldElement(field_expression) => { Expression::FieldElement(field_expression) => {
enforce_field_expression(cs, field_expression); let res = enforce_field_expression(cs, field_expression);
println!("field result: {}", res.value.unwrap());
} }
_ => unimplemented!(), _ => unimplemented!(),
}); });

View File

@ -220,16 +220,21 @@ impl<'ast> From<ast::Expression<'ast>> for types::Expression {
} }
} }
// impl<'ast> From<ast::AssignStatement<'ast>> for types::StatementNode<'ast> { impl<'ast> From<ast::Variable<'ast>> for types::Variable {
// fn from(statement: ast::AssignStatement<'ast>) -> Self { fn from(variable: ast::Variable<'ast>) -> Self {
// types::Statement::Definition( types::Variable(variable.value)
// types::VariableNode::from(statement.variable), }
// types::ExpressionNode::from(statement.expression), }
// )
// .span(statement.span) impl<'ast> From<ast::AssignStatement<'ast>> for types::Statement {
// } fn from(statement: ast::AssignStatement<'ast>) -> Self {
// } types::Statement::Definition(
// types::Variable::from(statement.variable),
types::Expression::from(statement.expression),
)
}
}
impl<'ast> From<ast::ReturnStatement<'ast>> for types::Statement { impl<'ast> From<ast::ReturnStatement<'ast>> for types::Statement {
fn from(statement: ast::ReturnStatement<'ast>) -> Self { fn from(statement: ast::ReturnStatement<'ast>) -> Self {
types::Statement::Return( types::Statement::Return(