constraint structs and function definitions

This commit is contained in:
collin 2020-04-13 21:21:15 -07:00
parent 1903858e7a
commit 2f9e4b31e0
5 changed files with 242 additions and 144 deletions

View File

@ -2,4 +2,6 @@ struct Foo {
field a
bool b
}
return 1
def main() -> (field) :
a = 1 + 1
return a

View File

@ -1,5 +1,5 @@
use crate::aleo_program::{
BooleanExpression, Expression, FieldExpression, Program, Statement, Variable,
BooleanExpression, Expression, FieldExpression, Function, Program, Statement, Struct, Variable,
};
use snarkos_models::curves::{Field, PrimeField};
@ -14,6 +14,8 @@ use std::collections::HashMap;
pub enum ResolvedValue {
Boolean(Boolean),
FieldElement(UInt32),
Struct(Struct),
Function(Function),
}
pub struct ResolvedProgram {
@ -315,115 +317,141 @@ impl ResolvedProgram {
}
}
fn enforce_statement<F: Field + PrimeField, CS: ConstraintSystem<F>>(
&mut self,
cs: &mut CS,
statement: Statement,
) {
match statement {
Statement::Definition(variable, expression) => match expression {
Expression::Boolean(boolean_expression) => {
let res = self.enforce_boolean_expression(cs, boolean_expression);
println!(
" variable boolean result: {} = {}",
variable.0,
res.get_value().unwrap()
);
self.insert(variable, ResolvedValue::Boolean(res));
}
Expression::FieldElement(field_expression) => {
let res = self.enforce_field_expression(cs, field_expression);
println!(
" variable field result: {} = {}",
variable.0,
res.value.unwrap()
);
self.insert(variable, ResolvedValue::FieldElement(res));
}
Expression::Variable(unresolved_variable) => {
if self.resolved_variables.contains_key(&unresolved_variable) {
// Reassigning variable to another variable
let already_assigned = self
.resolved_variables
.get_mut(&unresolved_variable)
.unwrap()
.clone();
self.insert(variable, already_assigned);
} else {
// The type of the unassigned variable depends on what is passed in
if std::env::args()
.nth(1)
.expect("variable declaration not passed in")
.parse::<bool>()
.is_ok()
{
let resolved_boolean = self.bool_from_variable(cs, unresolved_variable);
println!(
"variable boolean result: {} = {}",
variable.0,
resolved_boolean.get_value().unwrap()
);
self.insert(variable, ResolvedValue::Boolean(resolved_boolean));
} else {
let resolved_field_element =
self.u32_from_variable(cs, unresolved_variable);
println!(
" variable field result: {} = {}",
variable.0,
resolved_field_element.value.unwrap()
);
self.insert(
variable,
ResolvedValue::FieldElement(resolved_field_element),
);
}
}
}
},
Statement::Return(statements) => {
statements
.into_iter()
.for_each(|expression| match expression {
Expression::Boolean(boolean_expression) => {
let res = self.enforce_boolean_expression(cs, boolean_expression);
println!("\n Boolean result = {}", res.get_value().unwrap());
}
Expression::FieldElement(field_expression) => {
let res = self.enforce_field_expression(cs, field_expression);
println!("\n Field result = {}", res.value.unwrap());
}
Expression::Variable(variable) => {
match self.resolved_variables.get_mut(&variable).unwrap().clone() {
ResolvedValue::Boolean(boolean) => println!(
"\n Variable result = {}",
boolean.get_value().unwrap()
),
ResolvedValue::FieldElement(field_element) => println!(
"\n Variable field result = {}",
field_element.value.unwrap()
),
_ => {}
}
}
});
}
};
}
pub fn generate_constraints<F: Field + PrimeField, CS: ConstraintSystem<F>>(
cs: &mut CS,
program: Program,
) {
let mut resolved_program = ResolvedProgram::new();
program
.structs
.into_iter()
.for_each(|(variable, struct_def)| {
resolved_program
.resolved_variables
.insert(variable, ResolvedValue::Struct(struct_def));
});
program
.functions
.into_iter()
.for_each(|(variable, function)| {
resolved_program
.resolved_variables
.insert(variable, ResolvedValue::Function(function));
});
let main = resolved_program
.resolved_variables
.get_mut(&Variable("main".into()))
.expect("main function not defined");
match main {
ResolvedValue::Function(function) => function
.statements
.clone()
.into_iter()
.for_each(|statement| resolved_program.enforce_statement(cs, statement)),
_ => unimplemented!("main must be a function"),
}
program
.statements
.into_iter()
.for_each(|statement| match statement {
Statement::Definition(variable, expression) => match expression {
Expression::Boolean(boolean_expression) => {
let res =
resolved_program.enforce_boolean_expression(cs, boolean_expression);
println!(
" variable boolean result: {} = {}",
variable.0,
res.get_value().unwrap()
);
resolved_program.insert(variable, ResolvedValue::Boolean(res));
}
Expression::FieldElement(field_expression) => {
let res = resolved_program.enforce_field_expression(cs, field_expression);
println!(
" variable field result: {} = {}",
variable.0,
res.value.unwrap()
);
resolved_program.insert(variable, ResolvedValue::FieldElement(res));
}
Expression::Variable(unresolved_variable) => {
if resolved_program
.resolved_variables
.contains_key(&unresolved_variable)
{
// Reassigning variable to another variable
let already_assigned = resolved_program
.resolved_variables
.get_mut(&unresolved_variable)
.unwrap()
.clone();
resolved_program.insert(variable, already_assigned);
} else {
// The type of the unassigned variable depends on what is passed in
if std::env::args()
.nth(1)
.expect("variable declaration not passed in")
.parse::<bool>()
.is_ok()
{
let resolved_boolean =
resolved_program.bool_from_variable(cs, unresolved_variable);
println!(
"variable boolean result: {} = {}",
variable.0,
resolved_boolean.get_value().unwrap()
);
resolved_program
.insert(variable, ResolvedValue::Boolean(resolved_boolean));
} else {
let resolved_field_element =
resolved_program.u32_from_variable(cs, unresolved_variable);
println!(
" variable field result: {} = {}",
variable.0,
resolved_field_element.value.unwrap()
);
resolved_program.insert(
variable,
ResolvedValue::FieldElement(resolved_field_element),
);
}
}
}
},
Statement::Return(statements) => {
statements
.into_iter()
.for_each(|expression| match expression {
Expression::Boolean(boolean_expression) => {
let res = resolved_program
.enforce_boolean_expression(cs, boolean_expression);
println!("\n Boolean result = {}", res.get_value().unwrap());
}
Expression::FieldElement(field_expression) => {
let res =
resolved_program.enforce_field_expression(cs, field_expression);
println!("\n Field result = {}", res.value.unwrap());
}
Expression::Variable(variable) => {
match resolved_program
.resolved_variables
.get_mut(&variable)
.unwrap()
.clone()
{
ResolvedValue::Boolean(boolean) => println!(
"\n Variable result = {}",
boolean.get_value().unwrap()
),
ResolvedValue::FieldElement(field_element) => println!(
"\n Variable field result = {}",
field_element.value.unwrap()
),
}
}
});
}
});
.for_each(|statement| resolved_program.enforce_statement(cs, statement));
}
}

