Implement flattening phase

This commit is contained in:
Pranav Gaddamadugu 2022-09-01 12:36:56 -07:00
parent f74bfb034c
commit 9d5aa9d08d
12 changed files with 210 additions and 115 deletions

View File

@ -48,6 +48,10 @@ impl<'a> CodeGenerator<'a> {
(self.variable_mapping.get(&input.name).unwrap().clone(), String::new())
}
fn visit_err(&mut self, _input: &'a ErrExpression) -> (String, String) {
unreachable!("`ErrExpression`s should not be in the AST at this phase of compilation.")
}
fn visit_value(&mut self, input: &'a Literal) -> (String, String) {
(format!("{}", input), String::new())
}

View File

@ -14,24 +14,37 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use crate::{Flattener};
use crate::Flattener;
use itertools::Itertools;
use leo_ast::{
AccessExpression, CircuitExpression, CircuitMember,
CircuitVariableInitializer, ErrExpression, Expression, ExpressionReconstructor,
MemberAccess, Statement, TernaryExpression, TupleExpression,
AccessExpression, CircuitExpression, CircuitMember, CircuitVariableInitializer, Expression,
ExpressionReconstructor, MemberAccess, Statement, TernaryExpression, TupleExpression,
};
// TODO: Document
impl ExpressionReconstructor for Flattener<'_> {
type AdditionalOutput = Vec<Statement>;
/// Reconstructs ternary expressions over circuits, accumulating any statements that are generated.
/// Reconstructs ternary expressions over tuples and circuits, accumulating any statements that are generated.
/// This is necessary because Aleo instructions does not support ternary expressions over composite data types.
/// For example, the ternary expression `cond ? (a, b) : (c, d)` is flattened into the following:
/// ```leo
/// let var$0 = cond ? a : c;
/// let var$1 = cond ? b : d;
/// (var$0, var$1)
/// ```
/// For circuits, the ternary expression `cond ? a : b`, where `a` and `b` are both circuits `Foo { bar: u8, baz: u8 }`, is flattened into the following:
/// ```leo
/// let var$0 = cond ? a.bar : b.bar;
/// let var$1 = cond ? a.baz : b.baz;
/// let var$2 = Foo { bar: var$0, baz: var$1 };
/// var$2
/// ```
fn reconstruct_ternary(&mut self, input: TernaryExpression) -> (Expression, Self::AdditionalOutput) {
let mut statements = Vec::new();
match (*input.if_true, *input.if_false) {
// Folds ternary expressions over tuples into a tuple of ternary expression.
// Note that this branch is only invoked when folding a conditional returns.
(Expression::Tuple(first), Expression::Tuple(second)) => {
let tuple = Expression::Tuple(TupleExpression {
elements: first
@ -39,26 +52,33 @@ impl ExpressionReconstructor for Flattener<'_> {
.into_iter()
.zip_eq(second.elements.into_iter())
.map(|(if_true, if_false)| {
// Construct a new ternary expression for the tuple element.
let (ternary, stmts) = self.reconstruct_ternary(TernaryExpression {
condition: input.condition,
condition: input.condition.clone(),
if_true: Box::new(if_true),
if_false: Box::new(if_false),
span: input.span,
});
// Accumulate any statements generated.
statements.extend(stmts);
ternary
// Create and accumulate an intermediate assignment statement for the ternary expression corresponding to the tuple element.
let (identifier, statement) = self.assigner.unique_simple_assign_statement(ternary);
statements.push(statement);
// Return the identifier associated with the folded tuple element.
identifier
})
.collect(),
span: Default::default(),
});
(tuple, statements)
}
// If the `true` and `false` cases are circuits, handle them, appropriately.
// Note that type checking guarantees that both expressions have the same same type.
// If both expressions are circuits, construct ternary expression for each of the members and a circuit expression for the result.
(Expression::Identifier(first), Expression::Identifier(second))
if self.circuits.contains_key(&first.name) && self.circuits.contains_key(&second.name) =>
{
// TODO: Document.
let first_circuit = self
.symbol_table
.lookup_circuit(*self.circuits.get(&first.name).unwrap())
@ -67,6 +87,7 @@ impl ExpressionReconstructor for Flattener<'_> {
.symbol_table
.lookup_circuit(*self.circuits.get(&second.name).unwrap())
.unwrap();
// Note that type checking guarantees that both expressions have the same same type.
assert_eq!(first_circuit, second_circuit);
// For each circuit member, construct a new ternary expression.
@ -74,8 +95,9 @@ impl ExpressionReconstructor for Flattener<'_> {
.members
.iter()
.map(|CircuitMember::CircuitVariable(id, _)| {
// Construct a new ternary expression for the circuit member.
let (expression, stmts) = self.reconstruct_ternary(TernaryExpression {
condition: input.condition,
condition: input.condition.clone(),
if_true: Box::new(Expression::Access(AccessExpression::Member(MemberAccess {
inner: Box::new(Expression::Identifier(first)),
name: *id,
@ -88,11 +110,17 @@ impl ExpressionReconstructor for Flattener<'_> {
}))),
span: Default::default(),
});
// Accumulate any statements generated.
statements.extend(stmts);
// Create and accumulate an intermediate assignment statement for the ternary expression corresponding to the circuit member.
let (identifier, statement) = self.assigner.unique_simple_assign_statement(expression);
statements.push(statement);
CircuitVariableInitializer {
identifier: *id,
expression: Some(expression),
expression: Some(identifier),
}
})
.collect();
@ -103,17 +131,40 @@ impl ExpressionReconstructor for Flattener<'_> {
span: Default::default(),
});
// Accumulate any statements generated.
statements.extend(stmts);
(expr, statements)
// Create a new assignment statement for the circuit expression.
let (identifier, statement) = self.assigner.unique_simple_assign_statement(expr);
// Mark the lhs of the assignment as a circuit.
match identifier {
Expression::Identifier(identifier) => {
self.circuits.insert(identifier.name, first_circuit.identifier.name)
}
_ => unreachable!(
"`unique_simple_assign_statement` always produces an identifier on the left hand size."
),
};
statements.push(statement);
(identifier, statements)
}
// Otherwise, create a new intermediate assignment for the ternary expression are return the assigned variable.
// Note that a new assignment must be created to flattened nested ternary expressions.
(if_true, if_false) => {
let (identifier, statement) =
self.assigner
.unique_simple_assign_statement(Expression::Ternary(TernaryExpression {
condition: input.condition,
if_true: Box::new(if_true),
if_false: Box::new(if_false),
span: input.span,
}));
(identifier, vec![statement])
}
// Otherwise, return the original expression.
(if_true, if_false) => (Expression::Ternary(TernaryExpression {
condition: input.condition,
if_true: Box::new(if_true),
if_false: Box::new(if_false),
span: input.span,
}), Default::default())
}
}
}

