More tyc and ssa for finalize

This commit is contained in:
Pranav Gaddamadugu 2022-08-29 10:13:19 -07:00
parent 791463c82f
commit 3efb4c5108
7 changed files with 238 additions and 108 deletions

View File

@ -15,11 +15,10 @@
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use crate::StaticSingleAssigner;
use itertools::Itertools;
use leo_ast::{
Block, CircuitExpression, CircuitVariableInitializer, Expression, Function, FunctionConsumer, Identifier, Program,
ProgramConsumer, ReturnStatement, Statement, StatementConsumer, TernaryExpression, TupleExpression,
Block, FinalizeStatement, Function, FunctionConsumer, Program, ProgramConsumer, ReturnStatement, Statement,
StatementConsumer,
};
impl FunctionConsumer for StaticSingleAssigner<'_> {
@ -39,99 +38,39 @@ impl FunctionConsumer for StaticSingleAssigner<'_> {
let mut statements = self.consume_block(function.block);
// Add the `ReturnStatement` to the end of the block.
let mut returns = self.clear_early_returns();
// Get all of the guards and return expression.
let returns = self.clear_early_returns();
// Type checking guarantees that there exists at least one return statement in the function body.
let (_, last_return_expression) = returns.pop().unwrap();
// If the function contains return statements, then we fold them into a single return statement.
if !returns.is_empty() {
let (stmts, expression) = self.fold_guards("ret$", returns);
// Produce a chain of ternary expressions and assignments for the set of early returns.
let mut stmts = Vec::with_capacity(returns.len());
// Add all of the accumulated statements to the end of the block.
statements.extend(stmts);
// Helper to construct and store ternary assignments. e.g `$ret$0 = $var$0 ? $var$1 : $var$2`
let mut construct_ternary_assignment = |guard: Expression, if_true: Expression, if_false: Expression| {
let place = Expression::Identifier(Identifier {
name: self.unique_symbol("$ret"),
// Add the `ReturnStatement` to the end of the block.
statements.push(Statement::Return(ReturnStatement {
expression,
span: Default::default(),
});
stmts.push(Self::simple_assign_statement(
place.clone(),
Expression::Ternary(TernaryExpression {
condition: Box::new(guard),
if_true: Box::new(if_true),
if_false: Box::new(if_false),
span: Default::default(),
}),
));
place
};
}));
}
let expression = returns
.into_iter()
.rev()
.fold(last_return_expression, |acc, (guard, expr)| match guard {
None => unreachable!("All return statements except for the last one must have a guard."),
// Note that type checking guarantees that all expressions in return statements in the function body have the same type.
Some(guard) => match (expr, acc) {
// If the function returns tuples, fold the return expressions into a tuple of ternary expressions.
// Note that `expr` and `acc` are correspond to the `if` and `else` cases of the ternary expression respectively.
(Expression::Tuple(expr_tuple), Expression::Tuple(acc_tuple)) => {
Expression::Tuple(TupleExpression {
elements: expr_tuple
.elements
.into_iter()
.zip_eq(acc_tuple.elements.into_iter())
.map(|(if_true, if_false)| {
construct_ternary_assignment(guard.clone(), if_true, if_false)
})
.collect(),
span: Default::default(),
})
}
// If the function returns circuits, fold the return expressions into a circuit of ternary expressions.
// Note that `expr` and `acc` are correspond to the `if` and `else` cases of the ternary expression respectively.
(Expression::Circuit(expr_circuit), Expression::Circuit(acc_circuit)) => {
Expression::Circuit(CircuitExpression {
name: acc_circuit.name,
span: acc_circuit.span,
members: expr_circuit
.members
.into_iter()
.zip_eq(acc_circuit.members.into_iter())
.map(|(if_true, if_false)| {
let expression = construct_ternary_assignment(
guard.clone(),
match if_true.expression {
None => Expression::Identifier(if_true.identifier),
Some(expr) => expr,
},
match if_false.expression {
None => Expression::Identifier(if_false.identifier),
Some(expr) => expr,
},
);
CircuitVariableInitializer {
identifier: if_true.identifier,
expression: Some(expression),
}
})
.collect(),
})
}
// Otherwise, fold the return expressions into a single ternary expression.
// Note that `expr` and `acc` are correspond to the `if` and `else` cases of the ternary expression respectively.
(expr, acc) => construct_ternary_assignment(guard, expr, acc),
},
});
// Get all of the guards and finalize expression.
let finalizes = self.clear_early_finalizes();
// Add all of the accumulated statements to the end of the block.
statements.extend(stmts);
// If the function contains finalize statements, then we fold them into a single finalize statement.
if !finalizes.is_empty() {
let (stmts, expression) = self.fold_guards("fin$", finalizes);
// Add the `ReturnStatement` to the end of the block.
statements.push(Statement::Return(ReturnStatement {
expression,
span: Default::default(),
}));
// Add all of the accumulated statements to the end of the block.
statements.extend(stmts);
// Add the `FinalizeStatement` to the end of the block.
statements.push(Statement::Finalize(FinalizeStatement {
expression,
span: Default::default(),
}));
}
// Remove the `RenameTable` for the function.
self.pop();

