Use NodeBuilder in LoopUnroller

This commit is contained in:
Pranav Gaddamadugu 2023-08-17 16:02:01 -04:00
parent 8f63fcdf80
commit 19ba799d21
9 changed files with 113 additions and 67 deletions

View File

@ -24,7 +24,7 @@ use leo_ast::{
ExpressionReconstructor,
Identifier,
IterationStatement,
NodeID,
NodeReconstructor,
ProgramReconstructor,
Statement,
StatementReconstructor,
@ -36,15 +36,15 @@ use leo_span::Symbol;
// TODO: Generalize the functionality of this reconstructor to be used in other passes.
/// An `AssignmentRenamer` renames the left-hand side of all assignment statements in an AST node.
/// The new names are propagated to all following identifiers.
pub struct AssignmentRenamer {
pub assigner: Assigner,
pub struct AssignmentRenamer<'a> {
pub assigner: &'a Assigner,
pub rename_table: RenameTable,
pub is_lhs: bool,
}
impl AssignmentRenamer {
impl<'a> AssignmentRenamer<'a> {
/// Initialize a new `AssignmentRenamer`.
pub fn new(assigner: Assigner) -> Self {
pub fn new(assigner: &'a Assigner) -> Self {
Self { assigner, rename_table: RenameTable::new(None), is_lhs: false }
}
@ -61,7 +61,9 @@ impl AssignmentRenamer {
}
}
impl ExpressionReconstructor for AssignmentRenamer {
impl NodeReconstructor for AssignmentRenamer<'_> {}
impl ExpressionReconstructor for AssignmentRenamer<'_> {
type AdditionalOutput = ();
/// Rename the identifier if it is the left-hand side of an assignment, otherwise look up for a new name in the internal rename table.
@ -80,7 +82,10 @@ impl ExpressionReconstructor for AssignmentRenamer {
false => *self.rename_table.lookup(input.name).unwrap_or(&input.name),
};
(Expression::Identifier(Identifier { name, span: input.span, id: NodeID::default() }), Default::default())
(
Expression::Identifier(Identifier { name, span: input.span, id: self.reconstruct_node_id(input.id) }),
Default::default(),
)
}
/// Rename the variable initializers in the struct expression.
@ -100,18 +105,18 @@ impl ExpressionReconstructor for AssignmentRenamer {
),
},
span: member.span,
id: NodeID::default(),
id: self.reconstruct_node_id(member.id),
})
.collect(),
span: input.span,
id: NodeID::default(),
id: self.reconstruct_node_id(input.id),
}),
Default::default(),
)
}
}
impl StatementReconstructor for AssignmentRenamer {
impl StatementReconstructor for AssignmentRenamer<'_> {
/// Rename the left-hand side of the assignment statement.
fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
// First rename the right-hand-side of the assignment.
@ -124,7 +129,12 @@ impl StatementReconstructor for AssignmentRenamer {
self.is_lhs = false;
(
Statement::Assign(Box::new(AssignStatement { place, value, span: input.span, id: NodeID::default() })),
Statement::Assign(Box::new(AssignStatement {
place,
value,
span: input.span,
id: self.reconstruct_node_id(input.id),
})),
Default::default(),
)
}
@ -150,4 +160,4 @@ impl StatementReconstructor for AssignmentRenamer {
}
}
impl ProgramReconstructor for AssignmentRenamer {}
impl ProgramReconstructor for AssignmentRenamer<'_> {}

View File