View File

@ -14,14 +14,53 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use crate::{Flattener};
use crate::Flattener;
use leo_ast::{Finalize, FinalizeStatement, Function, ProgramReconstructor, ReturnStatement, Statement, StatementReconstructor};
use leo_ast::{
Finalize, FinalizeStatement, Function, ProgramReconstructor, ReturnStatement, Statement, StatementReconstructor,
};
// TODO: Document.
impl ProgramReconstructor for Flattener<'_> {
/// Flattens a function's body and finalize block, if it exists.
fn reconstruct_function(&mut self, function: Function) -> Function {
// First, flatten the finalize block. This allows us to initialize self.finalizes correctly.
// Note that this is safe since the finalize block is independent of the function body.
let finalize = function.finalize.map(|finalize| {
// Flatten the finalize block.
let mut block = self.reconstruct_block(finalize.block).0;
// Get all of the guards and return expression.
let returns = self.clear_early_returns();
// If the finalize block contains return statements, then we fold them into a single return statement.
if !returns.is_empty() {
let (expression, stmts) = self.fold_guards("ret$", returns);
// Add all of the accumulated statements to the end of the block.
block.statements.extend(stmts);
// Add the `ReturnStatement` to the end of the block.
block.statements.push(Statement::Return(ReturnStatement {
expression,
span: Default::default(),
}));
}
// Initialize `self.finalizes` with the appropriate number of vectors.
self.finalizes = vec![vec![]; finalize.input.len()];
Finalize {
input: finalize.input,
output: finalize.output,
output_type: finalize.output_type,
block,
span: finalize.span,
}
});
// Flatten the function body.
let mut block = self.reconstruct_block(function.block).0;
// Get all of the guards and return expression.
@ -29,7 +68,7 @@ impl ProgramReconstructor for Flattener<'_> {
// If the function contains return statements, then we fold them into a single return statement.
if !returns.is_empty() {
let (stmts, expression) = self.fold_guards("ret$", returns);
let (expression, stmts) = self.fold_guards("ret$", returns);
// Add all of the accumulated statements to the end of the block.
block.statements.extend(stmts);
@ -46,14 +85,22 @@ impl ProgramReconstructor for Flattener<'_> {
// If the function contains finalize statements, then we fold them into a single finalize statement.
if !finalizes.is_empty() {
let (stmts, expression) = self.fold_guards("fin$", finalizes);
let arguments = finalizes
.into_iter()
.enumerate()
.map(|(i, component)| {
let (expression, stmts) = self.fold_guards(format!("fin${i}$").as_str(), component);
// Add all of the accumulated statements to the end of the block.
block.statements.extend(stmts);
// Add all of the accumulated statements to the end of the block.
block.statements.extend(stmts);
expression
})
.collect();
// Add the `FinalizeStatement` to the end of the block.
block.statements.push(Statement::Finalize(FinalizeStatement {
expression,
arguments,
span: Default::default(),
}));
}
@ -65,34 +112,7 @@ impl ProgramReconstructor for Flattener<'_> {
output: function.output,
output_type: function.output_type,
block,
finalize: function.finalize.map(|finalize| {
let mut block = self.reconstruct_block(finalize.block).0;
// Get all of the guards and return expression.
let returns = self.clear_early_returns();
// If the function contains return statements, then we fold them into a single return statement.
if !returns.is_empty() {
let (stmts, expression) = self.fold_guards("ret$", returns);
// Add all of the accumulated statements to the end of the block.
block.statements.extend(stmts);
// Add the `ReturnStatement` to the end of the block.
block.statements.push(Statement::Return(ReturnStatement {
expression,
span: Default::default(),
}));
}
Finalize {
input: finalize.input,
output: finalize.output,
output_type: finalize.output_type,
block,
span: finalize.span
}
}),
finalize,
span: function.span,
}
}

