diff --git a/simple.leo b/simple.leo index 847c630eb2..2a3887c58a 100644 --- a/simple.leo +++ b/simple.leo @@ -1,4 +1,10 @@ -function main() -> (u32) { - a = 1 + 1 - return a -} \ No newline at end of file +function test() -> (u32, u32) { + return 4, 4 +} + +function main() -> (u32, u32) { + a, b = test() + + return a, b +} + diff --git a/src/ast.rs b/src/ast.rs index fe5b35bca4..07bb440e66 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -881,11 +881,11 @@ impl<'ast> fmt::Display for MultipleAssignmentStatement<'ast> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { for (i, id) in self.assignees.iter().enumerate() { write!(f, "{}", id)?; - if i < ids.len() - 1 { + if i < self.assignees.len() - 1 { write!(f, ", ")?; } } - write!(f, " = {}", self.function_id) + write!(f, " = {}", self.function_name) } } diff --git a/src/leo.pest b/src/leo.pest index 6ed5f27fcd..2bce1314a4 100644 --- a/src/leo.pest +++ b/src/leo.pest @@ -163,7 +163,7 @@ parameter = {variable ~ ":" ~ visibility? ~ ty} parameter_list = _{(parameter ~ ("," ~ parameter)*)?} function_name = @{ ((!protected_name ~ ASCII_ALPHA) | (protected_name ~ (ASCII_ALPHANUMERIC | "_"))) ~ (ASCII_ALPHANUMERIC | "_")* } -function_definition = {"function" ~ function_name ~ "(" ~ parameter_list ~ ")" ~ "->" ~ "(" ~ type_list ~ ")" ~ "{" ~ NEWLINE* ~ statement* ~ "}"} +function_definition = {"function" ~ function_name ~ "(" ~ parameter_list ~ ")" ~ "->" ~ "(" ~ type_list ~ ")" ~ "{" ~ NEWLINE* ~ statement* ~ NEWLINE* ~ "}" ~ NEWLINE* } /// Utilities diff --git a/src/program/constraints/constraints.rs b/src/program/constraints/constraints.rs index f76558b68c..057cbf6676 100644 --- a/src/program/constraints/constraints.rs +++ b/src/program/constraints/constraints.rs @@ -94,14 +94,21 @@ impl> ResolvedProgram { .clone() .into_iter() .for_each(|statement| match statement { - Statement::Definition(variable, expression) => { - self.enforce_definition_statement( + Statement::Return(expressions) => { + return_values = self.enforce_return_statement( cs, function.get_name(), - variable, - expression, - ); + expressions, + function.returns.to_owned(), + ) } + Statement::MultipleDefinition(assignees, function_call) => self + .enforce_multiple_definition_statement( + cs, + function.get_name(), + assignees, + function_call, + ), Statement::For(index, start, stop, statements) => { self.enforce_for_statement( cs, @@ -112,13 +119,13 @@ impl> ResolvedProgram { statements, ); } - Statement::Return(expressions) => { - return_values = self.enforce_return_statement( + Statement::Definition(variable, expression) => { + self.enforce_definition_statement( cs, function.get_name(), - expressions, - function.returns.to_owned(), - ) + variable, + expression, + ); } }); diff --git a/src/program/constraints/expression.rs b/src/program/constraints/expression.rs index 6b832c1c92..423b9c21a5 100644 --- a/src/program/constraints/expression.rs +++ b/src/program/constraints/expression.rs @@ -5,7 +5,9 @@ //! @date 2020 use crate::program::constraints::{new_scope_from_variable, ResolvedProgram, ResolvedValue}; -use crate::program::{Expression, RangeOrExpression, SpreadOrExpression, StructMember, Variable}; +use crate::program::{ + Expression, RangeOrExpression, ResolvedStructMember, SpreadOrExpression, StructMember, Variable, +}; use snarkos_models::curves::{Field, PrimeField}; use snarkos_models::gadgets::r1cs::ConstraintSystem; @@ -225,22 +227,24 @@ impl> ResolvedProgram { if let Some(resolved_value) = self.get_mut_variable(&variable) { match resolved_value { ResolvedValue::StructDefinition(struct_definition) => { - struct_definition + let resolved_members = struct_definition .fields .clone() .iter() .zip(members.clone().into_iter()) - .for_each(|(field, member)| { + .map(|(field, member)| { if field.variable != member.variable { unimplemented!("struct field variables do not match") } - // Resolve and possibly enforce struct fields - // do we need to store the results here? - let _result = + // Resolve and enforce struct fields + let member_value = self.enforce_expression(cs, scope.clone(), member.expression); - }); - ResolvedValue::StructExpression(variable, members) + ResolvedStructMember(member.variable, member_value) + }) + .collect(); + + ResolvedValue::StructExpression(variable, resolved_members) } _ => unimplemented!("Inline struct type is not defined as a struct"), } @@ -258,11 +262,9 @@ impl> ResolvedProgram { ) -> ResolvedValue { match self.enforce_expression(cs, scope.clone(), *struct_variable) { ResolvedValue::StructExpression(_name, members) => { - let matched_member = members - .into_iter() - .find(|member| member.variable == struct_member); + let matched_member = members.into_iter().find(|member| member.0 == struct_member); match matched_member { - Some(member) => self.enforce_expression(cs, scope.clone(), member.expression), + Some(member) => member.1, None => unimplemented!("Cannot access struct member {}", struct_member.name), } } @@ -273,13 +275,17 @@ impl> ResolvedProgram { fn enforce_function_access_expression( &mut self, cs: &mut CS, - scope: String, - function: Box>, + function: &Variable, arguments: Vec>, ) -> ResolvedValue { - match self.enforce_expression(cs, scope, *function) { - ResolvedValue::Function(function) => self.enforce_function(cs, function, arguments), - value => unimplemented!("Cannot call unknown function {}", value), + match self.get_mut_variable(function) { + Some(value) => match value.clone() { + ResolvedValue::Function(function) => { + self.enforce_function(cs, function.to_owned(), arguments) + } + value => unimplemented!("Cannot make function call to {}", value), + }, + None => unimplemented!("Cannot call unknown function {}", function), } } @@ -397,7 +403,7 @@ impl> ResolvedProgram { // Functions Expression::FunctionCall(function, arguments) => { - self.enforce_function_access_expression(cs, scope, function, arguments) + self.enforce_function_access_expression(cs, &function, arguments) } // _ => unimplemented!(), } } diff --git a/src/program/constraints/resolved_value.rs b/src/program/constraints/resolved_value.rs index 65ec2c278e..b159c0ab10 100644 --- a/src/program/constraints/resolved_value.rs +++ b/src/program/constraints/resolved_value.rs @@ -4,7 +4,7 @@ //! @author Collin Chin //! @date 2020 -use crate::program::types::{Function, Struct, StructMember, Type, Variable}; +use crate::program::types::{Function, Struct, Type, Variable}; use snarkos_models::curves::{Field, PrimeField}; use snarkos_models::gadgets::{utilities::boolean::Boolean, utilities::uint32::UInt32}; @@ -17,11 +17,14 @@ pub enum ResolvedValue { Boolean(Boolean), Array(Vec>), StructDefinition(Struct), - StructExpression(Variable, Vec>), + StructExpression(Variable, Vec>), Function(Function), Return(Vec>), // add Null for function returns } +#[derive(Clone)] +pub struct ResolvedStructMember(pub Variable, pub ResolvedValue); + impl ResolvedValue { pub(crate) fn match_type(&self, ty: &Type) -> bool { match (self, ty) { @@ -64,7 +67,7 @@ impl fmt::Display for ResolvedValue { ResolvedValue::StructExpression(ref variable, ref members) => { write!(f, "{} {{", variable)?; for (i, member) in members.iter().enumerate() { - write!(f, "{}: {}", member.variable, member.expression)?; + write!(f, "{}: {}", member.0, member.1)?; if i < members.len() - 1 { write!(f, ", ")?; } diff --git a/src/program/constraints/statement.rs b/src/program/constraints/statement.rs index 4b0f02c8fc..8e85da2aab 100644 --- a/src/program/constraints/statement.rs +++ b/src/program/constraints/statement.rs @@ -21,28 +21,21 @@ impl> ResolvedProgram { } } - pub(crate) fn enforce_definition_statement( + fn enforce_definition( &mut self, cs: &mut CS, scope: String, assignee: Assignee, - expression: Expression, + return_value: &mut ResolvedValue, ) { - // Create or modify the lhs variable in the current function scope match assignee { Assignee::Variable(name) => { // Store the variable in the current scope let definition_name = new_scope_from_variable(scope.clone(), &name); - // Evaluate the rhs expression in the current function scope - let result = self.enforce_expression(cs, scope, expression); - - self.store(definition_name, result); + self.store(definition_name, return_value.to_owned()); } Assignee::Array(array, index_expression) => { - // Evaluate the rhs expression in the current function scope - let result = &mut self.enforce_expression(cs, scope.clone(), expression); - // Check that array exists let expected_array_name = self.resolve_assignee(scope.clone(), *array); @@ -55,7 +48,7 @@ impl> ResolvedProgram { match self.get_mut(&expected_array_name) { Some(value) => match value { ResolvedValue::Array(old) => { - old[index] = result.to_owned(); + old[index] = return_value.to_owned(); } _ => { unimplemented!("Cannot assign single index to array of values ") @@ -79,7 +72,7 @@ impl> ResolvedProgram { // Modify the range of values of the array in place match self.get_mut(&expected_array_name) { - Some(value) => match (value, result) { + Some(value) => match (value, return_value) { (ResolvedValue::Array(old), ResolvedValue::Array(new)) => { let to_index = to_index_option.unwrap_or(old.len()); old.splice(from_index..to_index, new.iter().cloned()); @@ -104,11 +97,10 @@ impl> ResolvedProgram { Some(value) => match value { ResolvedValue::StructExpression(_variable, members) => { // Modify the struct member in place - let matched_member = members - .into_iter() - .find(|member| member.variable == struct_member); + let matched_member = + members.into_iter().find(|member| member.0 == struct_member); match matched_member { - Some(mut member) => member.expression = expression, + Some(mut member) => member.1 = return_value.to_owned(), None => unimplemented!( "struct member {} does not exist in {}", struct_member, @@ -126,7 +118,43 @@ impl> ResolvedProgram { } } } + } + } + + pub(crate) fn enforce_definition_statement( + &mut self, + cs: &mut CS, + scope: String, + assignee: Assignee, + expression: Expression, + ) { + let result_value = &mut self.enforce_expression(cs, scope.clone(), expression); + + self.enforce_definition(cs, scope, assignee, result_value); + } + + pub(crate) fn enforce_multiple_definition_statement( + &mut self, + cs: &mut CS, + scope: String, + assignees: Vec>, + function: Expression, + ) { + // Expect return values from function + let return_values = match self.enforce_expression(cs, scope.clone(), function) { + ResolvedValue::Return(values) => values, + value => unimplemented!( + "multiple assignment only implemented for functions, got {}", + value + ), }; + + assignees + .into_iter() + .zip(return_values.into_iter()) + .for_each(|(assignee, mut return_value)| { + self.enforce_definition(cs, scope.clone(), assignee, &mut return_value); + }); } pub(crate) fn enforce_return_statement( @@ -160,15 +188,18 @@ impl> ResolvedProgram { return_types: Vec>, ) { match statement { - Statement::Definition(variable, expression) => { - self.enforce_definition_statement(cs, scope, variable, expression); + Statement::Return(statements) => { + // TODO: add support for early termination + let _res = self.enforce_return_statement(cs, scope, statements, return_types); } Statement::For(index, start, stop, statements) => { self.enforce_for_statement(cs, scope, index, start, stop, statements); } - Statement::Return(statements) => { - // TODO: add support for early termination - let _res = self.enforce_return_statement(cs, scope, statements, return_types); + Statement::MultipleDefinition(assignees, function) => { + self.enforce_multiple_definition_statement(cs, scope, assignees, function); + } + Statement::Definition(variable, expression) => { + self.enforce_definition_statement(cs, scope, variable, expression); } }; } diff --git a/src/program/types.rs b/src/program/types.rs index d8fda43738..c1c0fe4c51 100644 --- a/src/program/types.rs +++ b/src/program/types.rs @@ -90,7 +90,7 @@ pub enum Expression { StructMemberAccess(Box>, Variable), // (struct name, struct member name) // Functions - FunctionCall(Box>, Vec>), + FunctionCall(Variable, Vec>), } /// Definition assignee: v, arr[0..2], Point p.x @@ -105,9 +105,10 @@ pub enum Assignee { #[derive(Clone)] pub enum Statement { // Declaration(Variable), + Return(Vec>), Definition(Assignee, Expression), For(Variable, Integer, Integer, Vec>), - Return(Vec>), + MultipleDefinition(Vec>, Expression), } /// Explicit type used for defining struct members and function parameters diff --git a/src/program/types_display.rs b/src/program/types_display.rs index baf974a6a2..a3609547f1 100644 --- a/src/program/types_display.rs +++ b/src/program/types_display.rs @@ -145,8 +145,24 @@ impl fmt::Display for Assignee { impl fmt::Display for Statement { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Statement::Definition(ref variable, ref statement) => { - write!(f, "{} = {}", variable, statement) + Statement::Return(ref statements) => { + write!(f, "return ")?; + for (i, value) in statements.iter().enumerate() { + write!(f, "{}", value)?; + if i < statements.len() - 1 { + write!(f, ", ")?; + } + } + write!(f, "\n") + } + Statement::MultipleDefinition(ref assignees, ref function) => { + for (i, id) in assignees.iter().enumerate() { + write!(f, "{}", id)?; + if i < assignees.len() - 1 { + write!(f, ", ")?; + } + } + write!(f, " = {}", function) } Statement::For(ref var, ref start, ref stop, ref list) => { write!(f, "for {} in {}..{} do\n", var, start, stop)?; @@ -155,11 +171,8 @@ impl fmt::Display for Statement { } write!(f, "\tendfor") } - Statement::Return(ref statements) => { - statements.iter().for_each(|statement| { - write!(f, "return {}", statement).unwrap(); - }); - write!(f, "\n") + Statement::Definition(ref variable, ref statement) => { + write!(f, "{} = {}", variable, statement) } } } @@ -168,8 +181,24 @@ impl fmt::Display for Statement { impl fmt::Debug for Statement { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - Statement::Definition(ref variable, ref statement) => { - write!(f, "{} = {}", variable, statement) + Statement::Return(ref statements) => { + write!(f, "return ")?; + for (i, value) in statements.iter().enumerate() { + write!(f, "{}", value)?; + if i < statements.len() - 1 { + write!(f, ", ")?; + } + } + write!(f, "\n") + } + Statement::MultipleDefinition(ref assignees, ref function) => { + for (i, id) in assignees.iter().enumerate() { + write!(f, "{}", id)?; + if i < assignees.len() - 1 { + write!(f, ", ")?; + } + } + write!(f, " = {}()", function) } Statement::For(ref var, ref start, ref stop, ref list) => { write!(f, "for {:?} in {:?}..{:?} do\n", var, start, stop)?; @@ -178,11 +207,8 @@ impl fmt::Debug for Statement { } write!(f, "\tendfor") } - Statement::Return(ref statements) => { - statements.iter().for_each(|statement| { - write!(f, "return {}", statement).unwrap(); - }); - write!(f, "\n") + Statement::Definition(ref variable, ref statement) => { + write!(f, "{} = {}", variable, statement) } } } diff --git a/src/program/types_from.rs b/src/program/types_from.rs index 5aa64de313..bce72baeb8 100644 --- a/src/program/types_from.rs +++ b/src/program/types_from.rs @@ -209,8 +209,8 @@ impl<'ast, F: Field + PrimeField> From> for types:: .into_iter() .fold(variable, |acc, access| match access { ast::Access::Call(function) => match acc { - types::Expression::Variable(_) => types::Expression::FunctionCall( - Box::new(acc), + types::Expression::Variable(variable) => types::Expression::FunctionCall( + variable, function .expressions .into_iter() @@ -349,24 +349,6 @@ impl<'ast, F: Field + PrimeField> From> for types::Assignee< /// pest ast -> types::Statement -impl<'ast, F: Field + PrimeField> From> for types::Statement { - fn from(statement: ast::AssignStatement<'ast>) -> Self { - types::Statement::Definition( - types::Assignee::from(statement.assignee), - types::Expression::from(statement.expression), - ) - } -} - -impl<'ast, F: Field + PrimeField> From> for types::Statement { - fn from(statement: ast::DefinitionStatement<'ast>) -> Self { - types::Statement::Definition( - types::Assignee::from(statement.variable), - types::Expression::from_type(statement.ty, statement.expression), - ) - } -} - impl<'ast, F: Field + PrimeField> From> for types::Statement { fn from(statement: ast::ReturnStatement<'ast>) -> Self { types::Statement::Return( @@ -403,13 +385,56 @@ impl<'ast, F: Field + PrimeField> From> for types::State } } +impl<'ast, F: Field + PrimeField> From> + for types::Statement +{ + fn from(statement: ast::MultipleAssignmentStatement<'ast>) -> Self { + let assignees = statement + .assignees + .into_iter() + .map(|i| types::Assignee::Variable(types::Variable::from(i.id))) + .collect(); + + types::Statement::MultipleDefinition( + assignees, + types::Expression::FunctionCall( + types::Variable::from(statement.function_name), + statement + .arguments + .into_iter() + .map(|e| types::Expression::from(e)) + .collect(), + ), + ) + } +} + +impl<'ast, F: Field + PrimeField> From> for types::Statement { + fn from(statement: ast::AssignStatement<'ast>) -> Self { + types::Statement::Definition( + types::Assignee::from(statement.assignee), + types::Expression::from(statement.expression), + ) + } +} + +impl<'ast, F: Field + PrimeField> From> for types::Statement { + fn from(statement: ast::DefinitionStatement<'ast>) -> Self { + types::Statement::Definition( + types::Assignee::from(statement.variable), + types::Expression::from_type(statement.ty, statement.expression), + ) + } +} + impl<'ast, F: Field + PrimeField> From> for types::Statement { fn from(statement: ast::Statement<'ast>) -> Self { match statement { + ast::Statement::Return(statement) => types::Statement::from(statement), + ast::Statement::Iteration(statement) => types::Statement::from(statement), + ast::Statement::MultipleAssignment(statement) => types::Statement::from(statement), ast::Statement::Assign(statement) => types::Statement::from(statement), ast::Statement::Definition(statement) => types::Statement::from(statement), - ast::Statement::Iteration(statement) => types::Statement::from(statement), - ast::Statement::Return(statement) => types::Statement::from(statement), } } }