Fix flattening logic

This commit is contained in:
d0cd 2022-11-12 16:23:16 -08:00
parent 8048d2754b
commit 682b67e184
10 changed files with 74 additions and 55 deletions

View File

@ -302,7 +302,7 @@ pub trait StatementReconstructor: ExpressionReconstructor {
(
Statement::Return(ReturnStatement {
expression: self.reconstruct_expression(input.expression).0,
finalize_args: input.finalize_args.map(|arguments| {
finalize_arguments: input.finalize_arguments.map(|arguments| {
arguments
.into_iter()
.map(|argument| self.reconstruct_expression(argument).0)

View File

@ -190,7 +190,7 @@ pub trait StatementVisitor<'a>: ExpressionVisitor<'a> {
fn visit_return(&mut self, input: &'a ReturnStatement) {
self.visit_expression(&input.expression, &Default::default());
if let Some(arguments) = &input.finalize_args {
if let Some(arguments) = &input.finalize_arguments {
arguments.iter().for_each(|argument| {
self.visit_expression(argument, &Default::default());
})

View File

@ -26,7 +26,7 @@ pub struct ReturnStatement {
/// The expression to return to the function caller.
pub expression: Expression,
/// Arguments to the finalize block.
pub finalize_args: Option<Vec<Expression>>,
pub finalize_arguments: Option<Vec<Expression>>,
/// The span of `return expression` excluding the semicolon.
pub span: Span,
}

View File

@ -142,7 +142,7 @@ impl ParserContext<'_> {
Ok(ReturnStatement {
span,
expression,
finalize_args,
finalize_arguments: finalize_args,
})
}

View File

@ -18,7 +18,8 @@ use crate::CodeGenerator;
use leo_ast::{
AssignStatement, Block, ConditionalStatement, ConsoleFunction, ConsoleStatement, DecrementStatement,
DefinitionStatement, Expression, IncrementStatement, IterationStatement, Mode, Output, ReturnStatement, Statement,
DefinitionStatement, Expression, ExpressionStatement, IncrementStatement, IterationStatement, Mode, Output,
ReturnStatement, Statement,
};
use itertools::Itertools;
@ -102,7 +103,7 @@ impl<'a> CodeGenerator<'a> {
// Output a finalize instruction if needed.
// TODO: Check formatting.
if let Some(arguments) = &input.finalize_args {
if let Some(arguments) = &input.finalize_arguments {
let mut finalize_instruction = "\n finalize".to_string();
for argument in arguments.iter() {
@ -152,7 +153,6 @@ impl<'a> CodeGenerator<'a> {
instructions
}
fn visit_assign(&mut self, input: &'a AssignStatement) -> String {
match (&input.place, &input.value) {
(Expression::Identifier(identifier), _) => {

View File

@ -16,7 +16,7 @@
use crate::Flattener;
use leo_ast::{Finalize, Function, ProgramReconstructor, ReturnStatement, Statement, StatementReconstructor, Type};
use leo_ast::{Finalize, Function, ProgramReconstructor, StatementReconstructor, Type};
impl ProgramReconstructor for Flattener<'_> {
/// Flattens a function's body and finalize block, if it exists.
@ -37,12 +37,9 @@ impl ProgramReconstructor for Flattener<'_> {
// Get all of the guards and return expression.
let returns = self.clear_early_returns();
// If the finalize block contains return statements, then we fold them into a single return statement.
// Fold the return statements into the block.
self.fold_returns(&mut block, returns);
// Initialize `self.finalizes` with the appropriate number of vectors.
self.finalizes = vec![vec![]; finalize.input.len()];
Finalize {
identifier: finalize.identifier,
input: finalize.input,
@ -65,14 +62,11 @@ impl ProgramReconstructor for Flattener<'_> {
let mut block = self.reconstruct_block(function.block).0;
// Get all of the guards and return expression.
// TODO: Verify that there is always at least one
let returns = self.clear_early_returns();
// If the function contains return statements, then we fold them into a single return statement.
// Fold the return statements into the block.
self.fold_returns(&mut block, returns);
Function {
annotations: function.annotations,
call_type: function.call_type,

View File

@ -19,9 +19,9 @@ use itertools::Itertools;
use std::borrow::Borrow;
use leo_ast::{
AssignStatement, BinaryExpression, BinaryOperation, Block, ConditionalStatement, DefinitionStatement, Expression,
ExpressionReconstructor, IterationStatement, Node, ReturnStatement, Statement, StatementReconstructor,
UnaryExpression, UnaryOperation,
AssignStatement, BinaryExpression, BinaryOperation, Block, ConditionalStatement, ConsoleFunction, ConsoleStatement,
DefinitionStatement, Expression, ExpressionReconstructor, Identifier, IterationStatement, Node, ReturnStatement,
Statement, StatementReconstructor, TupleExpression, Type, UnaryExpression, UnaryOperation,
};
impl StatementReconstructor for Flattener<'_> {
@ -381,26 +381,24 @@ impl StatementReconstructor for Flattener<'_> {
// Add it to `self.returns`.
// Note that SSA guarantees that `input.expression` is either a literal or identifier.
match input.expression {
// If the input is an identifier that maps to a tuple, add the corresponding tuple to `self.returns`
// If the input is an identifier that maps to a tuple,
// construct a `ReturnStatement` with the tuple and add it to `self.returns`
Expression::Identifier(identifier) if self.tuples.contains_key(&identifier.name) => {
// Note that the `unwrap` is safe since the match arm checks that the entry exists in `self.tuples`.
let tuple = self.tuples.get(&identifier.name).unwrap().clone();
self.returns.push((guard.clone(), Expression::Tuple(tuple)))
self.returns.push((
guard,
ReturnStatement {
span: input.span,
expression: Expression::Tuple(tuple),
finalize_arguments: input.finalize_arguments,
},
));
}
// Otherwise, add the expression directly.
_ => self.returns.push((guard.clone(), input.expression)),
_ => self.returns.push((guard, input)),
};
// Add each finalize argument to the list of finalize arguments.
if let Some(arguments) = input.finalize_args {
// For each finalize argument, add it and its associated guard to the appropriate list of finalize arguments.
// Note that type checking guarantees that the number of arguments in a finalize statement is equal to the number of arguments in to the finalize block.
for (i, argument) in arguments.into_iter().enumerate() {
// Note that this unwrap is safe since we initialize `self.finalizes` with a number of vectors equal to the number of finalize arguments.
self.finalizes.get_mut(i).unwrap().push((guard.clone(), argument));
}
}
(Statement::dummy(Default::default()), Default::default())
}
}

View File

@ -37,12 +37,7 @@ pub struct Flattener<'a> {
/// A guard is an expression that evaluates to true on the execution path of the `ReturnStatement`.
/// Note that returns are inserted in the order they are encountered during a pre-order traversal of the AST.
/// Note that type checking guarantees that there is at most one return in a basic block.
pub(crate) returns: Vec<(Option<Expression>, Expression)>,
/// A list containing tuples of guards and expressions associated with finalize arguments.
/// A guard is an expression that evaluates to true on the execution path of the finalize argument.
/// Note that finalizes are inserted in the order they are encountered during a pre-order traversal of the AST.
/// Note that type checking guarantees that there is at most one finalize in a basic block.
pub(crate) finalizes: Vec<Vec<(Option<Expression>, Expression)>>,
pub(crate) returns: Vec<(Option<Expression>, ReturnStatement)>,
/// A mapping between variables and flattened tuple expressions.
pub(crate) tuples: IndexMap<Symbol, TupleExpression>,
}
@ -55,21 +50,15 @@ impl<'a> Flattener<'a> {
structs: IndexMap::new(),
condition_stack: Vec::new(),
returns: Vec::new(),
finalizes: Vec::new(),
tuples: IndexMap::new(),
}
}
/// Clears the state associated with `ReturnStatements`, returning the ones that were previously stored.
pub(crate) fn clear_early_returns(&mut self) -> Vec<(Option<Expression>, Expression)> {
pub(crate) fn clear_early_returns(&mut self) -> Vec<(Option<Expression>, ReturnStatement)> {
core::mem::take(&mut self.returns)
}
/// Clears the state associated with `FinalizeStatements`, returning the ones that were previously stored.
pub(crate) fn clear_early_finalizes(&mut self) -> Vec<Vec<(Option<Expression>, Expression)>> {
core::mem::take(&mut self.finalizes)
}
/// Constructs a guard from the current state of the condition stack.
pub(crate) fn construct_guard(&mut self) -> Option<Expression> {
match self.condition_stack.is_empty() {
@ -197,18 +186,56 @@ impl<'a> Flattener<'a> {
}
/// Folds a list of return statements into a single return statement and adds the produced statements to the block.
pub(crate) fn fold_returns(&mut self, block: &mut Block, returns: Vec<(Option<Expression>, Expression)>) {
pub(crate) fn fold_returns(&mut self, block: &mut Block, returns: Vec<(Option<Expression>, ReturnStatement)>) {
if !returns.is_empty() {
let (expression, stmts) = self.fold_guards("ret$", returns);
let mut return_expressions = Vec::with_capacity(returns.len());
// TODO: Flatten tuples in the return statements once they are allowed.
// Construct a vector for each argument position.
// Note that the indexing is safe since we check that `returns` is not empty.
let (has_finalize, number_of_finalize_arguments) = match &returns[0].1.finalize_arguments {
None => (false, 0),
Some(args) => (true, args.len()),
};
let mut finalize_arguments: Vec<Vec<(Option<Expression>, Expression)>> =
Vec::with_capacity(number_of_finalize_arguments);
// Aggregate the return expressions and finalize arguments and their respective guards.
for (guard, return_statement) in returns {
return_expressions.push((guard.clone(), return_statement.expression));
if let Some(arguments) = return_statement.finalize_arguments {
for (i, argument) in arguments.into_iter().enumerate() {
// Note that the indexing is safe since we initialize `finalize_arguments` with the correct length.
finalize_arguments[i].push((guard.clone(), argument));
}
}
}
// Fold the return expressions into a single expression.
let (expression, stmts) = self.fold_guards("$ret", return_expressions);
// Add all of the accumulated statements to the end of the block.
block.statements.extend(stmts);
// For each position in the finalize call, fold the corresponding arguments into a single expression.
let finalize_arguments = match has_finalize {
false => None,
true => Some(
finalize_arguments
.into_iter()
.enumerate()
.map(|(i, arguments)| {
let (expression, stmts) = self.fold_guards(&format!("finalize${i}$"), arguments);
block.statements.extend(stmts);
expression
})
.collect(),
),
};
// Add the `ReturnStatement` to the end of the block.
block.statements.push(Statement::Return(ReturnStatement {
expression,
finalize_arguments,
span: Default::default(),
}));
}

View File

@ -17,9 +17,10 @@
use crate::{RenameTable, StaticSingleAssigner};
use leo_ast::{
AssignStatement, Block, ConditionalStatement, ConsoleFunction, ConsoleStatement, DecrementStatement,
DefinitionStatement, Expression, ExpressionConsumer, Identifier, IncrementStatement, IterationStatement,
ReturnStatement, Statement, StatementConsumer, TernaryExpression,
AssignStatement, Block, CallExpression, ConditionalStatement, ConsoleFunction, ConsoleStatement,
DecrementStatement, DefinitionStatement, Expression, ExpressionConsumer, ExpressionStatement, Identifier,
IncrementStatement, IterationStatement, ReturnStatement, Statement, StatementConsumer, TernaryExpression,
TupleExpression,
};
use leo_span::Symbol;
@ -328,7 +329,7 @@ impl StatementConsumer for StaticSingleAssigner<'_> {
// Consume the finalize arguments if they exist.
// Process the arguments, accumulating any statements produced.
let finalize_args = input.finalize_args.map(|arguments| {
let finalize_args = input.finalize_arguments.map(|arguments| {
arguments
.into_iter()
.map(|argument| {
@ -342,7 +343,7 @@ impl StatementConsumer for StaticSingleAssigner<'_> {
// Add the simplified return statement to the list of produced statements.
statements.push(Statement::Return(ReturnStatement {
expression,
finalize_args,
finalize_arguments: finalize_args,
span: input.span,
}));

View File

@ -397,7 +397,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
// Unset the `is_return` flag.
self.is_return = false;
if let Some(arguments) = &input.finalize_args {
if let Some(arguments) = &input.finalize_arguments {
if self.is_finalize {
self.emit_err(TypeCheckerError::finalize_in_finalize(input.span()));
}
@ -438,6 +438,5 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
}
}
}
}
}