View File

@ -14,40 +14,56 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use crate::{Flattener};
use crate::Flattener;
use leo_ast::{AssignStatement, BinaryExpression, BinaryOperation, Block, ConditionalStatement, DefinitionStatement, Expression, FinalizeStatement, IterationStatement, Node, ReturnStatement, Statement, StatementReconstructor, UnaryExpression, UnaryOperation};
use leo_ast::{
AssignStatement, BinaryExpression, BinaryOperation, Block, ConditionalStatement, DefinitionStatement, Expression,
ExpressionReconstructor, FinalizeStatement, IterationStatement, Node, ReturnStatement, Statement,
StatementReconstructor, UnaryExpression, UnaryOperation,
};
// TODO: Document
impl StatementReconstructor for Flattener<'_> {
/// Flattens an assign statement, if necessary.
/// Marks variables as circuits as necessary.
/// Note that new statements are only produced if the right hand side is a ternary expression over circuits.
/// Otherwise, the statement is returned as is.
fn reconstruct_assign(&mut self, assign: AssignStatement) -> (Statement, Self::AdditionalOutput) {
let lhs = match assign.place {
Expression::Identifier(identifier) => identifier,
_ => unreachable!("`AssignStatement`s can only have `Identifier`s on the left hand side."),
};
match &assign.value {
let (value, statements) = match assign.value {
// If the rhs of the assignment is a circuit, add it to `self.circuits`.
Expression::Circuit(rhs) => {
self.circuits.insert(lhs.name, rhs.name.name);
(Statement::Assign(Box::new(assign)), Default::default())
(Expression::Circuit(rhs), Default::default())
}
// If the rhs of the assignment is an identifier that is a circuit, add it to `self.circuits`.
Expression::Identifier(rhs) if self.circuits.contains_key(&rhs.name) => {
self.circuits.insert(lhs.name, rhs.name);
(Statement::Assign(Box::new(assign)), Default::default())
(Expression::Identifier(rhs), Default::default())
}
// If the rhs of the assignment is ternary expression, reconstruct it.
Expression::Ternary(ternary) => {
todo!()
}
Expression::Ternary(ternary) => self.reconstruct_ternary(ternary),
// Otherwise return the original statement.
_ => (Statement::Assign(Box::new(assign)), Default::default()),
}
value => (value, Default::default()),
};
(
Statement::Assign(Box::new(AssignStatement {
place: Expression::Identifier(lhs),
value,
span: assign.span,
})),
statements,
)
}
// TODO: Do we want to flatten nested blocks? They do not affect code generation but it would regularize the AST structure.
/// Flattens the statements inside a basic block.
/// The resulting block does not contain any conditional statements.
fn reconstruct_block(&mut self, block: Block) -> (Block, Self::AdditionalOutput) {
let mut statements = Vec::with_capacity(block.statements.len());
@ -58,10 +74,13 @@ impl StatementReconstructor for Flattener<'_> {
statements.push(reconstructed_statement);
}
(Block {
span: block.span,
statements,
}, Default::default())
(
Block {
span: block.span,
statements,
},
Default::default(),
)
}
/// Flatten a conditional statement into a list of statements.
@ -99,10 +118,13 @@ impl StatementReconstructor for Flattener<'_> {
(Statement::dummy(Default::default()), statements)
}
/// Static single assignment converts definition statements into assignment statements.
fn reconstruct_definition(&mut self, _definition: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
unreachable!("`DefinitionStatement`s should not exist in the AST at this phase of compilation.")
}
/// Replaces a finalize statement with an empty block statement.
/// Stores the arguments to the finalize statement, which are later folded into a single finalize statement at the end of the function.
fn reconstruct_finalize(&mut self, input: FinalizeStatement) -> (Statement, Self::AdditionalOutput) {
// Construct the associated guard.
let guard = match self.condition_stack.is_empty() {
@ -120,18 +142,22 @@ impl StatementReconstructor for Flattener<'_> {
}
};
// TODO: Add to finalize guards.
// Note that type checking guarantees that the number of arguments in a finalize statement is equal to the number of arguments in to the finalize block.
for (i, argument) in input.arguments.into_iter().enumerate() {
// Note that this unwrap is safe since we initialize `self.finalizes` with a number of vectors equal to the number of finalize arguments.
self.finalizes.get_mut(i).unwrap().push((guard.clone(), argument));
}
(Statement::dummy(Default::default()), Default::default())
}
// TODO: Error message
// TODO: Error message requesting the user to enable loop-unrolling.
fn reconstruct_iteration(&mut self, _input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
unreachable!("`IterationStatement`s should not be in the AST at this phase of compilation.");
}
/// Transforms a `ReturnStatement` into an empty `BlockStatement`,
/// storing the expression and the associated guard in `self.early_returns`.
/// Transforms a return statement into an empty block statement.
/// Stores the arguments to the return statement, which are later folded into a single return statement at the end of the function.
fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
// Construct the associated guard.
let guard = match self.condition_stack.is_empty() {
@ -149,7 +175,7 @@ impl StatementReconstructor for Flattener<'_> {
}
};
// TODO: Add to return guards.
self.returns.push((guard, input.expression));
(Statement::dummy(Default::default()), Default::default())
}

