diff --git a/compiler/passes/src/flattening/flatten_program.rs b/compiler/passes/src/flattening/flatten_program.rs index d6290de9a7..064b7995a3 100644 --- a/compiler/passes/src/flattening/flatten_program.rs +++ b/compiler/passes/src/flattening/flatten_program.rs @@ -16,7 +16,7 @@ use crate::Flattener; -use leo_ast::{Function, ProgramReconstructor, ProgramScope, Statement, StatementReconstructor}; +use leo_ast::{Finalize, Function, ProgramReconstructor, ProgramScope, Statement, StatementReconstructor}; impl ProgramReconstructor for Flattener<'_> { /// Flattens a program scope. @@ -59,7 +59,23 @@ impl ProgramReconstructor for Flattener<'_> { output: function.output, output_type: function.output_type, block, - finalize: function.finalize, + finalize: function.finalize.map(|finalize| { + // Set the `is_finalize` flag before reconstructing the finalize block. + self.is_finalize = true; + // Reconstruct the finalize block. + let finalize = Finalize { + identifier: finalize.identifier, + input: finalize.input, + output: finalize.output, + output_type: finalize.output_type, + block: self.reconstruct_block(finalize.block).0, + span: finalize.span, + id: finalize.id, + }; + // Reset the `is_finalize` flag. + self.is_finalize = false; + finalize + }), span: function.span, id: function.id, } diff --git a/compiler/passes/src/flattening/flatten_statement.rs b/compiler/passes/src/flattening/flatten_statement.rs index 9a47421112..ca5bdbccc4 100644 --- a/compiler/passes/src/flattening/flatten_statement.rs +++ b/compiler/passes/src/flattening/flatten_statement.rs @@ -61,6 +61,11 @@ impl StatementReconstructor for Flattener<'_> { fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) { let mut statements = Vec::new(); + // If traversing a `finalize` block, return the assert as is. + if self.is_finalize { + return (Statement::Assert(input), statements); + } + // Flatten the arguments of the assert statement. let assert = AssertStatement { span: input.span, @@ -222,6 +227,26 @@ impl StatementReconstructor for Flattener<'_> { fn reconstruct_conditional(&mut self, conditional: ConditionalStatement) -> (Statement, Self::AdditionalOutput) { let mut statements = Vec::with_capacity(conditional.then.statements.len()); + // If traversing a `finalize` block, only reconstruct the if and else blocks of the conditional statement. + if self.is_finalize { + let then_block = self.reconstruct_block(conditional.then).0; + let otherwise_block = conditional.otherwise.map(|statement| match *statement { + Statement::Block(block) => Box::new(Statement::Block(self.reconstruct_block(block).0)), + _ => unreachable!("Parsing guarantees that the `otherwise` is always a `Block`"), + }); + + return ( + Statement::Conditional(ConditionalStatement { + condition: conditional.condition, + then: then_block, + otherwise: otherwise_block, + span: conditional.span, + id: conditional.id, + }), + statements, + ); + } + // Add condition to the condition stack. self.condition_stack.push(conditional.condition.clone()); @@ -269,6 +294,10 @@ impl StatementReconstructor for Flattener<'_> { /// Transforms a return statement into an empty block statement. /// Stores the arguments to the return statement, which are later folded into a single return statement at the end of the function. fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) { + // If traversing a `finalize` block, return as is. + if self.is_finalize { + return (Statement::Return(input), Default::default()); + } // Construct the associated guard. let guard = self.construct_guard(); diff --git a/compiler/passes/src/flattening/flattener.rs b/compiler/passes/src/flattening/flattener.rs index 4e51de0397..772b4a1c2e 100644 --- a/compiler/passes/src/flattening/flattener.rs +++ b/compiler/passes/src/flattening/flattener.rs @@ -67,6 +67,8 @@ pub struct Flattener<'a> { pub(crate) returns: Vec<(Option, ReturnStatement)>, /// The program name. pub(crate) program: Option, + /// Whether we are currently traversing a `finalize` block. + pub(crate) is_finalize: bool, } impl<'a> Flattener<'a> { @@ -84,6 +86,7 @@ impl<'a> Flattener<'a> { condition_stack: Vec::new(), returns: Vec::new(), program: None, + is_finalize: false, } }