View File

@ -187,8 +187,22 @@ impl StatementConsumer for StaticSingleAssigner<'_> {
statements
}
fn consume_decrement(&mut self, _input: DecrementStatement) -> Self::Output {
todo!()
fn consume_decrement(&mut self, input: DecrementStatement) -> Self::Output {
// First consume the expression associated with the amount.
let (amount, mut statements) = self.consume_expression(input.amount);
// Then, consume the expression associated with the index.
let (index, index_statements) = self.consume_expression(input.index);
statements.extend(index_statements);
statements.push(Statement::Decrement(DecrementStatement {
mapping: input.mapping,
index,
amount,
span: input.span,
}));
statements
}
/// Consumes the `DefinitionStatement` into an `AssignStatement`, renaming the left-hand-side as appropriate.
@ -210,12 +224,48 @@ impl StatementConsumer for StaticSingleAssigner<'_> {
statements
}
fn consume_finalize(&mut self, _input: FinalizeStatement) -> Self::Output {
todo!()
fn consume_finalize(&mut self, input: FinalizeStatement) -> Self::Output {
// Construct the associated guard.
let guard = match self.condition_stack.is_empty() {
true => None,
false => {
let (first, rest) = self.condition_stack.split_first().unwrap();
Some(rest.iter().cloned().fold(first.clone(), |acc, condition| {
Expression::Binary(BinaryExpression {
op: BinaryOperation::And,
left: Box::new(acc),
right: Box::new(condition),
span: Default::default(),
})
}))
}
};
// Consume the expression and add it to `early_finalizes`.
let (expression, statements) = self.consume_expression(input.expression);
// Note that this is the only place where `self.early_finalizes` is appended.
// Furthermore, `expression` will always be an identifier or tuple expression.
self.early_finalizes.push((guard, expression));
statements
}
fn consume_increment(&mut self, _input: IncrementStatement) -> Self::Output {
todo!()
fn consume_increment(&mut self, input: IncrementStatement) -> Self::Output {
// First consume the expression associated with the amount.
let (amount, mut statements) = self.consume_expression(input.amount);
// Then, consume the expression associated with the index.
let (index, index_statements) = self.consume_expression(input.index);
statements.extend(index_statements);
statements.push(Statement::Increment(IncrementStatement {
mapping: input.mapping,
index,
amount,
span: input.span,
}));
statements
}
// TODO: Error message
@ -245,7 +295,7 @@ impl StatementConsumer for StaticSingleAssigner<'_> {
// Consume the expression and add it to `early_returns`.
let (expression, statements) = self.consume_expression(input.expression);
// Note that this is the only place where `self.early_returns` is mutated.
// Note that this is the only place where `self.early_returns` is appended.
// Furthermore, `expression` will always be an identifier or tuple expression.
self.early_returns.push((guard, expression));

View File

@ -15,9 +15,13 @@
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use crate::RenameTable;
use itertools::Itertools;
use std::fmt::Display;
use leo_ast::{AssignStatement, Expression, Identifier, Statement};
use leo_ast::{
AssignStatement, CircuitExpression, CircuitVariableInitializer, Expression, Identifier, Statement,
TernaryExpression, TupleExpression,
};
use leo_errors::emitter::Handler;
use leo_span::Symbol;
@ -35,6 +39,9 @@ pub struct StaticSingleAssigner<'a> {
/// A list containing tuples of guards and expressions associated with early `ReturnStatement`s.
/// Note that early returns are inserted in the order they are encountered during a pre-order traversal of the AST.
pub(crate) early_returns: Vec<(Option<Expression>, Expression)>,
/// A list containing tuples of guards and expressions associated with early `FinalizeStatement`s.
/// Note that early finalizes are inserted in the order they are encountered during a pre-order traversal of the AST.
pub(crate) early_finalizes: Vec<(Option<Expression>, Expression)>,
}
impl<'a> StaticSingleAssigner<'a> {
@ -46,6 +53,7 @@ impl<'a> StaticSingleAssigner<'a> {
is_lhs: false,
condition_stack: Vec::new(),
early_returns: Vec::new(),
early_finalizes: Vec::new(),
}
}
@ -84,6 +92,11 @@ impl<'a> StaticSingleAssigner<'a> {
core::mem::take(&mut self.early_returns)
}
// Clears the state associated with `FinalizeStatements`, returning the ones that were previously produced.
pub(crate) fn clear_early_finalizes(&mut self) -> Vec<(Option<Expression>, Expression)> {
core::mem::take(&mut self.early_finalizes)
}
/// Pushes a new scope, setting the current scope as the new scope's parent.
pub(crate) fn push(&mut self) {
let parent_table = core::mem::take(&mut self.rename_table);
@ -95,4 +108,96 @@ impl<'a> StaticSingleAssigner<'a> {
let parent = self.rename_table.parent.clone().unwrap_or_default();
core::mem::replace(&mut self.rename_table, *parent)
}
/// Fold guards and expressions into a single expression.
/// Note that this function assumes that at least one guard is present.
pub(crate) fn fold_guards(
&mut self,
prefix: &str,
mut guards: Vec<(Option<Expression>, Expression)>,
) -> (Vec<Statement>, Expression) {
// Type checking guarantees that there exists at least one return statement in the function body.
let (_, last_expression) = guards.pop().unwrap();
// Produce a chain of ternary expressions and assignments for the guards.
let mut stmts = Vec::with_capacity(guards.len());
// Helper to construct and store ternary assignments. e.g `$ret$0 = $var$0 ? $var$1 : $var$2`
let mut construct_ternary_assignment = |guard: Expression, if_true: Expression, if_false: Expression| {
let place = Expression::Identifier(Identifier {
name: self.unique_symbol(prefix),
span: Default::default(),
});
stmts.push(Self::simple_assign_statement(
place.clone(),
Expression::Ternary(TernaryExpression {
condition: Box::new(guard),
if_true: Box::new(if_true),
if_false: Box::new(if_false),
span: Default::default(),
}),
));
place
};
let expression = guards
.into_iter()
.rev()
.fold(last_expression, |acc, (guard, expr)| match guard {
None => unreachable!("All expression except for the last one must have a guard."),
// Note that type checking guarantees that all expressions have the same type.
Some(guard) => match (expr, acc) {
// If the function returns tuples, fold the expressions into a tuple of ternary expressions.
// Note that `expr` and `acc` are correspond to the `if` and `else` cases of the ternary expression respectively.
(Expression::Tuple(expr_tuple), Expression::Tuple(acc_tuple)) => {
Expression::Tuple(TupleExpression {
elements: expr_tuple
.elements
.into_iter()
.zip_eq(acc_tuple.elements.into_iter())
.map(|(if_true, if_false)| {
construct_ternary_assignment(guard.clone(), if_true, if_false)
})
.collect(),
span: Default::default(),
})
}
// If the expression is a circuit, fold the expressions into a circuit of ternary expressions.
// Note that `expr` and `acc` are correspond to the `if` and `else` cases of the ternary expression respectively.
(Expression::Circuit(expr_circuit), Expression::Circuit(acc_circuit)) => {
Expression::Circuit(CircuitExpression {
name: acc_circuit.name,
span: acc_circuit.span,
members: expr_circuit
.members
.into_iter()
.zip_eq(acc_circuit.members.into_iter())
.map(|(if_true, if_false)| {
let expression = construct_ternary_assignment(
guard.clone(),
match if_true.expression {
None => Expression::Identifier(if_true.identifier),
Some(expr) => expr,
},
match if_false.expression {
None => Expression::Identifier(if_false.identifier),
Some(expr) => expr,
},
);
CircuitVariableInitializer {
identifier: if_true.identifier,
expression: Some(expression),
}
})
.collect(),
})
}
// Otherwise, fold the return expressions into a single ternary expression.
// Note that `expr` and `acc` are correspond to the `if` and `else` cases of the ternary expression respectively.
(expr, acc) => construct_ternary_assignment(guard, expr, acc),
},
});
(stmts, expression)
}
}

View File

@ -128,6 +128,9 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
// The function's body does not have a return statement.
self.has_return = false;
// The function's body does not have a finalize statement.
self.has_finalize = false;
// Store the name of the function.
self.function = Some(function.name());
@ -161,16 +164,8 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
self.visit_block(&function.block);
if !self.has_return {
self.emit_err(TypeCheckerError::function_has_no_return(
function.name(),
function.span(),
));
} else {
// Check that the return type is valid.
// TODO: Span should be just for the return type.
self.assert_type_is_valid(function.span, &function.output);
}
// Check that the return type is valid.
self.assert_type_is_valid(function.span, &function.output);
// Ensure there are no nested tuples in the return type.
if let Type::Tuple(tys) = &function.output {
@ -185,6 +180,8 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
// Traverse and check the finalize block if it exists.
if let Some(finalize) = &function.finalize {
self.is_finalize = true;
// The function's finalize block does not have a return statement.
self.has_return = false;
if !self.is_program_function {
self.emit_err(TypeCheckerError::only_program_functions_can_have_finalize(
@ -221,6 +218,9 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
// Type check the finalize block.
self.visit_block(&finalize.block);
// Check that the return type is valid.
self.assert_type_is_valid(finalize.span, &finalize.output);
// Exit the scope for the finalize block.
self.exit_scope(scope_index);

View File

@ -21,6 +21,7 @@ use leo_errors::TypeCheckerError;
impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
fn visit_statement(&mut self, input: &'a Statement) {
// No statements can follow a return statement.
if self.has_return {
self.emit_err(TypeCheckerError::unreachable_code_after_return(input.span()));
return;
@ -80,8 +81,13 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
let mut then_block_has_return = false;
let mut otherwise_block_has_return = false;
let mut then_block_has_finalize = false;
let mut otherwise_block_has_finalize = false;
// Set the `has_return` flag for the then-block.
let previous_has_return = core::mem::replace(&mut self.has_return, then_block_has_return);
// Set the `has_finalize` flag for the then-block.
let previous_has_finalize = core::mem::replace(&mut self.has_finalize, then_block_has_finalize);
// Create a new scope for the then-block.
let scope_index = self.symbol_table.borrow_mut().insert_block();
@ -93,10 +99,14 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
// Store the `has_return` flag for the then-block.
then_block_has_return = self.has_return;
// Store the `has_finalize` flag for the then-block.
then_block_has_finalize = self.has_finalize;
if let Some(otherwise) = &input.otherwise {
// Set the `has_return` flag for the otherwise-block.
self.has_return = otherwise_block_has_return;
// Set the `has_finalize` flag for the otherwise-block.
self.has_finalize = otherwise_block_has_finalize;
match &**otherwise {
Statement::Block(stmt) => {
@ -115,10 +125,14 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
// Store the `has_return` flag for the otherwise-block.
otherwise_block_has_return = self.has_return;
// Store the `has_finalize` flag for the otherwise-block.
otherwise_block_has_finalize = self.has_finalize;
}
// Restore the previous `has_return` flag.
self.has_return = previous_has_return || (then_block_has_return && otherwise_block_has_return);
// Restore the previous `has_finalize` flag.
self.has_finalize = previous_has_finalize || (then_block_has_finalize && otherwise_block_has_finalize);
}
fn visit_console(&mut self, input: &'a ConsoleStatement) {
@ -214,6 +228,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
None => self.emit_err(TypeCheckerError::finalize_without_finalize_block(input.span())),
Some(finalize) => {
let type_ = self.visit_expression(&input.expression, &None);
// TODO: Check that the finalize type is correct.
self.assert_and_return_type(finalize.output, &type_, input.expression.span());
}
}
@ -273,6 +288,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
}
let prior_has_return = core::mem::take(&mut self.has_return);
let prior_has_finalize = core::mem::take(&mut self.has_finalize);
self.visit_block(&input.block);
@ -280,7 +296,12 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
self.emit_err(TypeCheckerError::loop_body_contains_return(input.span()));
}
if self.has_finalize {
self.emit_err(TypeCheckerError::loop_body_contains_finalize(input.span()));
}
self.has_return = prior_has_return;
self.has_finalize = prior_has_finalize;
// Exit the scope.
self.exit_scope(scope_index);
@ -308,7 +329,12 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
.symbol_table
.borrow()
.lookup_fn_symbol(parent)
.map(|f| f.output.clone());
.map(|f| match self.is_finalize {
// TODO: Check this.
// Note that this `unwrap()` is safe since we checked that the function has a finalize block.
true => f.finalize.as_ref().unwrap().output.clone(),
false => f.output.clone(),
});
self.has_return = true;

View File

@ -33,6 +33,8 @@ pub struct TypeChecker<'a> {
pub(crate) function: Option<Symbol>,
/// Whether or not the function that we are currently traversing has a return statement.
pub(crate) has_return: bool,
/// Whether or not the function that we are currently traversing has a finalize statement.
pub(crate) has_finalize: bool,
/// Are we traversing a program function?
/// A "program function" is a function that can be invoked by a user or another program.
pub(crate) is_program_function: bool,
@ -92,6 +94,7 @@ impl<'a> TypeChecker<'a> {
handler,
function: None,
has_return: false,
has_finalize: false,
is_finalize: false,
}
}

View File

@ -337,4 +337,11 @@ create_messages!(
msg: format!("Cannot use a `finalize` statement without a `finalize` block."),
help: None,
}
@formatted
loop_body_contains_finalize {
args: (),
msg: format!("Loop body contains a finalize statement."),
help: Some("Remove the finalize statement.".to_string()),
}
);