View File

@ -16,9 +16,7 @@
use crate::{Assigner, SymbolTable};
use leo_ast::{
Expression, ExpressionReconstructor, Identifier, Statement, TernaryExpression,
};
use leo_ast::{Expression, ExpressionReconstructor, Identifier, Statement, TernaryExpression};
use leo_span::Symbol;
use indexmap::IndexMap;
@ -32,12 +30,16 @@ pub struct Flattener<'a> {
pub(crate) circuits: IndexMap<Symbol, Symbol>,
/// A stack of condition `Expression`s visited up to the current point in the AST.
pub(crate) condition_stack: Vec<Expression>,
/// A list containing tuples of guards and expressions associated with early `ReturnStatement`s.
/// Note that early returns are inserted in the order they are encountered during a pre-order traversal of the AST.
pub(crate) early_returns: Vec<(Option<Expression>, Expression)>,
/// A list containing tuples of guards and expressions associated with early `FinalizeStatement`s.
/// Note that early finalizes are inserted in the order they are encountered during a pre-order traversal of the AST.
pub(crate) early_finalizes: Vec<(Option<Expression>, Expression)>,
/// A list containing tuples of guards and expressions associated `ReturnStatement`s.
/// A guard is an expression that evaluates to true on the execution path of the `ReturnStatement`.
/// Note that returns are inserted in the order they are encountered during a pre-order traversal of the AST.
/// Note that type checking guarantees that there is at most one return in a basic block.
pub(crate) returns: Vec<(Option<Expression>, Expression)>,
/// A list containing tuples of guards and expressions associated with `FinalizeStatement`s.
/// A guard is an expression that evaluates to true on the execution path of the `FinalizeStatement`.
/// Note that finalizes are inserted in the order they are encountered during a pre-order traversal of the AST.
/// Note that type checking guarantees that there is at most one finalize in a basic block.
pub(crate) finalizes: Vec<Vec<(Option<Expression>, Expression)>>,
}
impl<'a> Flattener<'a> {
@ -47,28 +49,29 @@ impl<'a> Flattener<'a> {
assigner,
circuits: IndexMap::new(),
condition_stack: Vec::new(),
early_returns: Vec::new(),
early_finalizes: Vec::new(),
returns: Vec::new(),
finalizes: Vec::new(),
}
}
/// Clears the state associated with `ReturnStatements`, returning the ones that were previously produced.
/// Clears the state associated with `ReturnStatements`, returning the ones that were previously stored.
pub(crate) fn clear_early_returns(&mut self) -> Vec<(Option<Expression>, Expression)> {
core::mem::take(&mut self.early_returns)
core::mem::take(&mut self.returns)
}
// Clears the state associated with `FinalizeStatements`, returning the ones that were previously produced.
pub(crate) fn clear_early_finalizes(&mut self) -> Vec<(Option<Expression>, Expression)> {
core::mem::take(&mut self.early_finalizes)
/// Clears the state associated with `FinalizeStatements`, returning the ones that were previously stored.
pub(crate) fn clear_early_finalizes(&mut self) -> Vec<Vec<(Option<Expression>, Expression)>> {
core::mem::take(&mut self.finalizes)
}
/// Fold guards and expressions into a single expression.
// TODO: Remove below assumption.
/// Note that this function assumes that at least one guard is present.
pub(crate) fn fold_guards(
&mut self,
prefix: &str,
mut guards: Vec<(Option<Expression>, Expression)>,
) -> (Vec<Statement>, Expression) {
) -> (Expression, Vec<Statement>) {
// Type checking guarantees that there exists at least one return statement in the function body.
let (_, last_expression) = guards.pop().unwrap();
@ -102,6 +105,6 @@ impl<'a> Flattener<'a> {
Some(guard) => construct_ternary_assignment(guard, expr, acc),
});
(statements, expression)
(expression, statements)
}
}

