diff --git a/compiler/ast/src/passes/visitor_director.rs b/compiler/ast/src/passes/visitor_director.rs index 9fee6cdda1..ede3a24a75 100644 --- a/compiler/ast/src/passes/visitor_director.rs +++ b/compiler/ast/src/passes/visitor_director.rs @@ -76,9 +76,7 @@ impl<'a, V: ExpressionVisitor<'a>> VisitorDirector<'a, V> { pub fn visit_call(&mut self, input: &'a CallExpression) { if let VisitResult::VisitChildren = self.visitor.visit_call(input) { - for expr in input.arguments.iter() { - self.visit_expression(expr); - } + input.arguments.iter().for_each(|expr| self.visit_expression(expr)); } } } @@ -137,8 +135,11 @@ impl<'a, V: ExpressionVisitor<'a> + StatementVisitor<'a>> VisitorDirector<'a, V> pub fn visit_console(&mut self, input: &'a ConsoleStatement) { if let VisitResult::VisitChildren = self.visitor.visit_console(input) { - if let ConsoleFunction::Assert(expr) = &input.function { - self.visit_expression(expr); + match &input.function { + ConsoleFunction::Assert(expr) => self.visit_expression(expr), + ConsoleFunction::Error(fmt) | ConsoleFunction::Log(fmt) => { + fmt.parameters.iter().for_each(|expr| self.visit_expression(expr)); + } } } } @@ -151,9 +152,7 @@ impl<'a, V: ExpressionVisitor<'a> + StatementVisitor<'a>> VisitorDirector<'a, V> pub fn visit_block(&mut self, input: &'a Block) { if let VisitResult::VisitChildren = self.visitor.visit_block(input) { - for stmt in input.statements.iter() { - self.visit_statement(stmt); - } + input.statements.iter().for_each(|stmt| self.visit_statement(stmt)); } } } @@ -161,9 +160,10 @@ impl<'a, V: ExpressionVisitor<'a> + StatementVisitor<'a>> VisitorDirector<'a, V> impl<'a, V: ExpressionVisitor<'a> + ProgramVisitor<'a> + StatementVisitor<'a>> VisitorDirector<'a, V> { pub fn visit_program(&mut self, input: &'a Program) { if let VisitResult::VisitChildren = self.visitor.visit_program(input) { - for function in input.functions.values() { - self.visit_function(function); - } + input + .functions + .values() + .for_each(|function| self.visit_function(function)); } } diff --git a/compiler/passes/src/symbol_table/table.rs b/compiler/passes/src/symbol_table/table.rs index 3ec37efa08..cc46ec0d60 100644 --- a/compiler/passes/src/symbol_table/table.rs +++ b/compiler/passes/src/symbol_table/table.rs @@ -77,6 +77,21 @@ impl<'a> SymbolTable<'a> { pub fn lookup_var(&self, symbol: &Symbol) -> Option<&&'a DefinitionStatement> { self.variables.variables.get(symbol) } + + pub fn push_variable_scope(&mut self) { + let current_scope = self.variables.clone(); + self.variables = VariableSymbol { + parent: Some(Box::new(current_scope)), + inputs: Default::default(), + variables: Default::default(), + }; + } + + pub fn pop_variable_scope(&mut self) { + let parent = self.variables.parent.clone().unwrap(); + + self.variables = *parent; + } } impl<'a> Display for SymbolTable<'a> { diff --git a/compiler/passes/src/type_checker/check.rs b/compiler/passes/src/type_checker/check.rs index de3ba95ab0..a9070927b2 100644 --- a/compiler/passes/src/type_checker/check.rs +++ b/compiler/passes/src/type_checker/check.rs @@ -34,6 +34,17 @@ impl<'a> TypeChecker<'a> { self.handler.emit_err(err1); self.handler.emit_err(err2); } + // Types match + _ => {} + } + } + + fn assert_type(&self, type_: Result>, expected: Type, span: &Span) { + match type_ { + Ok(Some(type_)) if type_ != expected => self + .handler + .emit_err(TypeCheckerError::type_should_be(type_, expected, span).into()), + // Types match _ => {} } } @@ -103,27 +114,58 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { _ => {} } - Default::default() + VisitResult::VisitChildren } - fn visit_conditional(&mut self, _input: &'a ConditionalStatement) -> VisitResult { - Default::default() + fn visit_conditional(&mut self, input: &'a ConditionalStatement) -> VisitResult { + self.assert_type(input.condition.get_type(), Type::Boolean, input.span()); + + VisitResult::VisitChildren } - fn visit_iteration(&mut self, _input: &'a IterationStatement) -> VisitResult { - Default::default() + fn visit_iteration(&mut self, input: &'a IterationStatement) -> VisitResult { + // TODO: need to change symbol table to some other repr for variables. + // self.symbol_table.insert_variable(input.variable.name, input); + + let iter_var_type = input.get_type(); + + self.compare_types(iter_var_type.clone(), input.start.get_type(), input.span()); + self.compare_types(iter_var_type, input.stop.get_type(), input.span()); + + VisitResult::VisitChildren } - fn visit_console(&mut self, _input: &'a ConsoleStatement) -> VisitResult { - Default::default() + fn visit_console(&mut self, input: &'a ConsoleStatement) -> VisitResult { + match &input.function { + ConsoleFunction::Assert(expr) => { + self.assert_type(expr.get_type(), Type::Boolean, expr.span()); + } + ConsoleFunction::Error(_) | ConsoleFunction::Log(_) => { + todo!("need to discuss this"); + } + } + + VisitResult::VisitChildren } - fn visit_expression_statement(&mut self, _input: &'a ExpressionStatement) -> VisitResult { - Default::default() - } + fn visit_block(&mut self, input: &'a Block) -> VisitResult { + self.symbol_table.push_variable_scope(); + // have to redo the logic here so we have scoping + input.statements.iter().for_each(|stmt| { + match stmt { + Statement::Return(stmt) => self.visit_return(stmt), + Statement::Definition(stmt) => self.visit_definition(stmt), + Statement::Assign(stmt) => self.visit_assign(stmt), + Statement::Conditional(stmt) => self.visit_conditional(stmt), + Statement::Iteration(stmt) => self.visit_iteration(stmt), + Statement::Console(stmt) => self.visit_console(stmt), + Statement::Expression(stmt) => self.visit_expression_statement(stmt), + Statement::Block(stmt) => self.visit_block(stmt), + }; + }); + self.symbol_table.pop_variable_scope(); - fn visit_block(&mut self, _input: &'a Block) -> VisitResult { - Default::default() + VisitResult::SkipChildren } } diff --git a/leo/errors/src/errors/type_checker/type_checker_error.rs b/leo/errors/src/errors/type_checker/type_checker_error.rs index 5fd062bb85..cd6ce3b152 100644 --- a/leo/errors/src/errors/type_checker/type_checker_error.rs +++ b/leo/errors/src/errors/type_checker/type_checker_error.rs @@ -53,23 +53,33 @@ create_messages!( help: None, } - /// For when the user tries to assign to a const input. - @formatted - cannont_assign_to_const_input { - args: (input: impl Display), - msg: format!( - "Cannot assign to const input `{input}`", - ), - help: None, - } + /// For when the user tries to assign to a const input. + @formatted + cannont_assign_to_const_input { + args: (input: impl Display), + msg: format!( + "Cannot assign to const input `{input}`", + ), + help: None, + } - /// For when the user tries to assign to a const input. - @formatted - cannont_assign_to_const_var { - args: (var: impl Display), - msg: format!( - "Cannot assign to const variable `{var}`", - ), - help: None, - } + /// For when the user tries to assign to a const input. + @formatted + cannont_assign_to_const_var { + args: (var: impl Display), + msg: format!( + "Cannot assign to const variable `{var}`", + ), + help: None, + } + + /// For when the user tries to assign to a const input. + @formatted + type_should_be { + args: (type_: impl Display, expected: impl Display), + msg: format!( + "Found type `{type_}` but type `{expected}` was expected", + ), + help: None, + } );