diff --git a/compiler/passes/src/loop_unrolling/unroll_statement.rs b/compiler/passes/src/loop_unrolling/unroll_statement.rs index d15a443478..f1f5c0fbc4 100644 --- a/compiler/passes/src/loop_unrolling/unroll_statement.rs +++ b/compiler/passes/src/loop_unrolling/unroll_statement.rs @@ -32,6 +32,7 @@ impl StatementReconstructor for Unroller<'_> { VariableType::Mut }; + // TODO: Do we need to obey shadowing rules? input.variable_names.iter().for_each(|v| { if let Err(err) = self.symbol_table.borrow_mut().insert_variable( v.identifier.name, @@ -64,10 +65,7 @@ impl StatementReconstructor for Unroller<'_> { Ok(val_as_usize) => Ok(val_as_usize), Err(err) => { self.handler.emit_err(err); - Err(Statement::Block(Block { - statements: Vec::new(), - span: input.span, - })) + Err(Statement::dummy(input.span)) } } }; @@ -97,12 +95,7 @@ impl StatementReconstructor for Unroller<'_> { Default::default() }; - // Create the iteration scope if it does not exist, otherwise grab the existing one. - let scope_index = if self.is_unrolling { - self.symbol_table.borrow_mut().insert_block() - } else { - self.block_index - }; + let scope_index = self.get_current_block(); // Enter the scope of the loop body. let prev_st = std::mem::take(&mut self.symbol_table); @@ -171,12 +164,8 @@ impl StatementReconstructor for Unroller<'_> { self.is_unrolling = prev_create_iter_scopes; - // TODO: Should this be removed? - // self.symbol_table.borrow_mut().variables.remove(&input.variable.name); - // Restore the previous symbol table. let prev_st = *self.symbol_table.borrow_mut().parent.take().unwrap(); - // TODO: Is this swap necessary? self.symbol_table.swap(prev_st.get_block_scope(scope_index).unwrap()); self.symbol_table = RefCell::new(prev_st); @@ -188,42 +177,20 @@ impl StatementReconstructor for Unroller<'_> { // Restore the previous symbol table. let prev_st = *self.symbol_table.borrow_mut().parent.take().unwrap(); - // TODO: Is this swap necessary? self.symbol_table.swap(prev_st.get_block_scope(scope_index).unwrap()); self.symbol_table = RefCell::new(prev_st); self.block_index = scope_index + 1; iter_blocks } - (None, Some(_)) => { - self.handler - .emit_err(FlattenError::non_const_loop_bounds("start", input.start.span())); - Statement::Iteration(Box::from(input)) - } - (Some(_), None) => { - self.handler - .emit_err(FlattenError::non_const_loop_bounds("stop", input.stop.span())); - Statement::Iteration(Box::from(input)) - } - (None, None) => { - self.handler - .emit_err(FlattenError::non_const_loop_bounds("start", input.start.span())); - self.handler - .emit_err(FlattenError::non_const_loop_bounds("stop", input.stop.span())); - Statement::Iteration(Box::from(input)) - } + // If both loop bounds are not constant, then the loop is not unrolled. + _ => Statement::Iteration(Box::from(input)) } } fn reconstruct_block(&mut self, input: Block) -> Block { - // If we are in an iteration scope we create any sub scopes for it. - // This is because in TYC we remove all its sub scopes to avoid clashing variables - // during flattening. - let current_block = if self.is_unrolling { - self.symbol_table.borrow_mut().insert_block() - } else { - self.block_index - }; + + self.get_current_block(); // Enter block scope. let prev_st = std::mem::take(&mut self.symbol_table); diff --git a/compiler/passes/src/loop_unrolling/unroller.rs b/compiler/passes/src/loop_unrolling/unroller.rs index d6501024d3..bb6aebfdd4 100644 --- a/compiler/passes/src/loop_unrolling/unroller.rs +++ b/compiler/passes/src/loop_unrolling/unroller.rs @@ -40,4 +40,15 @@ impl<'a> Unroller<'a> { is_unrolling: false, } } + + + /// Returns the index of the current block scope. + /// Note that if we are in the midst of unrolling an IterationStatement, a new scope is created. + pub(crate) fn get_current_block(&mut self) -> usize { + if self.is_unrolling { + self.symbol_table.borrow_mut().insert_block() + } else { + self.block_index + } + } } diff --git a/compiler/passes/src/symbol_table/table.rs b/compiler/passes/src/symbol_table/table.rs index 8f94d2f743..baab297135 100644 --- a/compiler/passes/src/symbol_table/table.rs +++ b/compiler/passes/src/symbol_table/table.rs @@ -174,6 +174,17 @@ impl SymbolTable { } } + /// Returns the index associated with the function symbol, if it exists in the symbol table. + pub fn get_fn_index(&self, symbol: &Symbol) -> Option { + if let Some(func) = self.functions.get(symbol) { + Some(func.id) + } else if let Some(parent) = self.parent.as_ref() { + parent.get_fn_id(symbol) + } else { + None + } + } + /// Returns the scope associated with `index`, if it exists in the symbol table. pub fn get_block_scope(&self, index: usize) -> Option<&RefCell> { self.scopes.get(index) diff --git a/compiler/passes/src/type_checking/check_expressions.rs b/compiler/passes/src/type_checking/check_expressions.rs index a533deee81..885a837b83 100644 --- a/compiler/passes/src/type_checking/check_expressions.rs +++ b/compiler/passes/src/type_checking/check_expressions.rs @@ -221,16 +221,21 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { } fn visit_literal(&mut self, input: &'a Literal, expected: &Self::AdditionalInput) -> Self::Output { + + let negate_int = |str_content: &string| { + if self.negate { + format!("-{str_content}") + } else { + str_content.clone() + } + }; + Some(match input { Literal::Address(_, _) => self.assert_and_return_type(Type::Address, expected, input.span()), Literal::Boolean(_, _) => self.assert_and_return_type(Type::Boolean, expected, input.span()), Literal::Field(_, _) => self.assert_and_return_type(Type::Field, expected, input.span()), Literal::I8(str_content, _) => { - let int = if self.negate { - format!("-{str_content}") - } else { - str_content.clone() - }; + let int = negate_int(str_content); if int.parse::().is_err() { self.handler @@ -239,11 +244,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { self.assert_and_return_type(Type::I8, expected, input.span()) } Literal::I16(str_content, _) => { - let int = if self.negate { - format!("-{str_content}") - } else { - str_content.clone() - }; + let int = negate_int(str_content); if int.parse::().is_err() { self.handler @@ -252,11 +253,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { self.assert_and_return_type(Type::I16, expected, input.span()) } Literal::I32(str_content, _) => { - let int = if self.negate { - format!("-{str_content}") - } else { - str_content.clone() - }; + let int = negate_int(str_content); if int.parse::().is_err() { self.handler @@ -265,11 +262,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { self.assert_and_return_type(Type::I32, expected, input.span()) } Literal::I64(str_content, _) => { - let int = if self.negate { - format!("-{str_content}") - } else { - str_content.clone() - }; + let int = negate_int(str_content); if int.parse::().is_err() { self.handler @@ -278,11 +271,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { self.assert_and_return_type(Type::I64, expected, input.span()) } Literal::I128(str_content, _) => { - let int = if self.negate { - format!("-{str_content}") - } else { - str_content.clone() - }; + let int = negate_int(str_content); if int.parse::().is_err() { self.handler diff --git a/compiler/passes/src/type_checking/check_statements.rs b/compiler/passes/src/type_checking/check_statements.rs index 65faf5526d..2e85576fd8 100644 --- a/compiler/passes/src/type_checking/check_statements.rs +++ b/compiler/passes/src/type_checking/check_statements.rs @@ -127,7 +127,6 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { // Restore the previous scope. let prev_st = *self.symbol_table.borrow_mut().parent.take().unwrap(); - // TODO: Is this swap necessary? self.symbol_table.swap(prev_st.get_block_scope(scope_index).unwrap()); self.symbol_table = RefCell::new(prev_st); @@ -168,7 +167,6 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { input.statements.iter().for_each(|stmt| self.visit_statement(stmt)); let previous_symbol_table = *self.symbol_table.borrow_mut().parent.take().unwrap(); - // TODO: Is this swap necessary? self.symbol_table .swap(previous_symbol_table.get_block_scope(scope_index).unwrap()); self.symbol_table = RefCell::new(previous_symbol_table);