mirror of
https://github.com/ProvableHQ/leo.git
synced 2024-12-23 02:01:54 +03:00
constraints function calls, params, returns
This commit is contained in:
parent
3bc7118d71
commit
ac48138621
@ -1,10 +1,9 @@
|
||||
struct Point {
|
||||
field x
|
||||
field y
|
||||
}
|
||||
def test(field x) -> (field):
|
||||
return 1
|
||||
|
||||
Point p = Point {x: 1, y: 0}
|
||||
def test2(bool b) -> (bool):
|
||||
return b
|
||||
|
||||
p.x = 2
|
||||
|
||||
return p
|
||||
def main() -> (field):
|
||||
a = test2(true)
|
||||
return a
|
@ -22,6 +22,7 @@ pub enum ResolvedValue {
|
||||
StructDefinition(Struct),
|
||||
StructExpression(Variable, Vec<StructMember>),
|
||||
Function(Function),
|
||||
Return(Vec<ResolvedValue>), // add Null for function returns
|
||||
}
|
||||
|
||||
impl fmt::Display for ResolvedValue {
|
||||
@ -59,7 +60,17 @@ impl fmt::Display for ResolvedValue {
|
||||
}
|
||||
write!(f, "}}")
|
||||
}
|
||||
_ => unimplemented!("resolve values not finished"),
|
||||
ResolvedValue::Return(ref values) => {
|
||||
write!(f, "Return values : [")?;
|
||||
for (i, value) in values.iter().enumerate() {
|
||||
write!(f, "{}", value)?;
|
||||
if i < values.len() - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
write!(f, "]")
|
||||
}
|
||||
_ => unimplemented!("display not impl for value"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -544,6 +555,18 @@ impl ResolvedProgram {
|
||||
}
|
||||
}
|
||||
|
||||
fn enforce_function_access_expression<F: Field + PrimeField, CS: ConstraintSystem<F>>(
|
||||
&mut self,
|
||||
cs: &mut CS,
|
||||
function: Box<Expression>,
|
||||
arguments: Vec<Expression>,
|
||||
) -> ResolvedValue {
|
||||
match self.enforce_expression(cs, *function) {
|
||||
ResolvedValue::Function(function) => self.enforce_function(cs, function, arguments),
|
||||
value => unimplemented!("Cannot call unknown function {}", value),
|
||||
}
|
||||
}
|
||||
|
||||
fn enforce_expression<F: Field + PrimeField, CS: ConstraintSystem<F>>(
|
||||
&mut self,
|
||||
cs: &mut CS,
|
||||
@ -586,45 +609,134 @@ impl ResolvedProgram {
|
||||
Expression::StructMemberAccess(struct_variable, struct_member) => {
|
||||
self.enforce_struct_access_expression(cs, struct_variable, struct_member)
|
||||
}
|
||||
Expression::FunctionCall(function, arguments) => {
|
||||
self.enforce_function_access_expression(cs, function, arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn enforce_statement<F: Field + PrimeField, CS: ConstraintSystem<F>>(
|
||||
fn enforce_definition_statement<F: Field + PrimeField, CS: ConstraintSystem<F>>(
|
||||
&mut self,
|
||||
cs: &mut CS,
|
||||
statement: Statement,
|
||||
variable: Variable,
|
||||
expression: Expression,
|
||||
) {
|
||||
match statement {
|
||||
Statement::Definition(variable, expression) => {
|
||||
let result = self.enforce_expression(cs, expression);
|
||||
println!(" statement result: {} = {}", variable.0, result);
|
||||
self.insert(variable, result);
|
||||
}
|
||||
Statement::Return(statements) => {
|
||||
statements
|
||||
.into_iter()
|
||||
.for_each(|expression| match expression {
|
||||
Expression::Boolean(boolean_expression) => {
|
||||
let res = self.enforce_boolean_expression(cs, boolean_expression);
|
||||
println!("\n Boolean result = {}", res);
|
||||
let result = self.enforce_expression(cs, expression);
|
||||
// println!(" statement result: {} = {}", variable.0, result);
|
||||
self.insert(variable, result);
|
||||
}
|
||||
|
||||
fn enforce_return_statement<F: Field + PrimeField, CS: ConstraintSystem<F>>(
|
||||
&mut self,
|
||||
cs: &mut CS,
|
||||
statements: Vec<Expression>,
|
||||
) -> ResolvedValue {
|
||||
ResolvedValue::Return(
|
||||
statements
|
||||
.into_iter()
|
||||
.map(|expression| match expression {
|
||||
Expression::Boolean(boolean_expression) => {
|
||||
self.enforce_boolean_expression(cs, boolean_expression)
|
||||
}
|
||||
Expression::FieldElement(field_expression) => {
|
||||
self.enforce_field_expression(cs, field_expression)
|
||||
}
|
||||
Expression::Variable(variable) => {
|
||||
self.resolved_variables.get_mut(&variable).unwrap().clone()
|
||||
}
|
||||
Expression::Struct(_v, _m) => {
|
||||
unimplemented!("return struct not impl");
|
||||
}
|
||||
expr => unimplemented!("expression {} can't be returned yet", expr),
|
||||
})
|
||||
.collect::<Vec<ResolvedValue>>(),
|
||||
)
|
||||
}
|
||||
|
||||
// fn enforce_statement<F: Field + PrimeField, CS: ConstraintSystem<F>>(
|
||||
// &mut self,
|
||||
// cs: &mut CS,
|
||||
// statement: Statement,
|
||||
// ) {
|
||||
// match statement {
|
||||
// Statement::Definition(variable, expression) => {
|
||||
// self.enforce_definition_statement(cs, variable, expression);
|
||||
// }
|
||||
// Statement::Return(statements) => {
|
||||
// let res = self.enforce_return_statement(cs, statements);
|
||||
//
|
||||
// }
|
||||
// };
|
||||
// }
|
||||
|
||||
fn enforce_function<F: Field + PrimeField, CS: ConstraintSystem<F>>(
|
||||
&mut self,
|
||||
cs: &mut CS,
|
||||
function: Function,
|
||||
arguments: Vec<Expression>,
|
||||
) -> ResolvedValue {
|
||||
// Make sure we are given the correct number of arguments
|
||||
if function.parameters.len() != arguments.len() {
|
||||
unimplemented!(
|
||||
"function expected {} arguments, got {}",
|
||||
function.parameters.len(),
|
||||
arguments.len()
|
||||
)
|
||||
}
|
||||
|
||||
// Store arguments as variables in resolved program
|
||||
function
|
||||
.parameters
|
||||
.clone()
|
||||
.iter()
|
||||
.zip(arguments.clone().into_iter())
|
||||
.for_each(|(parameter, argument)| {
|
||||
// Check visibility here
|
||||
|
||||
// Check that argument is correct type
|
||||
match parameter.ty.clone() {
|
||||
Type::FieldElement => {
|
||||
match self.enforce_expression(cs, argument) {
|
||||
ResolvedValue::FieldElement(field) => {
|
||||
// Store argument as variable with parameter name
|
||||
// TODO: this will not support multiple function calls or variables with same name as parameter
|
||||
self.resolved_variables.insert(
|
||||
parameter.variable.clone(),
|
||||
ResolvedValue::FieldElement(field),
|
||||
);
|
||||
}
|
||||
argument => unimplemented!("expected field argument got {}", argument),
|
||||
}
|
||||
Expression::FieldElement(field_expression) => {
|
||||
let res = self.enforce_field_expression(cs, field_expression);
|
||||
println!("\n Field result = {}", res);
|
||||
}
|
||||
Type::Boolean => match self.enforce_expression(cs, argument) {
|
||||
ResolvedValue::Boolean(bool) => {
|
||||
self.resolved_variables
|
||||
.insert(parameter.variable.clone(), ResolvedValue::Boolean(bool));
|
||||
}
|
||||
Expression::Variable(variable) => {
|
||||
println!(
|
||||
"\n Return = {}",
|
||||
self.resolved_variables.get_mut(&variable).unwrap().clone()
|
||||
);
|
||||
}
|
||||
Expression::Struct(_v, _m) => {
|
||||
unimplemented!("return struct not impl");
|
||||
}
|
||||
_ => unimplemented!("expression can't be returned yet"),
|
||||
});
|
||||
}
|
||||
};
|
||||
argument => unimplemented!("expected boolean argument got {}", argument),
|
||||
},
|
||||
ty => unimplemented!("parameter type {} not matched yet", ty),
|
||||
}
|
||||
});
|
||||
|
||||
// Evaluate function statements
|
||||
|
||||
let mut return_values = ResolvedValue::Return(vec![]);
|
||||
|
||||
function
|
||||
.statements
|
||||
.clone()
|
||||
.into_iter()
|
||||
.for_each(|statement| match statement {
|
||||
Statement::Definition(variable, expression) => {
|
||||
self.enforce_definition_statement(cs, variable, expression)
|
||||
}
|
||||
Statement::Return(expressions) => {
|
||||
return_values = self.enforce_return_statement(cs, expressions)
|
||||
}
|
||||
});
|
||||
|
||||
return_values
|
||||
}
|
||||
|
||||
pub fn generate_constraints<F: Field + PrimeField, CS: ConstraintSystem<F>>(
|
||||
@ -650,24 +762,23 @@ impl ResolvedProgram {
|
||||
.insert(variable, ResolvedValue::Function(function));
|
||||
});
|
||||
|
||||
// let main = resolved_program
|
||||
// .resolved_variables
|
||||
// .get_mut(&Variable("main".into()))
|
||||
// .expect("main function not defined");
|
||||
//
|
||||
// match main {
|
||||
// ResolvedValue::Function(function) => function
|
||||
// .statements
|
||||
// .clone()
|
||||
// .into_iter()
|
||||
// .for_each(|statement| resolved_program.enforce_statement(cs, statement)),
|
||||
// _ => unimplemented!("main must be a function"),
|
||||
// }
|
||||
let main = resolved_program
|
||||
.resolved_variables
|
||||
.get(&Variable("main".into()))
|
||||
.expect("main function not defined");
|
||||
|
||||
program
|
||||
.statements
|
||||
.into_iter()
|
||||
.for_each(|statement| resolved_program.enforce_statement(cs, statement));
|
||||
let result = match main.clone() {
|
||||
ResolvedValue::Function(function) => {
|
||||
resolved_program.enforce_function(cs, function, vec![])
|
||||
}
|
||||
_ => unimplemented!("main must be a function"),
|
||||
};
|
||||
println!("\n {}", result);
|
||||
|
||||
// program
|
||||
// .statements
|
||||
// .into_iter()
|
||||
// .for_each(|statement| resolved_program.enforce_statement(cs, statement));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7,7 +7,7 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// A variable in a constraint system.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[derive(Clone, PartialEq, Eq, Hash)]
|
||||
pub struct Variable(pub String);
|
||||
|
||||
/// Spread operator
|
||||
@ -92,9 +92,10 @@ pub enum Expression {
|
||||
Boolean(BooleanExpression),
|
||||
FieldElement(FieldExpression),
|
||||
Variable(Variable),
|
||||
ArrayAccess(Box<Expression>, FieldRangeOrExpression),
|
||||
Struct(Variable, Vec<StructMember>),
|
||||
ArrayAccess(Box<Expression>, FieldRangeOrExpression),
|
||||
StructMemberAccess(Box<Expression>, Variable), // (struct name, struct member name)
|
||||
FunctionCall(Box<Expression>, Vec<Expression>),
|
||||
}
|
||||
|
||||
/// Program statement that defines some action (or expression) to be carried out.
|
||||
@ -143,14 +144,14 @@ pub enum Visibility {
|
||||
Private,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone)]
|
||||
pub struct Parameter {
|
||||
pub visibility: Option<Visibility>,
|
||||
pub ty: Type,
|
||||
pub variable: Variable,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Clone)]
|
||||
pub struct Function {
|
||||
pub variable: Variable,
|
||||
pub parameters: Vec<Parameter>,
|
||||
|
@ -6,8 +6,8 @@
|
||||
|
||||
use crate::aleo_program::{
|
||||
BooleanExpression, BooleanSpread, BooleanSpreadOrExpression, Expression, FieldExpression,
|
||||
FieldRangeOrExpression, FieldSpread, FieldSpreadOrExpression, Statement, Struct, StructField,
|
||||
Type, Variable,
|
||||
FieldRangeOrExpression, FieldSpread, FieldSpreadOrExpression, Function, Parameter, Statement,
|
||||
Struct, StructField, Type, Variable,
|
||||
};
|
||||
|
||||
use std::fmt;
|
||||
@ -17,6 +17,11 @@ impl fmt::Display for Variable {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
impl fmt::Debug for Variable {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for FieldSpread {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
@ -145,7 +150,16 @@ impl<'ast> fmt::Display for Expression {
|
||||
Expression::StructMemberAccess(ref struct_variable, ref member) => {
|
||||
write!(f, "{}.{}", struct_variable, member)
|
||||
}
|
||||
// _ => unimplemented!("can't display expression yet"),
|
||||
Expression::FunctionCall(ref function, ref arguments) => {
|
||||
write!(f, "{}(", function,)?;
|
||||
for (i, param) in arguments.iter().enumerate() {
|
||||
write!(f, "{}", param)?;
|
||||
if i < arguments.len() - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
write!(f, ")")
|
||||
} // _ => unimplemented!("can't display expression yet"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -156,9 +170,11 @@ impl fmt::Display for Statement {
|
||||
statements.iter().for_each(|statement| {
|
||||
write!(f, "return {}", statement).unwrap();
|
||||
});
|
||||
write!(f, "")
|
||||
write!(f, "\n")
|
||||
}
|
||||
Statement::Definition(ref variable, ref statement) => {
|
||||
write!(f, "{} = {}", variable, statement)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -170,7 +186,7 @@ impl fmt::Debug for Statement {
|
||||
statements.iter().for_each(|statement| {
|
||||
write!(f, "return {}", statement).unwrap();
|
||||
});
|
||||
write!(f, "")
|
||||
write!(f, "\n")
|
||||
}
|
||||
Statement::Definition(ref variable, ref statement) => {
|
||||
write!(f, "{} = {}", variable, statement)
|
||||
@ -205,3 +221,60 @@ impl fmt::Debug for Struct {
|
||||
write!(f, "}}")
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Parameter {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
// let visibility = if self.private { "private " } else { "" };
|
||||
write!(
|
||||
f,
|
||||
"{} {}",
|
||||
// visibility,
|
||||
self.ty,
|
||||
self.variable
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Parameter {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "Parameter(variable: {:?})", self.ty)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Function {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"({}):\n{}",
|
||||
self.parameters
|
||||
.iter()
|
||||
.map(|x| format!("{}", x))
|
||||
.collect::<Vec<_>>()
|
||||
.join(","),
|
||||
self.statements
|
||||
.iter()
|
||||
.map(|x| format!("\t{}", x))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Function {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"({}):\n{}",
|
||||
self.parameters
|
||||
.iter()
|
||||
.map(|x| format!("{}", x))
|
||||
.collect::<Vec<_>>()
|
||||
.join(","),
|
||||
self.statements
|
||||
.iter()
|
||||
.map(|x| format!("\t{}", x))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -319,10 +319,15 @@ impl<'ast> From<ast::PostfixExpression<'ast>> for types::Expression {
|
||||
.accesses
|
||||
.into_iter()
|
||||
.fold(variable, |acc, access| match access {
|
||||
ast::Access::Call(a) => match acc {
|
||||
types::Expression::Variable(_) => {
|
||||
unimplemented!("function calls not implemented")
|
||||
}
|
||||
ast::Access::Call(function) => match acc {
|
||||
types::Expression::Variable(_) => types::Expression::FunctionCall(
|
||||
Box::new(acc),
|
||||
function
|
||||
.expressions
|
||||
.into_iter()
|
||||
.map(|expression| types::Expression::from(expression))
|
||||
.collect(),
|
||||
),
|
||||
expression => {
|
||||
unimplemented!("only function names are callable, found \"{}\"", expression)
|
||||
}
|
||||
@ -698,12 +703,10 @@ impl<'ast> From<ast::File<'ast>> for types::Program {
|
||||
let mut functions = HashMap::new();
|
||||
|
||||
file.structs.into_iter().for_each(|struct_def| {
|
||||
// println!("{:#?}", struct_def);
|
||||
let struct_definition = types::Struct::from(struct_def);
|
||||
structs.insert(struct_definition.variable.clone(), struct_definition);
|
||||
});
|
||||
file.functions.into_iter().for_each(|function_def| {
|
||||
// println!("{:#?}", function_def);
|
||||
let function_definition = types::Function::from(function_def);
|
||||
functions.insert(function_definition.variable.clone(), function_definition);
|
||||
});
|
||||
|
@ -57,25 +57,21 @@ fn parse_term(pair: Pair<Rule>) -> Box<Expression> {
|
||||
match next.as_rule() {
|
||||
Rule::expression => Expression::from_pest(&mut pair.into_inner()).unwrap(), // Parenthesis case
|
||||
Rule::expression_inline_struct => {
|
||||
println!("struct inline");
|
||||
Expression::StructInline(
|
||||
StructInlineExpression::from_pest(&mut pair.into_inner()).unwrap(),
|
||||
)
|
||||
},
|
||||
Rule::expression_array_inline => {
|
||||
println!("array inline");
|
||||
Expression::ArrayInline(
|
||||
ArrayInlineExpression::from_pest(&mut pair.into_inner()).unwrap()
|
||||
)
|
||||
},
|
||||
Rule::expression_array_initializer => {
|
||||
println!("array initializer");
|
||||
Expression::ArrayInitializer(
|
||||
ArrayInitializerExpression::from_pest(&mut pair.into_inner()).unwrap()
|
||||
)
|
||||
},
|
||||
Rule::expression_conditional => {
|
||||
println!("conditional expression");
|
||||
Expression::Ternary(
|
||||
TernaryExpression::from_pest(&mut pair.into_inner()).unwrap(),
|
||||
)
|
||||
@ -113,7 +109,6 @@ fn parse_term(pair: Pair<Rule>) -> Box<Expression> {
|
||||
Expression::Decrement(DecrementExpression { operation, expression, span })
|
||||
},
|
||||
Rule::expression_postfix => {
|
||||
println!("postfix expression");
|
||||
Expression::Postfix(
|
||||
PostfixExpression::from_pest(&mut pair.into_inner()).unwrap(),
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user