View File

@ -59,7 +59,7 @@ pub use flattener::*;
use crate::{Assigner, Pass, SymbolTable};
use leo_ast::{Ast, ProgramReconstructor};
use leo_errors::{Result};
use leo_errors::Result;
impl<'a> Pass for Flattener<'a> {
type Input = (Ast, &'a SymbolTable, Assigner);

View File

@ -19,17 +19,13 @@ use leo_span::Symbol;
use std::fmt::Display;
/// A struct used to create assignment statements.
#[derive(Default)]
pub struct Assigner {
/// A strictly increasing counter, used to ensure that new variable names are unique.
pub(crate) counter: usize,
}
impl Assigner {
/// Initializes a new assigner.
pub fn new() -> Self {
Self { counter: 0 }
}
/// Return a new unique `Symbol` from a `&str`.
pub(crate) fn unique_symbol(&mut self, arg: impl Display) -> Symbol {
self.counter += 1;

View File

@ -60,10 +60,10 @@ pub(crate) use rename_table::*;
pub mod static_single_assigner;
pub use static_single_assigner::*;
use crate::{Pass};
use crate::Pass;
use leo_ast::{Ast, ProgramConsumer};
use leo_errors::{Result};
use leo_errors::Result;
impl Pass for StaticSingleAssigner {
type Input = Ast;

View File

@ -18,8 +18,8 @@ use crate::StaticSingleAssigner;
use leo_ast::{
AccessExpression, AssociatedFunction, BinaryExpression, CallExpression, CircuitExpression,
CircuitVariableInitializer, ErrExpression, Expression, ExpressionConsumer, Identifier, Literal, MemberAccess,
Statement, TernaryExpression, TupleAccess, TupleExpression, UnaryExpression,
CircuitVariableInitializer, Expression, ExpressionConsumer, Identifier, Literal, MemberAccess, Statement,
TernaryExpression, TupleAccess, TupleExpression, UnaryExpression,
};
impl ExpressionConsumer for StaticSingleAssigner {

View File

@ -16,10 +16,7 @@
use crate::StaticSingleAssigner;
use leo_ast::{
Block, Finalize, Function, FunctionConsumer, Program, ProgramConsumer,
StatementConsumer,
};
use leo_ast::{Block, Finalize, Function, FunctionConsumer, Program, ProgramConsumer, StatementConsumer};
impl FunctionConsumer for StaticSingleAssigner {
type Output = Function;

View File

@ -17,9 +17,9 @@
use crate::{RenameTable, StaticSingleAssigner};
use leo_ast::{
AssignStatement, Block, ConditionalStatement, ConsoleFunction, ConsoleStatement,
DecrementStatement, DefinitionStatement, Expression, ExpressionConsumer, FinalizeStatement, Identifier,
IncrementStatement, IterationStatement, ReturnStatement, Statement, StatementConsumer, TernaryExpression,
AssignStatement, Block, ConditionalStatement, ConsoleFunction, ConsoleStatement, DecrementStatement,
DefinitionStatement, Expression, ExpressionConsumer, FinalizeStatement, Identifier, IncrementStatement,
IterationStatement, ReturnStatement, Statement, StatementConsumer, TernaryExpression,
};
use leo_span::Symbol;
@ -101,7 +101,7 @@ impl StatementConsumer for StaticSingleAssigner {
// Add reconstructed conditional statement to the list of produced statements.
statements.push(Statement::Conditional(ConditionalStatement {
span: conditional.span,
condition,
condition: condition.clone(),
then,
otherwise,
}));

View File

@ -16,8 +16,6 @@
use crate::{Assigner, RenameTable};
// TODO: Consider refactoring out an Assigner struct that produces (unique) assignment statements.
pub struct StaticSingleAssigner {
/// The `RenameTable` for the current basic block in the AST
pub(crate) rename_table: RenameTable,
@ -33,7 +31,7 @@ impl StaticSingleAssigner {
Self {
rename_table: RenameTable::new(None),
is_lhs: false,
assigner: Assigner::new(),
assigner: Assigner::default(),
}
}