From ac481386216028f3c910fcbdfd83ec9cc02a0217 Mon Sep 17 00:00:00 2001 From: collin Date: Wed, 15 Apr 2020 18:46:54 -0700 Subject: [PATCH] constraints function calls, params, returns --- simple.program | 15 +-- src/aleo_program/constraints.rs | 209 +++++++++++++++++++++++------- src/aleo_program/types.rs | 9 +- src/aleo_program/types_display.rs | 85 +++++++++++- src/aleo_program/types_from.rs | 15 ++- src/ast.rs | 5 - 6 files changed, 260 insertions(+), 78 deletions(-) diff --git a/simple.program b/simple.program index b71faab37d..44b95e3837 100644 --- a/simple.program +++ b/simple.program @@ -1,10 +1,9 @@ -struct Point { - field x - field y -} +def test(field x) -> (field): + return 1 -Point p = Point {x: 1, y: 0} +def test2(bool b) -> (bool): + return b -p.x = 2 - -return p \ No newline at end of file +def main() -> (field): + a = test2(true) + return a \ No newline at end of file diff --git a/src/aleo_program/constraints.rs b/src/aleo_program/constraints.rs index fdaeadca1e..ab262e756e 100644 --- a/src/aleo_program/constraints.rs +++ b/src/aleo_program/constraints.rs @@ -22,6 +22,7 @@ pub enum ResolvedValue { StructDefinition(Struct), StructExpression(Variable, Vec), Function(Function), + Return(Vec), // add Null for function returns } impl fmt::Display for ResolvedValue { @@ -59,7 +60,17 @@ impl fmt::Display for ResolvedValue { } write!(f, "}}") } - _ => unimplemented!("resolve values not finished"), + ResolvedValue::Return(ref values) => { + write!(f, "Return values : [")?; + for (i, value) in values.iter().enumerate() { + write!(f, "{}", value)?; + if i < values.len() - 1 { + write!(f, ", ")?; + } + } + write!(f, "]") + } + _ => unimplemented!("display not impl for value"), } } } @@ -544,6 +555,18 @@ impl ResolvedProgram { } } + fn enforce_function_access_expression>( + &mut self, + cs: &mut CS, + function: Box, + arguments: Vec, + ) -> ResolvedValue { + match self.enforce_expression(cs, *function) { + ResolvedValue::Function(function) => self.enforce_function(cs, function, arguments), + value => unimplemented!("Cannot call unknown function {}", value), + } + } + fn enforce_expression>( &mut self, cs: &mut CS, @@ -586,45 +609,134 @@ impl ResolvedProgram { Expression::StructMemberAccess(struct_variable, struct_member) => { self.enforce_struct_access_expression(cs, struct_variable, struct_member) } + Expression::FunctionCall(function, arguments) => { + self.enforce_function_access_expression(cs, function, arguments) + } } } - fn enforce_statement>( + fn enforce_definition_statement>( &mut self, cs: &mut CS, - statement: Statement, + variable: Variable, + expression: Expression, ) { - match statement { - Statement::Definition(variable, expression) => { - let result = self.enforce_expression(cs, expression); - println!(" statement result: {} = {}", variable.0, result); - self.insert(variable, result); - } - Statement::Return(statements) => { - statements - .into_iter() - .for_each(|expression| match expression { - Expression::Boolean(boolean_expression) => { - let res = self.enforce_boolean_expression(cs, boolean_expression); - println!("\n Boolean result = {}", res); + let result = self.enforce_expression(cs, expression); + // println!(" statement result: {} = {}", variable.0, result); + self.insert(variable, result); + } + + fn enforce_return_statement>( + &mut self, + cs: &mut CS, + statements: Vec, + ) -> ResolvedValue { + ResolvedValue::Return( + statements + .into_iter() + .map(|expression| match expression { + Expression::Boolean(boolean_expression) => { + self.enforce_boolean_expression(cs, boolean_expression) + } + Expression::FieldElement(field_expression) => { + self.enforce_field_expression(cs, field_expression) + } + Expression::Variable(variable) => { + self.resolved_variables.get_mut(&variable).unwrap().clone() + } + Expression::Struct(_v, _m) => { + unimplemented!("return struct not impl"); + } + expr => unimplemented!("expression {} can't be returned yet", expr), + }) + .collect::>(), + ) + } + + // fn enforce_statement>( + // &mut self, + // cs: &mut CS, + // statement: Statement, + // ) { + // match statement { + // Statement::Definition(variable, expression) => { + // self.enforce_definition_statement(cs, variable, expression); + // } + // Statement::Return(statements) => { + // let res = self.enforce_return_statement(cs, statements); + // + // } + // }; + // } + + fn enforce_function>( + &mut self, + cs: &mut CS, + function: Function, + arguments: Vec, + ) -> ResolvedValue { + // Make sure we are given the correct number of arguments + if function.parameters.len() != arguments.len() { + unimplemented!( + "function expected {} arguments, got {}", + function.parameters.len(), + arguments.len() + ) + } + + // Store arguments as variables in resolved program + function + .parameters + .clone() + .iter() + .zip(arguments.clone().into_iter()) + .for_each(|(parameter, argument)| { + // Check visibility here + + // Check that argument is correct type + match parameter.ty.clone() { + Type::FieldElement => { + match self.enforce_expression(cs, argument) { + ResolvedValue::FieldElement(field) => { + // Store argument as variable with parameter name + // TODO: this will not support multiple function calls or variables with same name as parameter + self.resolved_variables.insert( + parameter.variable.clone(), + ResolvedValue::FieldElement(field), + ); + } + argument => unimplemented!("expected field argument got {}", argument), } - Expression::FieldElement(field_expression) => { - let res = self.enforce_field_expression(cs, field_expression); - println!("\n Field result = {}", res); + } + Type::Boolean => match self.enforce_expression(cs, argument) { + ResolvedValue::Boolean(bool) => { + self.resolved_variables + .insert(parameter.variable.clone(), ResolvedValue::Boolean(bool)); } - Expression::Variable(variable) => { - println!( - "\n Return = {}", - self.resolved_variables.get_mut(&variable).unwrap().clone() - ); - } - Expression::Struct(_v, _m) => { - unimplemented!("return struct not impl"); - } - _ => unimplemented!("expression can't be returned yet"), - }); - } - }; + argument => unimplemented!("expected boolean argument got {}", argument), + }, + ty => unimplemented!("parameter type {} not matched yet", ty), + } + }); + + // Evaluate function statements + + let mut return_values = ResolvedValue::Return(vec![]); + + function + .statements + .clone() + .into_iter() + .for_each(|statement| match statement { + Statement::Definition(variable, expression) => { + self.enforce_definition_statement(cs, variable, expression) + } + Statement::Return(expressions) => { + return_values = self.enforce_return_statement(cs, expressions) + } + }); + + return_values } pub fn generate_constraints>( @@ -650,24 +762,23 @@ impl ResolvedProgram { .insert(variable, ResolvedValue::Function(function)); }); - // let main = resolved_program - // .resolved_variables - // .get_mut(&Variable("main".into())) - // .expect("main function not defined"); - // - // match main { - // ResolvedValue::Function(function) => function - // .statements - // .clone() - // .into_iter() - // .for_each(|statement| resolved_program.enforce_statement(cs, statement)), - // _ => unimplemented!("main must be a function"), - // } + let main = resolved_program + .resolved_variables + .get(&Variable("main".into())) + .expect("main function not defined"); - program - .statements - .into_iter() - .for_each(|statement| resolved_program.enforce_statement(cs, statement)); + let result = match main.clone() { + ResolvedValue::Function(function) => { + resolved_program.enforce_function(cs, function, vec![]) + } + _ => unimplemented!("main must be a function"), + }; + println!("\n {}", result); + + // program + // .statements + // .into_iter() + // .for_each(|statement| resolved_program.enforce_statement(cs, statement)); } } diff --git a/src/aleo_program/types.rs b/src/aleo_program/types.rs index f2a3112bb2..238290238a 100644 --- a/src/aleo_program/types.rs +++ b/src/aleo_program/types.rs @@ -7,7 +7,7 @@ use std::collections::HashMap; /// A variable in a constraint system. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash)] pub struct Variable(pub String); /// Spread operator @@ -92,9 +92,10 @@ pub enum Expression { Boolean(BooleanExpression), FieldElement(FieldExpression), Variable(Variable), - ArrayAccess(Box, FieldRangeOrExpression), Struct(Variable, Vec), + ArrayAccess(Box, FieldRangeOrExpression), StructMemberAccess(Box, Variable), // (struct name, struct member name) + FunctionCall(Box, Vec), } /// Program statement that defines some action (or expression) to be carried out. @@ -143,14 +144,14 @@ pub enum Visibility { Private, } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct Parameter { pub visibility: Option, pub ty: Type, pub variable: Variable, } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct Function { pub variable: Variable, pub parameters: Vec, diff --git a/src/aleo_program/types_display.rs b/src/aleo_program/types_display.rs index c63a4415ab..617557978a 100644 --- a/src/aleo_program/types_display.rs +++ b/src/aleo_program/types_display.rs @@ -6,8 +6,8 @@ use crate::aleo_program::{ BooleanExpression, BooleanSpread, BooleanSpreadOrExpression, Expression, FieldExpression, - FieldRangeOrExpression, FieldSpread, FieldSpreadOrExpression, Statement, Struct, StructField, - Type, Variable, + FieldRangeOrExpression, FieldSpread, FieldSpreadOrExpression, Function, Parameter, Statement, + Struct, StructField, Type, Variable, }; use std::fmt; @@ -17,6 +17,11 @@ impl fmt::Display for Variable { write!(f, "{}", self.0) } } +impl fmt::Debug for Variable { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } +} impl fmt::Display for FieldSpread { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -145,7 +150,16 @@ impl<'ast> fmt::Display for Expression { Expression::StructMemberAccess(ref struct_variable, ref member) => { write!(f, "{}.{}", struct_variable, member) } - // _ => unimplemented!("can't display expression yet"), + Expression::FunctionCall(ref function, ref arguments) => { + write!(f, "{}(", function,)?; + for (i, param) in arguments.iter().enumerate() { + write!(f, "{}", param)?; + if i < arguments.len() - 1 { + write!(f, ", ")?; + } + } + write!(f, ")") + } // _ => unimplemented!("can't display expression yet"), } } } @@ -156,9 +170,11 @@ impl fmt::Display for Statement { statements.iter().for_each(|statement| { write!(f, "return {}", statement).unwrap(); }); - write!(f, "") + write!(f, "\n") + } + Statement::Definition(ref variable, ref statement) => { + write!(f, "{} = {}", variable, statement) } - _ => unimplemented!(), } } } @@ -170,7 +186,7 @@ impl fmt::Debug for Statement { statements.iter().for_each(|statement| { write!(f, "return {}", statement).unwrap(); }); - write!(f, "") + write!(f, "\n") } Statement::Definition(ref variable, ref statement) => { write!(f, "{} = {}", variable, statement) @@ -205,3 +221,60 @@ impl fmt::Debug for Struct { write!(f, "}}") } } + +impl fmt::Display for Parameter { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // let visibility = if self.private { "private " } else { "" }; + write!( + f, + "{} {}", + // visibility, + self.ty, + self.variable + ) + } +} + +impl fmt::Debug for Parameter { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Parameter(variable: {:?})", self.ty) + } +} + +impl fmt::Display for Function { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "({}):\n{}", + self.parameters + .iter() + .map(|x| format!("{}", x)) + .collect::>() + .join(","), + self.statements + .iter() + .map(|x| format!("\t{}", x)) + .collect::>() + .join("\n") + ) + } +} + +impl fmt::Debug for Function { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "({}):\n{}", + self.parameters + .iter() + .map(|x| format!("{}", x)) + .collect::>() + .join(","), + self.statements + .iter() + .map(|x| format!("\t{}", x)) + .collect::>() + .join("\n") + ) + } +} diff --git a/src/aleo_program/types_from.rs b/src/aleo_program/types_from.rs index 40ea30eb9a..0ea32a83ac 100644 --- a/src/aleo_program/types_from.rs +++ b/src/aleo_program/types_from.rs @@ -319,10 +319,15 @@ impl<'ast> From> for types::Expression { .accesses .into_iter() .fold(variable, |acc, access| match access { - ast::Access::Call(a) => match acc { - types::Expression::Variable(_) => { - unimplemented!("function calls not implemented") - } + ast::Access::Call(function) => match acc { + types::Expression::Variable(_) => types::Expression::FunctionCall( + Box::new(acc), + function + .expressions + .into_iter() + .map(|expression| types::Expression::from(expression)) + .collect(), + ), expression => { unimplemented!("only function names are callable, found \"{}\"", expression) } @@ -698,12 +703,10 @@ impl<'ast> From> for types::Program { let mut functions = HashMap::new(); file.structs.into_iter().for_each(|struct_def| { - // println!("{:#?}", struct_def); let struct_definition = types::Struct::from(struct_def); structs.insert(struct_definition.variable.clone(), struct_definition); }); file.functions.into_iter().for_each(|function_def| { - // println!("{:#?}", function_def); let function_definition = types::Function::from(function_def); functions.insert(function_definition.variable.clone(), function_definition); }); diff --git a/src/ast.rs b/src/ast.rs index 55dd0b354c..6fc9dbf50f 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -57,25 +57,21 @@ fn parse_term(pair: Pair) -> Box { match next.as_rule() { Rule::expression => Expression::from_pest(&mut pair.into_inner()).unwrap(), // Parenthesis case Rule::expression_inline_struct => { - println!("struct inline"); Expression::StructInline( StructInlineExpression::from_pest(&mut pair.into_inner()).unwrap(), ) }, Rule::expression_array_inline => { - println!("array inline"); Expression::ArrayInline( ArrayInlineExpression::from_pest(&mut pair.into_inner()).unwrap() ) }, Rule::expression_array_initializer => { - println!("array initializer"); Expression::ArrayInitializer( ArrayInitializerExpression::from_pest(&mut pair.into_inner()).unwrap() ) }, Rule::expression_conditional => { - println!("conditional expression"); Expression::Ternary( TernaryExpression::from_pest(&mut pair.into_inner()).unwrap(), ) @@ -113,7 +109,6 @@ fn parse_term(pair: Pair) -> Box { Expression::Decrement(DecrementExpression { operation, expression, span }) }, Rule::expression_postfix => { - println!("postfix expression"); Expression::Postfix( PostfixExpression::from_pest(&mut pair.into_inner()).unwrap(), )