@ -16,27 +16,33 @@
use crate::{Assigner, AssignmentRenamer, CallGraph};
use leo_ast::Function;
use leo_ast::{Function, NodeBuilder};
use leo_span::Symbol;
use indexmap::IndexMap;
pub struct FunctionInliner<'a> {
/// A counter used to create unique NodeIDs.
pub(crate) node_builder: &'a NodeBuilder,
/// The call graph for the program.
pub(crate) call_graph: &'a CallGraph,
/// A wrapper around an Assigner used to create unique variable assignments.
pub(crate) assignment_renamer: AssignmentRenamer,
pub(crate) assignment_renamer: AssignmentRenamer<'a>,
/// A map of reconstructed functions in the current program scope.
pub(crate) reconstructed_functions: IndexMap<Symbol, Function>,
/// Whether or not we are currently inlining a function.
pub(crate) inlining: bool,
}
impl<'a> FunctionInliner<'a> {
/// Initializes a new `FunctionInliner`.
pub fn new(call_graph: &'a CallGraph, assigner: Assigner) -> Self {
pub fn new(node_builder: &'a NodeBuilder, call_graph: &'a CallGraph, assigner: &'a Assigner) -> Self {
Self {
node_builder,
call_graph,
assignment_renamer: AssignmentRenamer::new(assigner),
reconstructed_functions: Default::default(),
inlining: false,
}
}
}

View File

@ -66,17 +66,17 @@ pub use function_inliner::*;
use crate::{Assigner, CallGraph, Pass};
use leo_ast::{Ast, ProgramReconstructor};
use leo_ast::{Ast, NodeBuilder, ProgramReconstructor};
use leo_errors::Result;
impl<'a> Pass for FunctionInliner<'a> {
type Input = (Ast, &'a CallGraph, Assigner);
type Output = Result<(Ast, Assigner)>;
type Input = (Ast, &'a NodeBuilder, &'a CallGraph, &'a Assigner);
type Output = Result<Ast>;
fn do_pass((ast, call_graph, assigner): Self::Input) -> Self::Output {
let mut reconstructor = FunctionInliner::new(call_graph, assigner);
fn do_pass((ast, node_builder, call_graph, assigner): Self::Input) -> Self::Output {
let mut reconstructor = FunctionInliner::new(node_builder, call_graph, assigner);
let program = reconstructor.reconstruct_program(ast.into_repr());
Ok((Ast::new(program), reconstructor.assignment_renamer.assigner))
Ok(Ast::new(program))
}
}

View File

@ -31,15 +31,15 @@ pub use unroll_statement::*;
use crate::{Pass, SymbolTable};
use leo_ast::{Ast, ProgramReconstructor};
use leo_ast::{Ast, NodeBuilder, ProgramReconstructor};
use leo_errors::{emitter::Handler, Result};
impl<'a> Pass for Unroller<'a> {
type Input = (Ast, &'a Handler, SymbolTable);
type Input = (Ast, &'a Handler, &'a NodeBuilder, SymbolTable);
type Output = Result<(Ast, SymbolTable)>;
fn do_pass((ast, handler, st): Self::Input) -> Self::Output {
let mut reconstructor = Self::new(st, handler);
fn do_pass((ast, handler, node_builder, st): Self::Input) -> Self::Output {
let mut reconstructor = Self::new(st, handler, node_builder);
let program = reconstructor.reconstruct_program(ast.into_repr());
handler.last_err().map_err(|e| *e)?;

View File

@ -47,7 +47,7 @@ impl ProgramReconstructor for Unroller<'_> {
output_type: finalize.output_type,
block,
span: finalize.span,
id: NodeID::default(),
id: finalize.id,
}
});
@ -62,7 +62,7 @@ impl ProgramReconstructor for Unroller<'_> {
block,
finalize,
span: function.span,
id: NodeID::default(),
id: function.id,
};
// Exit the function's scope.

View File

@ -30,7 +30,7 @@ impl StatementReconstructor for Unroller<'_> {
let block = Block {
statements: input.statements.into_iter().map(|s| self.reconstruct_statement(s).0).collect(),
span: input.span,
id: NodeID::default(),
id: input.id,
};
// Exit the block scope.

View File

@ -22,7 +22,7 @@ use leo_ast::{
IntegerType,
IterationStatement,
Literal,
NodeID,
NodeBuilder,
Statement,
StatementReconstructor,
Type,
@ -41,13 +41,15 @@ pub struct Unroller<'a> {
pub(crate) scope_index: usize,
/// An error handler used for any errors found during unrolling.
pub(crate) handler: &'a Handler,
/// A counter used to generate unique node IDs.
pub(crate) node_builder: &'a NodeBuilder,
/// Are we in the midst of unrolling a loop?
pub(crate) is_unrolling: bool,
}
impl<'a> Unroller<'a> {
pub(crate) fn new(symbol_table: SymbolTable, handler: &'a Handler) -> Self {
Self { symbol_table: RefCell::new(symbol_table), scope_index: 0, handler, is_unrolling: false }
pub(crate) fn new(symbol_table: SymbolTable, handler: &'a Handler, node_builder: &'a NodeBuilder) -> Self {
Self { symbol_table: RefCell::new(symbol_table), scope_index: 0, handler, node_builder, is_unrolling: false }
}
/// Returns the index of the current scope.
@ -86,7 +88,7 @@ impl<'a> Unroller<'a> {
Ok(val_as_u128) => Ok(val_as_u128),
Err(err) => {
self.handler.emit_err(err);
Err(Statement::dummy(input.span))
Err(Statement::dummy(input.span, self.node_builder.next_id()))
}
}
};
@ -128,7 +130,7 @@ impl<'a> Unroller<'a> {
iter.map(|iteration_count| self.unroll_single_iteration(&input, iteration_count)).collect()
}
},
id: NodeID::default(),
id: input.id,
});
// Exit the scope of the loop body.
@ -148,36 +150,66 @@ impl<'a> Unroller<'a> {
// Reconstruct `iteration_count` as a `Literal`.
let value = match input.type_ {
Type::Integer(IntegerType::I8) => {
Literal::Integer(IntegerType::I8, iteration_count.to_string(), Default::default(), NodeID::default())
}
Type::Integer(IntegerType::I16) => {
Literal::Integer(IntegerType::I16, iteration_count.to_string(), Default::default(), NodeID::default())
}
Type::Integer(IntegerType::I32) => {
Literal::Integer(IntegerType::I32, iteration_count.to_string(), Default::default(), NodeID::default())
}
Type::Integer(IntegerType::I64) => {
Literal::Integer(IntegerType::I64, iteration_count.to_string(), Default::default(), NodeID::default())
}
Type::Integer(IntegerType::I128) => {
Literal::Integer(IntegerType::I128, iteration_count.to_string(), Default::default(), NodeID::default())
}
Type::Integer(IntegerType::U8) => {
Literal::Integer(IntegerType::U8, iteration_count.to_string(), Default::default(), NodeID::default())
}
Type::Integer(IntegerType::U16) => {
Literal::Integer(IntegerType::U16, iteration_count.to_string(), Default::default(), NodeID::default())
}
Type::Integer(IntegerType::U32) => {
Literal::Integer(IntegerType::U32, iteration_count.to_string(), Default::default(), NodeID::default())
}
Type::Integer(IntegerType::U64) => {
Literal::Integer(IntegerType::U64, iteration_count.to_string(), Default::default(), NodeID::default())
}
Type::Integer(IntegerType::U128) => {
Literal::Integer(IntegerType::U128, iteration_count.to_string(), Default::default(), NodeID::default())
}
Type::Integer(IntegerType::I8) => Literal::Integer(
IntegerType::I8,
iteration_count.to_string(),
Default::default(),
self.node_builder.next_id(),
),
Type::Integer(IntegerType::I16) => Literal::Integer(
IntegerType::I16,
iteration_count.to_string(),
Default::default(),
self.node_builder.next_id(),
),
Type::Integer(IntegerType::I32) => Literal::Integer(
IntegerType::I32,
iteration_count.to_string(),
Default::default(),
self.node_builder.next_id(),
),
Type::Integer(IntegerType::I64) => Literal::Integer(
IntegerType::I64,
iteration_count.to_string(),
Default::default(),
self.node_builder.next_id(),
),
Type::Integer(IntegerType::I128) => Literal::Integer(
IntegerType::I128,
iteration_count.to_string(),
Default::default(),
self.node_builder.next_id(),
),
Type::Integer(IntegerType::U8) => Literal::Integer(
IntegerType::U8,
iteration_count.to_string(),
Default::default(),
self.node_builder.next_id(),
),
Type::Integer(IntegerType::U16) => Literal::Integer(
IntegerType::U16,
iteration_count.to_string(),
Default::default(),
self.node_builder.next_id(),
),
Type::Integer(IntegerType::U32) => Literal::Integer(
IntegerType::U32,
iteration_count.to_string(),
Default::default(),
self.node_builder.next_id(),
),
Type::Integer(IntegerType::U64) => Literal::Integer(
IntegerType::U64,
iteration_count.to_string(),
Default::default(),
self.node_builder.next_id(),
),
Type::Integer(IntegerType::U128) => Literal::Integer(
IntegerType::U128,
iteration_count.to_string(),
Default::default(),
self.node_builder.next_id(),
),
_ => unreachable!(
"The iteration variable must be an integer type. This should be enforced by type checking."
),
@ -191,7 +223,7 @@ impl<'a> Unroller<'a> {
value: Expression::Literal(value),
span: Default::default(),
place: Expression::Identifier(input.variable),
id: NodeID::default(),
id: self.node_builder.next_id(),
})
.0,
];
@ -201,7 +233,7 @@ impl<'a> Unroller<'a> {
statements.push(self.reconstruct_statement(s).0);
});
let block = Statement::Block(Block { statements, span: input.block.span, id: NodeID::default() });
let block = Statement::Block(Block { statements, span: input.block.span, id: input.block.id });
self.is_unrolling = prior_is_unrolling;

View File

@ -27,7 +27,6 @@ use leo_ast::{
Identifier,
Literal,
MemberAccess,
NodeID,
Statement,
Struct,
StructExpression,

View File

@ -22,7 +22,6 @@ use leo_ast::{
Function,
FunctionConsumer,
Member,
NodeID,
Program,
ProgramConsumer,
ProgramScope,