View File

@ -4,29 +4,11 @@
//! @author Collin Chin <collin@aleo.org>
//! @date 2020
// id == 0 for field values
// id < 0 for boolean values
use std::collections::HashMap;
/// A variable in a constraint system.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Variable(pub String);
//
// /// Linear combination of variables in a program. (a + b + c)
// #[derive(Debug, Clone)]
// pub struct LinearCombination (pub Vec<Variable>);
//
// impl LinearCombination {
// pub fn one() -> Self {
// LinearCombination(vec![Variable{ id: 0, value: "1".into() }])
// }
//
// pub fn value(&self) -> String {
// self.0[0].value.clone()
// }
// }
//
// /// Quadratic combination of variables in a program (a * b)
// #[derive(Debug, Clone)]
// pub struct QuadraticCombination (pub LinearCombination, pub LinearCombination);
/// Expression that evaluates to a field value
#[derive(Debug, Clone)]
@ -109,11 +91,33 @@ pub struct Struct {
pub fields: Vec<StructField>,
}
#[derive(Clone, Debug)]
pub enum Visibility {
Public,
Private,
}
#[derive(Clone, Debug)]
pub struct Parameter {
pub visibility: Option<Visibility>,
pub ty: Type,
pub variable: Variable,
}
#[derive(Clone, Debug)]
pub struct Function {
pub variable: Variable,
pub parameters: Vec<Parameter>,
pub returns: Vec<Type>,
pub statements: Vec<Statement>,
}
/// A simple program with statement expressions, program arguments and program returns.
#[derive(Debug, Clone)]
pub struct Program {
pub id: String,
pub structs: Vec<Struct>,
pub structs: HashMap<Variable, Struct>,
pub functions: HashMap<Variable, Function>,
pub statements: Vec<Statement>,
pub arguments: Vec<Variable>,
pub returns: Vec<Variable>,

View File

@ -6,6 +6,7 @@
//! @date 2020
use crate::{aleo_program::types, ast};
use std::collections::HashMap;
impl<'ast> From<ast::Field<'ast>> for types::FieldExpression {
fn from(field: ast::Field<'ast>) -> Self {
@ -341,8 +342,8 @@ impl<'ast> From<ast::Statement<'ast>> for types::Statement {
impl<'ast> From<ast::BasicType<'ast>> for types::Type {
fn from(basic_type: ast::BasicType<'ast>) -> Self {
match basic_type {
ast::BasicType::Field(ty) => types::Type::FieldElement,
ast::BasicType::Boolean(ty) => types::Type::Boolean,
ast::BasicType::Field(_ty) => types::Type::FieldElement,
ast::BasicType::Boolean(_ty) => types::Type::Boolean,
}
}
}
@ -396,20 +397,82 @@ impl<'ast> From<ast::Struct<'ast>> for types::Struct {
}
}
impl From<ast::Visibility> for types::Visibility {
fn from(visibility: ast::Visibility) -> Self {
match visibility {
ast::Visibility::Private(_private) => types::Visibility::Private,
ast::Visibility::Public(_public) => types::Visibility::Public,
}
}
}
impl<'ast> From<ast::Parameter<'ast>> for types::Parameter {
fn from(parameter: ast::Parameter<'ast>) -> Self {
let ty = types::Type::from(parameter.ty);
let variable = types::Variable::from(parameter.variable);
if parameter.visibility.is_some() {
let visibility = Some(types::Visibility::from(parameter.visibility.unwrap()));
types::Parameter {
visibility,
ty,
variable,
}
} else {
types::Parameter {
visibility: None,
ty,
variable,
}
}
}
}
impl<'ast> From<ast::Function<'ast>> for types::Function {
fn from(function_definition: ast::Function<'ast>) -> Self {
let variable = types::Variable::from(function_definition.variable);
let parameters = function_definition
.parameters
.into_iter()
.map(|parameter| types::Parameter::from(parameter))
.collect();
let returns = function_definition
.returns
.into_iter()
.map(|return_type| types::Type::from(return_type))
.collect();
let statements = function_definition
.statements
.into_iter()
.map(|statement| types::Statement::from(statement))
.collect();
types::Function {
variable,
parameters,
returns,
statements,
}
}
}
impl<'ast> From<ast::File<'ast>> for types::Program {
fn from(file: ast::File<'ast>) -> Self {
// 1. compile ast -> aleo program representation
let structs = file
.structs
.into_iter()
.map(|struct_def| {
println!("{:#?}", struct_def);
types::Struct::from(struct_def)
})
.collect();
file.functions
.into_iter()
.for_each(|function_def| println!("{:#?}", function_def));
let mut structs = HashMap::new();
let mut functions = HashMap::new();
file.structs.into_iter().for_each(|struct_def| {
println!("{:#?}", struct_def);
let struct_definition = types::Struct::from(struct_def);
structs.insert(struct_definition.variable.clone(), struct_definition);
});
file.functions.into_iter().for_each(|function_def| {
println!("{:#?}", function_def);
let function_definition = types::Function::from(function_def);
functions.insert(function_definition.variable.clone(), function_definition);
});
let statements: Vec<types::Statement> = file
.statements
.into_iter()
@ -419,6 +482,7 @@ impl<'ast> From<ast::File<'ast>> for types::Program {
types::Program {
id: "main".into(),
structs,
functions,
statements,
arguments: vec![],
returns: vec![],

View File

@ -408,20 +408,20 @@ pub struct ArrayAccess<'ast> {
pub span: Span<'ast>,
}
// #[derive(Clone, Debug, FromPest, PartialEq)]
// #[pest_ast(rule(Rule::member_access))]
// pub struct MemberAccess<'ast> {
// pub id: IdentifierExpression<'ast>,
// #[pest_ast(outer())]
// pub span: Span<'ast>,
// }
#[derive(Clone, Debug, FromPest, PartialEq)]
#[pest_ast(rule(Rule::access_member))]
pub struct MemberAccess<'ast> {
pub variable: Variable<'ast>,
#[pest_ast(outer())]
pub span: Span<'ast>,
}
#[derive(Clone, Debug, FromPest, PartialEq)]
#[pest_ast(rule(Rule::access))]
pub enum Access<'ast> {
// Call(CallAccess<'ast>),
Select(ArrayAccess<'ast>),
// Member(MemberAccess<'ast>),
Member(MemberAccess<'ast>),
}
#[derive(Clone, Debug, FromPest, PartialEq)]