From 044b2a10a453497938407d6ff8784364a222f147 Mon Sep 17 00:00:00 2001 From: gluax <16431709+gluax@users.noreply.github.com> Date: Thu, 26 May 2022 13:29:51 -0700 Subject: [PATCH] refa:ctored visitor pattern to better on an the AST --- compiler/ast/src/passes/visitor.rs | 18 +- compiler/ast/src/passes/visitor_director.rs | 12 +- compiler/passes/src/symbol_table/create.rs | 4 +- .../src/type_checker/check_expressions.rs | 562 +++++++++++------- .../passes/src/type_checker/check_file.rs | 4 + .../src/type_checker/check_statements.rs | 87 +-- compiler/passes/src/type_checker/director.rs | 160 +---- .../compiler/group/mult_by_group_fail.leo.out | 2 +- .../compare_diff_types_fail.leo.out | 2 +- 9 files changed, 421 insertions(+), 430 deletions(-) diff --git a/compiler/ast/src/passes/visitor.rs b/compiler/ast/src/passes/visitor.rs index 47ff7736f4..79c449cd3d 100644 --- a/compiler/ast/src/passes/visitor.rs +++ b/compiler/ast/src/passes/visitor.rs @@ -30,37 +30,35 @@ impl Default for VisitResult { } pub trait ExpressionVisitor<'a> { - type Output; - - fn visit_expression(&mut self, _input: &'a Expression) -> (VisitResult, Option) { + fn visit_expression(&mut self, _input: &'a Expression) -> VisitResult { Default::default() } - fn visit_identifier(&mut self, _input: &'a Identifier) -> (VisitResult, Option) { + fn visit_identifier(&mut self, _input: &'a Identifier) -> VisitResult { Default::default() } - fn visit_value(&mut self, _input: &'a ValueExpression) -> (VisitResult, Option) { + fn visit_value(&mut self, _input: &'a ValueExpression) -> VisitResult { Default::default() } - fn visit_binary(&mut self, _input: &'a BinaryExpression) -> (VisitResult, Option) { + fn visit_binary(&mut self, _input: &'a BinaryExpression) -> VisitResult { Default::default() } - fn visit_unary(&mut self, _input: &'a UnaryExpression) -> (VisitResult, Option) { + fn visit_unary(&mut self, _input: &'a UnaryExpression) -> VisitResult { Default::default() } - fn visit_ternary(&mut self, _input: &'a TernaryExpression) -> (VisitResult, Option) { + fn visit_ternary(&mut self, _input: &'a TernaryExpression) -> VisitResult { Default::default() } - fn visit_call(&mut self, _input: &'a CallExpression) -> (VisitResult, Option) { + fn visit_call(&mut self, _input: &'a CallExpression) -> VisitResult { Default::default() } - fn visit_err(&mut self, _input: &'a ErrExpression) -> (VisitResult, Option) { + fn visit_err(&mut self, _input: &'a ErrExpression) -> VisitResult { Default::default() } } diff --git a/compiler/ast/src/passes/visitor_director.rs b/compiler/ast/src/passes/visitor_director.rs index e427511431..eab9b36870 100644 --- a/compiler/ast/src/passes/visitor_director.rs +++ b/compiler/ast/src/passes/visitor_director.rs @@ -32,7 +32,7 @@ pub trait ExpressionVisitorDirector<'a>: VisitorDirector<'a> { type Output; fn visit_expression(&mut self, input: &'a Expression) -> Option { - if let VisitResult::VisitChildren = self.visitor_ref().visit_expression(input).0 { + if let VisitResult::VisitChildren = self.visitor_ref().visit_expression(input) { match input { Expression::Identifier(expr) => self.visit_identifier(expr), Expression::Value(expr) => self.visit_value(expr), @@ -58,7 +58,7 @@ pub trait ExpressionVisitorDirector<'a>: VisitorDirector<'a> { } fn visit_binary(&mut self, input: &'a BinaryExpression) -> Option { - if let VisitResult::VisitChildren = self.visitor_ref().visit_binary(input).0 { + if let VisitResult::VisitChildren = self.visitor_ref().visit_binary(input) { self.visit_expression(&input.left); self.visit_expression(&input.right); } @@ -66,14 +66,14 @@ pub trait ExpressionVisitorDirector<'a>: VisitorDirector<'a> { } fn visit_unary(&mut self, input: &'a UnaryExpression) -> Option { - if let VisitResult::VisitChildren = self.visitor_ref().visit_unary(input).0 { + if let VisitResult::VisitChildren = self.visitor_ref().visit_unary(input) { self.visit_expression(&input.inner); } None } fn visit_ternary(&mut self, input: &'a TernaryExpression) -> Option { - if let VisitResult::VisitChildren = self.visitor_ref().visit_ternary(input).0 { + if let VisitResult::VisitChildren = self.visitor_ref().visit_ternary(input) { self.visit_expression(&input.condition); self.visit_expression(&input.if_true); self.visit_expression(&input.if_false); @@ -82,7 +82,7 @@ pub trait ExpressionVisitorDirector<'a>: VisitorDirector<'a> { } fn visit_call(&mut self, input: &'a CallExpression) -> Option { - if let VisitResult::VisitChildren = self.visitor_ref().visit_call(input).0 { + if let VisitResult::VisitChildren = self.visitor_ref().visit_call(input) { input.arguments.iter().for_each(|expr| { self.visit_expression(expr); }); @@ -113,7 +113,7 @@ pub trait StatementVisitorDirector<'a>: VisitorDirector<'a> + ExpressionVisitorD fn visit_return(&mut self, input: &'a ReturnStatement) { if let VisitResult::VisitChildren = self.visitor_ref().visit_return(input) { - self.visitor_ref().visit_expression(&input.expression); + self.visit_expression(&input.expression); } } diff --git a/compiler/passes/src/symbol_table/create.rs b/compiler/passes/src/symbol_table/create.rs index d24c054ffc..4d2b414a48 100644 --- a/compiler/passes/src/symbol_table/create.rs +++ b/compiler/passes/src/symbol_table/create.rs @@ -36,9 +36,7 @@ impl<'a> CreateSymbolTable<'a> { } } -impl<'a> ExpressionVisitor<'a> for CreateSymbolTable<'a> { - type Output = (); -} +impl<'a> ExpressionVisitor<'a> for CreateSymbolTable<'a> {} impl<'a> StatementVisitor<'a> for CreateSymbolTable<'a> {} diff --git a/compiler/passes/src/type_checker/check_expressions.rs b/compiler/passes/src/type_checker/check_expressions.rs index 5c6dc02019..cd1501e03e 100644 --- a/compiler/passes/src/type_checker/check_expressions.rs +++ b/compiler/passes/src/type_checker/check_expressions.rs @@ -19,6 +19,10 @@ use leo_errors::TypeCheckerError; use crate::TypeChecker; +use super::director::Director; + +impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {} + fn return_incorrect_type(t1: Option, t2: Option, expected: Option) -> Option { match (t1, t2) { (Some(t1), Some(t2)) if t1 == t2 => Some(t1), @@ -37,239 +41,373 @@ fn return_incorrect_type(t1: Option, t2: Option, expected: Option ExpressionVisitor<'a> for TypeChecker<'a> { +impl<'a> ExpressionVisitorDirector<'a> for Director<'a> { type Output = Type; - fn visit_identifier(&mut self, input: &'a Identifier) -> (VisitResult, Option) { - let type_ = if let Some(var) = self.symbol_table.lookup_variable(&input.name) { - Some(self.assert_type(*var.type_, self.expected_type)) - } else { - self.handler - .emit_err(TypeCheckerError::unknown_sym("variable", input.name, self.span).into()); - None - }; + fn visit_expression(&mut self, input: &'a Expression) -> Option { + if let VisitResult::VisitChildren = self.visitor.visit_expression(input) { + return match input { + Expression::Identifier(expr) => self.visit_identifier(expr), + Expression::Value(expr) => self.visit_value(expr), + Expression::Binary(expr) => self.visit_binary(expr), + Expression::Unary(expr) => self.visit_unary(expr), + Expression::Ternary(expr) => self.visit_ternary(expr), + Expression::Call(expr) => self.visit_call(expr), + Expression::Err(expr) => self.visit_err(expr), + }; + } - (VisitResult::VisitChildren, type_) + None } - fn visit_value(&mut self, input: &'a ValueExpression) -> (VisitResult, Option) { - let prev_span = self.span; - self.span = input.span(); + fn visit_identifier(&mut self, input: &'a Identifier) -> Option { + if let VisitResult::VisitChildren = self.visitor.visit_identifier(input) { + return if let Some(var) = self.visitor.symbol_table.clone().lookup_variable(&input.name) { + Some(self.visitor.assert_type(*var.type_, self.visitor.expected_type)) + } else { + self.visitor + .handler + .emit_err(TypeCheckerError::unknown_sym("variable", input.name, input.span()).into()); + None + }; + } - let type_ = Some(match input { - ValueExpression::Address(_, _) => self.assert_type(Type::Address, self.expected_type), - ValueExpression::Boolean(_, _) => self.assert_type(Type::Boolean, self.expected_type), - ValueExpression::Field(_, _) => self.assert_type(Type::Field, self.expected_type), - ValueExpression::Integer(type_, str_content, _) => { - match type_ { - IntegerType::I8 => { - let int = if self.negate { - format!("-{str_content}") - } else { - str_content.clone() - }; - - if int.parse::().is_err() { - self.handler - .emit_err(TypeCheckerError::invalid_int_value(int, "i8", input.span()).into()); - } - } - IntegerType::I16 => { - let int = if self.negate { - format!("-{str_content}") - } else { - str_content.clone() - }; - - if int.parse::().is_err() { - self.handler - .emit_err(TypeCheckerError::invalid_int_value(int, "i16", input.span()).into()); - } - } - IntegerType::I32 => { - let int = if self.negate { - format!("-{str_content}") - } else { - str_content.clone() - }; - - if int.parse::().is_err() { - self.handler - .emit_err(TypeCheckerError::invalid_int_value(int, "i32", input.span()).into()); - } - } - IntegerType::I64 => { - let int = if self.negate { - format!("-{str_content}") - } else { - str_content.clone() - }; - - if int.parse::().is_err() { - self.handler - .emit_err(TypeCheckerError::invalid_int_value(int, "i64", input.span()).into()); - } - } - IntegerType::I128 => { - let int = if self.negate { - format!("-{str_content}") - } else { - str_content.clone() - }; - - if int.parse::().is_err() { - self.handler - .emit_err(TypeCheckerError::invalid_int_value(int, "i128", input.span()).into()); - } - } - IntegerType::U8 if str_content.parse::().is_err() => self - .handler - .emit_err(TypeCheckerError::invalid_int_value(str_content, "u8", input.span()).into()), - IntegerType::U16 if str_content.parse::().is_err() => self - .handler - .emit_err(TypeCheckerError::invalid_int_value(str_content, "u16", input.span()).into()), - IntegerType::U32 if str_content.parse::().is_err() => self - .handler - .emit_err(TypeCheckerError::invalid_int_value(str_content, "u32", input.span()).into()), - IntegerType::U64 if str_content.parse::().is_err() => self - .handler - .emit_err(TypeCheckerError::invalid_int_value(str_content, "u64", input.span()).into()), - IntegerType::U128 if str_content.parse::().is_err() => self - .handler - .emit_err(TypeCheckerError::invalid_int_value(str_content, "u128", input.span()).into()), - _ => {} - } - self.assert_type(Type::IntegerType(*type_), self.expected_type) - } - ValueExpression::Group(_) => self.assert_type(Type::Group, self.expected_type), - ValueExpression::Scalar(_, _) => self.assert_type(Type::Scalar, self.expected_type), - ValueExpression::String(_, _) => unreachable!("String types are not reachable"), - }); - - self.span = prev_span; - (VisitResult::VisitChildren, type_) + None } - fn visit_binary(&mut self, input: &'a BinaryExpression) -> (VisitResult, Option) { - let prev_span = self.span; - self.span = input.span(); + fn visit_value(&mut self, input: &'a ValueExpression) -> Option { + if let VisitResult::VisitChildren = self.visitor.visit_value(input) { + return Some(match input { + ValueExpression::Address(_, _) => self.visitor.assert_type(Type::Address, self.visitor.expected_type), + ValueExpression::Boolean(_, _) => self.visitor.assert_type(Type::Boolean, self.visitor.expected_type), + ValueExpression::Field(_, _) => self.visitor.assert_type(Type::Field, self.visitor.expected_type), + ValueExpression::Integer(type_, str_content, _) => { + match type_ { + IntegerType::I8 => { + let int = if self.visitor.negate { + format!("-{str_content}") + } else { + str_content.clone() + }; - /* let type_ = match input.op { - BinaryOperation::And | BinaryOperation::Or => { - self.assert_type(Type::Boolean, self.expected_type); - let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span()); - let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span()); + if int.parse::().is_err() { + self.visitor + .handler + .emit_err(TypeCheckerError::invalid_int_value(int, "i8", input.span()).into()); + } + } + IntegerType::I16 => { + let int = if self.visitor.negate { + format!("-{str_content}") + } else { + str_content.clone() + }; - return_incorrect_type(t1, t2, self.expected_type) - } - BinaryOperation::Add => { - self.assert_field_group_scalar_int_type(self.expected_type, input.span()); - let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span()); - let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span()); + if int.parse::().is_err() { + self.visitor + .handler + .emit_err(TypeCheckerError::invalid_int_value(int, "i16", input.span()).into()); + } + } + IntegerType::I32 => { + let int = if self.visitor.negate { + format!("-{str_content}") + } else { + str_content.clone() + }; - return_incorrect_type(t1, t2, self.expected_type) - } - BinaryOperation::Sub => { - self.assert_field_group_int_type(self.expected_type, input.span()); - let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span()); - let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span()); + if int.parse::().is_err() { + self.visitor + .handler + .emit_err(TypeCheckerError::invalid_int_value(int, "i32", input.span()).into()); + } + } + IntegerType::I64 => { + let int = if self.visitor.negate { + format!("-{str_content}") + } else { + str_content.clone() + }; - return_incorrect_type(t1, t2, self.expected_type) - } - BinaryOperation::Mul => { - self.assert_field_group_int_type(self.expected_type, input.span()); + if int.parse::().is_err() { + self.visitor + .handler + .emit_err(TypeCheckerError::invalid_int_value(int, "i64", input.span()).into()); + } + } + IntegerType::I128 => { + let int = if self.visitor.negate { + format!("-{str_content}") + } else { + str_content.clone() + }; - let t1 = self.compare_expr_type(&input.left, None, input.left.span()); - let t2 = self.compare_expr_type(&input.right, None, input.right.span()); - - // Allow `group` * `scalar` multiplication. - match (t1.as_ref(), t2.as_ref()) { - (Some(Type::Group), Some(other)) => { - self.assert_type(Type::Group, self.expected_type); - self.assert_type(*other, Some(Type::Scalar)); - Some(Type::Group) + if int.parse::().is_err() { + self.visitor + .handler + .emit_err(TypeCheckerError::invalid_int_value(int, "i128", input.span()).into()); + } + } + IntegerType::U8 if str_content.parse::().is_err() => self + .visitor + .handler + .emit_err(TypeCheckerError::invalid_int_value(str_content, "u8", input.span()).into()), + IntegerType::U16 if str_content.parse::().is_err() => self + .visitor + .handler + .emit_err(TypeCheckerError::invalid_int_value(str_content, "u16", input.span()).into()), + IntegerType::U32 if str_content.parse::().is_err() => self + .visitor + .handler + .emit_err(TypeCheckerError::invalid_int_value(str_content, "u32", input.span()).into()), + IntegerType::U64 if str_content.parse::().is_err() => self + .visitor + .handler + .emit_err(TypeCheckerError::invalid_int_value(str_content, "u64", input.span()).into()), + IntegerType::U128 if str_content.parse::().is_err() => self + .visitor + .handler + .emit_err(TypeCheckerError::invalid_int_value(str_content, "u128", input.span()).into()), + _ => {} } - (Some(other), Some(Type::Group)) => { - self.assert_type(Type::Group, self.expected_type); - self.assert_type(*other, Some(Type::Scalar)); - Some(Type::Group) - } - _ => { - self.assert_type(t1.unwrap(), self.expected_type); - self.assert_type(t2.unwrap(), self.expected_type); - return_incorrect_type(t1, t2, self.expected_type) + self.visitor + .assert_type(Type::IntegerType(*type_), self.visitor.expected_type) + } + ValueExpression::Group(_) => self.visitor.assert_type(Type::Group, self.visitor.expected_type), + ValueExpression::Scalar(_, _) => self.visitor.assert_type(Type::Scalar, self.visitor.expected_type), + ValueExpression::String(_, _) => unreachable!("String types are not reachable"), + }); + } + + None + } + + fn visit_binary(&mut self, input: &'a BinaryExpression) -> Option { + if let VisitResult::VisitChildren = self.visitor.visit_binary(input) { + return match input.op { + BinaryOperation::And | BinaryOperation::Or => { + self.visitor.assert_type(Type::Boolean, self.visitor.expected_type); + let t1 = self.visit_expression(&input.left); + let t2 = self.visit_expression(&input.right); + + return_incorrect_type(t1, t2, self.visitor.expected_type) + } + BinaryOperation::Add => { + self.visitor + .assert_field_group_scalar_int_type(self.visitor.expected_type, input.span()); + let t1 = self.visit_expression(&input.left); + let t2 = self.visit_expression(&input.right); + + return_incorrect_type(t1, t2, self.visitor.expected_type) + } + BinaryOperation::Sub => { + self.visitor + .assert_field_group_int_type(self.visitor.expected_type, input.span()); + let t1 = self.visit_expression(&input.left); + let t2 = self.visit_expression(&input.right); + + return_incorrect_type(t1, t2, self.visitor.expected_type) + } + BinaryOperation::Mul => { + self.visitor + .assert_field_group_int_type(self.visitor.expected_type, input.span()); + + let prev_expected_type = self.visitor.expected_type; + self.visitor.expected_type = None; + let t1 = self.visit_expression(&input.left); + let t2 = self.visit_expression(&input.right); + self.visitor.expected_type = prev_expected_type; + + // Allow `group` * `scalar` multiplication. + match (t1.as_ref(), t2.as_ref()) { + (Some(Type::Group), Some(other)) + | (Some(other), Some(Type::Group)) => { + self.visitor.assert_type(Type::Group, self.visitor.expected_type); + self.visitor.assert_type(*other, Some(Type::Scalar)); + Some(Type::Group) + } + _ => { + self.visitor.assert_type(t1.unwrap(), self.visitor.expected_type); + self.visitor.assert_type(t2.unwrap(), self.visitor.expected_type); + return_incorrect_type(t1, t2, self.visitor.expected_type) + } } } - } - BinaryOperation::Div => { - self.assert_field_int_type(self.expected_type, input.span()); + BinaryOperation::Div => { + self.visitor + .assert_field_int_type(self.visitor.expected_type, input.span()); - let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span()); - let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span()); - return_incorrect_type(t1, t2, self.expected_type) - } - BinaryOperation::Pow => { - let t1 = self.compare_expr_type(&input.left, None, input.left.span()); - let t2 = self.compare_expr_type(&input.right, None, input.right.span()); + let t1 = self.visit_expression(&input.left); + let t2 = self.visit_expression(&input.right); + + return_incorrect_type(t1, t2, self.visitor.expected_type) + } + BinaryOperation::Pow => { + let prev_expected_type = self.visitor.expected_type; + self.visitor.expected_type = None; + + let t1 = self.visit_expression(&input.left); + let t2 = self.visit_expression(&input.right); + + self.visitor.expected_type = prev_expected_type; - match (t1.as_ref(), t2.as_ref()) { - // Type A must be an int. - // Type B must be a unsigned int. - (Some(Type::IntegerType(_)), Some(Type::IntegerType(itype))) if !itype.is_signed() => { - self.assert_type(t1.unwrap(), self.expected_type); - } - // Type A was an int. - // But Type B was not a unsigned int. - (Some(Type::IntegerType(_)), Some(t)) => { - self.handler.emit_err( - TypeCheckerError::incorrect_pow_exponent_type("unsigned int", t, input.right.span()) - .into(), - ); - } - // Type A must be a field. - // Type B must be an int. - (Some(Type::Field), Some(Type::IntegerType(_))) => { - self.assert_type(Type::Field, self.expected_type); - } - // Type A was a field. - // But Type B was not an int. - (Some(Type::Field), Some(t)) => { - self.handler.emit_err( - TypeCheckerError::incorrect_pow_exponent_type("int", t, input.right.span()).into(), - ); - } - // The base is some type thats not an int or field. - (Some(t), _) => { - self.handler - .emit_err(TypeCheckerError::incorrect_pow_base_type(t, input.left.span()).into()); + match (t1.as_ref(), t2.as_ref()) { + // Type A must be an int. + // Type B must be a unsigned int. + (Some(Type::IntegerType(_)), Some(Type::IntegerType(itype))) if !itype.is_signed() => { + self.visitor.assert_type(t1.unwrap(), self.visitor.expected_type); + } + // Type A was an int. + // But Type B was not a unsigned int. + (Some(Type::IntegerType(_)), Some(t)) => { + self.visitor.handler.emit_err( + TypeCheckerError::incorrect_pow_exponent_type("unsigned int", t, input.right.span()) + .into(), + ); + } + // Type A must be a field. + // Type B must be an int. + (Some(Type::Field), Some(Type::IntegerType(_))) => { + self.visitor.assert_type(Type::Field, self.visitor.expected_type); + } + // Type A was a field. + // But Type B was not an int. + (Some(Type::Field), Some(t)) => { + self.visitor.handler.emit_err( + TypeCheckerError::incorrect_pow_exponent_type("int", t, input.right.span()).into(), + ); + } + // The base is some type thats not an int or field. + (Some(t), _) => { + self.visitor + .handler + .emit_err(TypeCheckerError::incorrect_pow_base_type(t, input.left.span()).into()); + } + _ => {} } + + t1 + } + BinaryOperation::Eq | BinaryOperation::Ne => { + let prev_expected_type = self.visitor.expected_type; + self.visitor.expected_type = None; + + let t1 = self.visit_expression(&input.left); + let t2 = self.visit_expression(&input.right); + + self.visitor.expected_type = prev_expected_type; + self.visitor.assert_eq_types(t1, t2, input.span()); + + Some(Type::Boolean) + } + BinaryOperation::Lt | BinaryOperation::Gt | BinaryOperation::Le | BinaryOperation::Ge => { + let prev_expected_type = self.visitor.expected_type; + self.visitor.expected_type = None; + + let t1 = self.visit_expression(&input.left); + self.visitor.assert_field_scalar_int_type(t1, input.left.span()); + + let t2 = self.visit_expression(&input.right); + self.visitor.assert_field_scalar_int_type(t2, input.right.span()); + + self.visitor.expected_type = prev_expected_type; + self.visitor.assert_eq_types(t1, t2, input.span()); + + Some(Type::Boolean) + } + }; + } + + None + } + + fn visit_unary(&mut self, input: &'a UnaryExpression) -> Option { + match input.op { + UnaryOperation::Not => { + self.visitor.assert_type(Type::Boolean, self.visitor.expected_type); + self.visit_expression(&input.inner) + } + UnaryOperation::Negate => { + let prior_negate_state = self.visitor.negate; + self.visitor.negate = true; + + let type_ = self.visit_expression(&input.inner); + self.visitor.negate = prior_negate_state; + match type_.as_ref() { + Some( + Type::IntegerType( + IntegerType::I8 + | IntegerType::I16 + | IntegerType::I32 + | IntegerType::I64 + | IntegerType::I128, + ) + | Type::Field + | Type::Group, + ) => {} + Some(t) => self + .visitor + .handler + .emit_err(TypeCheckerError::type_is_not_negatable(t, input.inner.span()).into()), _ => {} + }; + type_ + } + } + } + + fn visit_ternary(&mut self, input: &'a TernaryExpression) -> Option { + if let VisitResult::VisitChildren = self.visitor.visit_ternary(input) { + let prev_expected_type = self.visitor.expected_type; + self.visitor.expected_type = Some(Type::Boolean); + self.visit_expression(&input.condition); + self.visitor.expected_type = prev_expected_type; + + let t1 = self.visit_expression(&input.if_true); + let t2 = self.visit_expression(&input.if_false); + + return return_incorrect_type(t1, t2, self.visitor.expected_type); + } + + None + } + + fn visit_call(&mut self, input: &'a CallExpression) -> Option { + match &*input.function { + Expression::Identifier(ident) => { + if let Some(func) = self.visitor.symbol_table.clone().lookup_fn(&ident.name) { + let ret = self.visitor.assert_type(func.output, self.visitor.expected_type); + + if func.input.len() != input.arguments.len() { + self.visitor.handler.emit_err( + TypeCheckerError::incorrect_num_args_to_call( + func.input.len(), + input.arguments.len(), + input.span(), + ) + .into(), + ); + } + + func.input + .iter() + .zip(input.arguments.iter()) + .for_each(|(expected, argument)| { + let prev_expected_type = self.visitor.expected_type; + self.visitor.expected_type = Some(expected.get_variable().type_); + self.visit_expression(argument); + self.visitor.expected_type = prev_expected_type; + }); + + Some(ret) + } else { + self.visitor + .handler + .emit_err(TypeCheckerError::unknown_sym("function", &ident.name, ident.span()).into()); + None } - - t1 } - BinaryOperation::Eq | BinaryOperation::Ne => { - let t1 = self.compare_expr_type(&input.left, None, input.left.span()); - let t2 = self.compare_expr_type(&input.right, None, input.right.span()); - - self.assert_eq_types(t1, t2, input.span()); - - Some(Type::Boolean) - } - BinaryOperation::Lt | BinaryOperation::Gt | BinaryOperation::Le | BinaryOperation::Ge => { - let t1 = self.compare_expr_type(&input.left, None, input.left.span()); - self.assert_field_scalar_int_type(t1, input.left.span()); - - let t2 = self.compare_expr_type(&input.right, None, input.right.span()); - self.assert_field_scalar_int_type(t2, input.right.span()); - - self.assert_eq_types(t1, t2, input.span()); - - Some(Type::Boolean) - } - }; */ - - self.span = prev_span; - (VisitResult::VisitChildren, None) + expr => self.visit_expression(expr), + } } } diff --git a/compiler/passes/src/type_checker/check_file.rs b/compiler/passes/src/type_checker/check_file.rs index 394079a211..b8c61e38ff 100644 --- a/compiler/passes/src/type_checker/check_file.rs +++ b/compiler/passes/src/type_checker/check_file.rs @@ -18,6 +18,8 @@ use leo_ast::*; use crate::{Declaration, TypeChecker, VariableSymbol}; +use super::director::Director; + impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { fn visit_function(&mut self, input: &'a Function) -> VisitResult { self.symbol_table.clear_variables(); @@ -40,3 +42,5 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { VisitResult::VisitChildren } } + +impl<'a> ProgramVisitorDirector<'a> for Director<'a> {} diff --git a/compiler/passes/src/type_checker/check_statements.rs b/compiler/passes/src/type_checker/check_statements.rs index 5294165bc5..888bb42f8f 100644 --- a/compiler/passes/src/type_checker/check_statements.rs +++ b/compiler/passes/src/type_checker/check_statements.rs @@ -19,20 +19,23 @@ use leo_errors::TypeCheckerError; use crate::{Declaration, TypeChecker, VariableSymbol}; -impl<'a> StatementVisitor<'a> for TypeChecker<'a> { - fn visit_return(&mut self, input: &'a ReturnStatement) -> VisitResult { +use super::director::Director; + +impl<'a> StatementVisitor<'a> for TypeChecker<'a> {} + +impl<'a> StatementVisitorDirector<'a> for Director<'a> { + fn visit_return(&mut self, input: &'a ReturnStatement) { // we can safely unwrap all self.parent instances because // statements should always have some parent block - let parent = self.parent.unwrap(); + let parent = self.visitor.parent.unwrap(); - // Would never be None. - let func_output_type = self.symbol_table.lookup_fn(&parent).map(|f| f.output); - // self.compare_expr_type(&input.expression, func_output_type, input.expression.span()); - - VisitResult::VisitChildren + let prev_expected_type = self.visitor.expected_type; + self.visitor.expected_type = self.visitor.symbol_table.lookup_fn(&parent).map(|f| f.output); + self.visit_expression(&input.expression); + self.visitor.expected_type = prev_expected_type; } - fn visit_definition(&mut self, input: &'a DefinitionStatement) -> VisitResult { + fn visit_definition(&mut self, input: &'a DefinitionStatement) { let declaration = if input.declaration_type == Declare::Const { Declaration::Const } else { @@ -40,7 +43,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { }; input.variable_names.iter().for_each(|v| { - if let Err(err) = self.symbol_table.insert_variable( + if let Err(err) = self.visitor.symbol_table.insert_variable( v.identifier.name, VariableSymbol { type_: &input.type_, @@ -48,23 +51,26 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { declaration: declaration.clone(), }, ) { - self.handler.emit_err(err); + self.visitor.handler.emit_err(err); } - // self.compare_expr_type(&input.value, Some(input.type_), input.value.span()); + let prev_expected_type = self.visitor.expected_type; + self.visitor.expected_type = Some(input.type_); + self.visit_expression(&input.value); + self.visitor.expected_type = prev_expected_type; }); - - VisitResult::VisitChildren } - fn visit_assign(&mut self, input: &'a AssignStatement) -> VisitResult { + fn visit_assign(&mut self, input: &'a AssignStatement) { let var_name = &input.assignee.identifier.name; - let var_type = if let Some(var) = self.symbol_table.lookup_variable(var_name) { + let var_type = if let Some(var) = self.visitor.symbol_table.lookup_variable(var_name) { match &var.declaration { Declaration::Const => self + .visitor .handler .emit_err(TypeCheckerError::cannont_assign_to_const_var(var_name, var.span).into()), Declaration::Input(ParamMode::Constant) => self + .visitor .handler .emit_err(TypeCheckerError::cannont_assign_to_const_input(var_name, var.span).into()), _ => {} @@ -72,7 +78,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { Some(*var.type_) } else { - self.handler.emit_err( + self.visitor.handler.emit_err( TypeCheckerError::unknown_sym("variable", &input.assignee.identifier.name, input.assignee.span).into(), ); @@ -80,20 +86,22 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { }; if var_type.is_some() { - // self.compare_expr_type(&input.value, var_type, input.value.span()); + let prev_expected_type = self.visitor.expected_type; + self.visitor.expected_type = var_type; + self.visit_expression(&input.value); + self.visitor.expected_type = prev_expected_type; } - - VisitResult::VisitChildren } - fn visit_conditional(&mut self, input: &'a ConditionalStatement) -> VisitResult { - // self.compare_expr_type(&input.condition, Some(Type::Boolean), input.condition.span()); - - VisitResult::VisitChildren + fn visit_conditional(&mut self, input: &'a ConditionalStatement) { + let prev_expected_type = self.visitor.expected_type; + self.visitor.expected_type = Some(Type::Boolean); + self.visit_expression(&input.condition); + self.visitor.expected_type = prev_expected_type; } - fn visit_iteration(&mut self, input: &'a IterationStatement) -> VisitResult { - if let Err(err) = self.symbol_table.insert_variable( + fn visit_iteration(&mut self, input: &'a IterationStatement) { + if let Err(err) = self.visitor.symbol_table.insert_variable( input.variable.name, VariableSymbol { type_: &input.type_, @@ -101,30 +109,33 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { declaration: Declaration::Const, }, ) { - self.handler.emit_err(err); + self.visitor.handler.emit_err(err); } - // self.compare_expr_type(&input.start, Some(input.type_), input.start.span()); - // self.compare_expr_type(&input.stop, Some(input.type_), input.stop.span()); - - VisitResult::VisitChildren + let prev_expected_type = self.visitor.expected_type; + self.visitor.expected_type = Some(input.type_); + self.visit_expression(&input.start); + self.visit_expression(&input.stop); + self.visitor.expected_type = prev_expected_type; } - fn visit_console(&mut self, input: &'a ConsoleStatement) -> VisitResult { + fn visit_console(&mut self, input: &'a ConsoleStatement) { match &input.function { ConsoleFunction::Assert(expr) => { + let prev_expected_type = self.visitor.expected_type; + self.visitor.expected_type = Some(Type::Boolean); + self.visit_expression(expr); + self.visitor.expected_type = prev_expected_type; // self.compare_expr_type(expr, Some(Type::Boolean), expr.span()); } ConsoleFunction::Error(_) | ConsoleFunction::Log(_) => { // TODO: undetermined } } - - VisitResult::VisitChildren } - fn visit_block(&mut self, input: &'a Block) -> VisitResult { - self.symbol_table.push_variable_scope(); + fn visit_block(&mut self, input: &'a Block) { + self.visitor.symbol_table.push_variable_scope(); // have to redo the logic here so we have scoping input.statements.iter().for_each(|stmt| { match stmt { @@ -137,8 +148,6 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { Statement::Block(stmt) => self.visit_block(stmt), }; }); - self.symbol_table.pop_variable_scope(); - - VisitResult::SkipChildren + self.visitor.symbol_table.pop_variable_scope(); } } diff --git a/compiler/passes/src/type_checker/director.rs b/compiler/passes/src/type_checker/director.rs index dac353a5cc..30c813bb84 100644 --- a/compiler/passes/src/type_checker/director.rs +++ b/compiler/passes/src/type_checker/director.rs @@ -15,12 +15,12 @@ // along with the Leo library. If not, see . use leo_ast::*; -use leo_errors::{emitter::Handler, TypeCheckerError}; +use leo_errors::emitter::Handler; use crate::{SymbolTable, TypeChecker}; pub(crate) struct Director<'a> { - visitor: TypeChecker<'a>, + pub(crate) visitor: TypeChecker<'a>, } impl<'a> Director<'a> { @@ -42,159 +42,3 @@ impl<'a> VisitorDirector<'a> for Director<'a> { &mut self.visitor } } - -fn return_incorrect_type(t1: Option, t2: Option, expected: Option) -> Option { - match (t1, t2) { - (Some(t1), Some(t2)) if t1 == t2 => Some(t1), - (Some(t1), Some(t2)) => { - if let Some(expected) = expected { - if t1 != expected { - Some(t1) - } else { - Some(t2) - } - } else { - Some(t1) - } - } - (None, Some(_)) | (Some(_), None) | (None, None) => None, - } -} - -impl<'a> ExpressionVisitorDirector<'a> for Director<'a> { - type Output = Type; - - fn visit_expression(&mut self, input: &'a Expression) -> Option { - if let VisitResult::VisitChildren = self.visitor.visit_expression(input).0 { - return match input { - Expression::Identifier(expr) => self.visit_identifier(expr), - Expression::Value(expr) => self.visit_value(expr), - Expression::Binary(expr) => self.visit_binary(expr), - Expression::Unary(expr) => self.visit_unary(expr), - Expression::Ternary(expr) => self.visit_ternary(expr), - Expression::Call(expr) => self.visit_call(expr), - Expression::Err(expr) => self.visit_err(expr), - }; - } - - None - } - - fn visit_identifier(&mut self, input: &'a Identifier) -> Option { - self.visitor.visit_identifier(input).1 - } - - fn visit_value(&mut self, input: &'a ValueExpression) -> Option { - self.visitor.visit_value(input).1 - } - - fn visit_binary(&mut self, input: &'a BinaryExpression) -> Option { - match self.visitor.visit_binary(input) { - (VisitResult::VisitChildren, expected) => { - let t1 = self.visit_expression(&input.left); - let t2 = self.visit_expression(&input.right); - - return_incorrect_type(t1, t2, self.visitor.expected_type) - } - _ => None, - } - } - - fn visit_unary(&mut self, input: &'a UnaryExpression) -> Option { - match input.op { - UnaryOperation::Not => { - self.visitor.assert_type(Type::Boolean, self.visitor.expected_type); - self.visit_expression(&input.inner) - } - UnaryOperation::Negate => { - let prior_negate_state = self.visitor.negate; - self.visitor.negate = true; - - let type_ = self.visit_expression(&input.inner); - self.visitor.negate = prior_negate_state; - match type_.as_ref() { - Some( - Type::IntegerType( - IntegerType::I8 - | IntegerType::I16 - | IntegerType::I32 - | IntegerType::I64 - | IntegerType::I128, - ) - | Type::Field - | Type::Group, - ) => {} - Some(t) => self - .visitor - .handler - .emit_err(TypeCheckerError::type_is_not_negatable(t, input.inner.span()).into()), - _ => {} - }; - type_ - } - } - } - - fn visit_ternary(&mut self, input: &'a TernaryExpression) -> Option { - if let VisitResult::VisitChildren = self.visitor.visit_ternary(input).0 { - let prev_expected_type = self.visitor.expected_type; - self.visitor.expected_type = Some(Type::Boolean); - self.visit_expression(&input.condition); - self.visitor.expected_type = prev_expected_type; - - let t1 = self.visit_expression(&input.if_true); - let t2 = self.visit_expression(&input.if_false); - - return return_incorrect_type(t1, t2, self.visitor.expected_type); - } - - None - } - - fn visit_call(&mut self, input: &'a CallExpression) -> Option { - match &*input.function { - Expression::Identifier(ident) => { - if let Some(func) = self.visitor.symbol_table.clone().lookup_fn(&ident.name) { - let ret = self.visitor.assert_type(func.output, self.visitor.expected_type); - - if func.input.len() != input.arguments.len() { - self.visitor.handler.emit_err( - TypeCheckerError::incorrect_num_args_to_call( - func.input.len(), - input.arguments.len(), - input.span(), - ) - .into(), - ); - } - - func.input - .iter() - .zip(input.arguments.iter()) - .for_each(|(expected, argument)| { - let prev_expected_type = self.visitor.expected_type; - self.visitor.expected_type = Some(expected.get_variable().type_); - self.visit_expression(argument); - self.visitor.expected_type = prev_expected_type; - }); - - Some(ret) - } else { - self.visitor - .handler - .emit_err(TypeCheckerError::unknown_sym("function", &ident.name, ident.span()).into()); - None - } - } - expr => self.visit_expression(expr), - } - } - - fn visit_err(&mut self, input: &'a ErrExpression) -> Option { - self.visitor.visit_err(input).1 - } -} - -impl<'a> StatementVisitorDirector<'a> for Director<'a> {} - -impl<'a> ProgramVisitorDirector<'a> for Director<'a> {} diff --git a/tests/expectations/compiler/compiler/group/mult_by_group_fail.leo.out b/tests/expectations/compiler/compiler/group/mult_by_group_fail.leo.out index bb55a5caca..81b8873b74 100644 --- a/tests/expectations/compiler/compiler/group/mult_by_group_fail.leo.out +++ b/tests/expectations/compiler/compiler/group/mult_by_group_fail.leo.out @@ -2,4 +2,4 @@ namespace: Compile expectation: Fail outputs: - - "Error [ETYC0372002]: Found type `group` but type `scalar` was expected\n --> compiler-test:4:12\n |\n 4 | return (_, _)group * a;\n | ^^^^^^^^^^^^^^^\n" + - "Error [ETYC0372002]: Found type `group` but type `scalar` was expected\n --> compiler-test:1:1\n |\n 1 | \n | \n" diff --git a/tests/expectations/compiler/compiler/statements/compare_diff_types_fail.leo.out b/tests/expectations/compiler/compiler/statements/compare_diff_types_fail.leo.out index a3376fe4ec..de17a6ae9a 100644 --- a/tests/expectations/compiler/compiler/statements/compare_diff_types_fail.leo.out +++ b/tests/expectations/compiler/compiler/statements/compare_diff_types_fail.leo.out @@ -2,4 +2,4 @@ namespace: Compile expectation: Fail outputs: - - "Error [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:4:19\n |\n 4 | let b: bool = a == 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:5:19\n |\n 5 | let c: bool = a != 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:6:19\n |\n 6 | let d: bool = a > 1u8;\n | ^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:7:19\n |\n 7 | let e: bool = a < 1u8;\n | ^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:8:19\n |\n 8 | let f: bool = a >= 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:9:19\n |\n 9 | let g: bool = a <= 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u32` was expected\n --> compiler-test:10:18\n |\n 10 | let h: u32 = a * 1u8;\n | ^\nError [ETYC0372002]: Found type `u8` but type `u32` was expected\n --> compiler-test:10:22\n |\n 10 | let h: u32 = a * 1u8;\n | ^^^\n" + - "Error [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:4:19\n |\n 4 | let b: bool = a == 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:5:19\n |\n 5 | let c: bool = a != 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:6:19\n |\n 6 | let d: bool = a > 1u8;\n | ^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:7:19\n |\n 7 | let e: bool = a < 1u8;\n | ^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:8:19\n |\n 8 | let f: bool = a >= 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:9:19\n |\n 9 | let g: bool = a <= 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u32` was expected\n --> compiler-test:1:1\n |\n 1 | \n | \nError [ETYC0372002]: Found type `u8` but type `u32` was expected\n --> compiler-test:1:1\n |\n 1 | \n | \n"