diff --git a/compiler/ast/src/passes/visitor_director.rs b/compiler/ast/src/passes/visitor_director.rs index eab9b36870..5ebe2bfd8d 100644 --- a/compiler/ast/src/passes/visitor_director.rs +++ b/compiler/ast/src/passes/visitor_director.rs @@ -29,68 +29,77 @@ pub trait VisitorDirector<'a> { } pub trait ExpressionVisitorDirector<'a>: VisitorDirector<'a> { + type AdditionalInput: Default; type Output; - fn visit_expression(&mut self, input: &'a Expression) -> Option { + fn visit_expression(&mut self, input: &'a Expression, additional: &Self::AdditionalInput) -> Option { if let VisitResult::VisitChildren = self.visitor_ref().visit_expression(input) { match input { - Expression::Identifier(expr) => self.visit_identifier(expr), - Expression::Value(expr) => self.visit_value(expr), - Expression::Binary(expr) => self.visit_binary(expr), - Expression::Unary(expr) => self.visit_unary(expr), - Expression::Ternary(expr) => self.visit_ternary(expr), - Expression::Call(expr) => self.visit_call(expr), - Expression::Err(expr) => self.visit_err(expr), + Expression::Identifier(expr) => self.visit_identifier(expr, additional), + Expression::Value(expr) => self.visit_value(expr, additional), + Expression::Binary(expr) => self.visit_binary(expr, additional), + Expression::Unary(expr) => self.visit_unary(expr, additional), + Expression::Ternary(expr) => self.visit_ternary(expr, additional), + Expression::Call(expr) => self.visit_call(expr, additional), + Expression::Err(expr) => self.visit_err(expr, additional), }; } None } - fn visit_identifier(&mut self, input: &'a Identifier) -> Option { + fn visit_identifier(&mut self, input: &'a Identifier, _additional: &Self::AdditionalInput) -> Option { self.visitor_ref().visit_identifier(input); None } - fn visit_value(&mut self, input: &'a ValueExpression) -> Option { + fn visit_value(&mut self, input: &'a ValueExpression, _additional: &Self::AdditionalInput) -> Option { self.visitor_ref().visit_value(input); None } - fn visit_binary(&mut self, input: &'a BinaryExpression) -> Option { + fn visit_binary( + &mut self, + input: &'a BinaryExpression, + additional: &Self::AdditionalInput, + ) -> Option { if let VisitResult::VisitChildren = self.visitor_ref().visit_binary(input) { - self.visit_expression(&input.left); - self.visit_expression(&input.right); + self.visit_expression(&input.left, additional); + self.visit_expression(&input.right, additional); } None } - fn visit_unary(&mut self, input: &'a UnaryExpression) -> Option { + fn visit_unary(&mut self, input: &'a UnaryExpression, additional: &Self::AdditionalInput) -> Option { if let VisitResult::VisitChildren = self.visitor_ref().visit_unary(input) { - self.visit_expression(&input.inner); + self.visit_expression(&input.inner, additional); } None } - fn visit_ternary(&mut self, input: &'a TernaryExpression) -> Option { + fn visit_ternary( + &mut self, + input: &'a TernaryExpression, + additional: &Self::AdditionalInput, + ) -> Option { if let VisitResult::VisitChildren = self.visitor_ref().visit_ternary(input) { - self.visit_expression(&input.condition); - self.visit_expression(&input.if_true); - self.visit_expression(&input.if_false); + self.visit_expression(&input.condition, additional); + self.visit_expression(&input.if_true, additional); + self.visit_expression(&input.if_false, additional); } None } - fn visit_call(&mut self, input: &'a CallExpression) -> Option { + fn visit_call(&mut self, input: &'a CallExpression, additional: &Self::AdditionalInput) -> Option { if let VisitResult::VisitChildren = self.visitor_ref().visit_call(input) { input.arguments.iter().for_each(|expr| { - self.visit_expression(expr); + self.visit_expression(expr, additional); }); } None } - fn visit_err(&mut self, input: &'a ErrExpression) -> Option { + fn visit_err(&mut self, input: &'a ErrExpression, _additional: &Self::AdditionalInput) -> Option { self.visitor_ref().visit_err(input); None } @@ -113,25 +122,25 @@ pub trait StatementVisitorDirector<'a>: VisitorDirector<'a> + ExpressionVisitorD fn visit_return(&mut self, input: &'a ReturnStatement) { if let VisitResult::VisitChildren = self.visitor_ref().visit_return(input) { - self.visit_expression(&input.expression); + self.visit_expression(&input.expression, &Default::default()); } } fn visit_definition(&mut self, input: &'a DefinitionStatement) { if let VisitResult::VisitChildren = self.visitor_ref().visit_definition(input) { - self.visit_expression(&input.value); + self.visit_expression(&input.value, &Default::default()); } } fn visit_assign(&mut self, input: &'a AssignStatement) { if let VisitResult::VisitChildren = self.visitor_ref().visit_assign(input) { - self.visit_expression(&input.value); + self.visit_expression(&input.value, &Default::default()); } } fn visit_conditional(&mut self, input: &'a ConditionalStatement) { if let VisitResult::VisitChildren = self.visitor_ref().visit_conditional(input) { - self.visit_expression(&input.condition); + self.visit_expression(&input.condition, &Default::default()); self.visit_block(&input.block); if let Some(stmt) = input.next.as_ref() { self.visit_statement(stmt); @@ -141,8 +150,8 @@ pub trait StatementVisitorDirector<'a>: VisitorDirector<'a> + ExpressionVisitorD fn visit_iteration(&mut self, input: &'a IterationStatement) { if let VisitResult::VisitChildren = self.visitor_ref().visit_iteration(input) { - self.visit_expression(&input.start); - self.visit_expression(&input.stop); + self.visit_expression(&input.start, &Default::default()); + self.visit_expression(&input.stop, &Default::default()); self.visit_block(&input.block); } } @@ -150,10 +159,10 @@ pub trait StatementVisitorDirector<'a>: VisitorDirector<'a> + ExpressionVisitorD fn visit_console(&mut self, input: &'a ConsoleStatement) { if let VisitResult::VisitChildren = self.visitor_ref().visit_console(input) { match &input.function { - ConsoleFunction::Assert(expr) => self.visit_expression(expr), + ConsoleFunction::Assert(expr) => self.visit_expression(expr, &Default::default()), ConsoleFunction::Error(fmt) | ConsoleFunction::Log(fmt) => { fmt.parameters.iter().for_each(|expr| { - self.visit_expression(expr); + self.visit_expression(expr, &Default::default()); }); None } diff --git a/compiler/passes/src/symbol_table/director.rs b/compiler/passes/src/symbol_table/director.rs index 3627ea4d8f..8dfe4190fa 100644 --- a/compiler/passes/src/symbol_table/director.rs +++ b/compiler/passes/src/symbol_table/director.rs @@ -44,6 +44,7 @@ impl<'a> VisitorDirector<'a> for Director<'a> { } impl<'a> ExpressionVisitorDirector<'a> for Director<'a> { + type AdditionalInput = (); type Output = (); } diff --git a/compiler/passes/src/type_checker/check_expressions.rs b/compiler/passes/src/type_checker/check_expressions.rs index cd1501e03e..e34361a247 100644 --- a/compiler/passes/src/type_checker/check_expressions.rs +++ b/compiler/passes/src/type_checker/check_expressions.rs @@ -23,12 +23,12 @@ use super::director::Director; impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {} -fn return_incorrect_type(t1: Option, t2: Option, expected: Option) -> Option { +fn return_incorrect_type(t1: Option, t2: Option, expected: &Option) -> Option { match (t1, t2) { (Some(t1), Some(t2)) if t1 == t2 => Some(t1), (Some(t1), Some(t2)) => { if let Some(expected) = expected { - if t1 != expected { + if &t1 != expected { Some(t1) } else { Some(t2) @@ -42,28 +42,29 @@ fn return_incorrect_type(t1: Option, t2: Option, expected: Option ExpressionVisitorDirector<'a> for Director<'a> { + type AdditionalInput = Option; type Output = Type; - fn visit_expression(&mut self, input: &'a Expression) -> Option { + fn visit_expression(&mut self, input: &'a Expression, expected: &Self::AdditionalInput) -> Option { if let VisitResult::VisitChildren = self.visitor.visit_expression(input) { return match input { - Expression::Identifier(expr) => self.visit_identifier(expr), - Expression::Value(expr) => self.visit_value(expr), - Expression::Binary(expr) => self.visit_binary(expr), - Expression::Unary(expr) => self.visit_unary(expr), - Expression::Ternary(expr) => self.visit_ternary(expr), - Expression::Call(expr) => self.visit_call(expr), - Expression::Err(expr) => self.visit_err(expr), + Expression::Identifier(expr) => self.visit_identifier(expr, expected), + Expression::Value(expr) => self.visit_value(expr, expected), + Expression::Binary(expr) => self.visit_binary(expr, expected), + Expression::Unary(expr) => self.visit_unary(expr, expected), + Expression::Ternary(expr) => self.visit_ternary(expr, expected), + Expression::Call(expr) => self.visit_call(expr, expected), + Expression::Err(expr) => self.visit_err(expr, expected), }; } None } - fn visit_identifier(&mut self, input: &'a Identifier) -> Option { + fn visit_identifier(&mut self, input: &'a Identifier, expected: &Self::AdditionalInput) -> Option { if let VisitResult::VisitChildren = self.visitor.visit_identifier(input) { return if let Some(var) = self.visitor.symbol_table.clone().lookup_variable(&input.name) { - Some(self.visitor.assert_type(*var.type_, self.visitor.expected_type)) + Some(self.visitor.assert_type(*var.type_, expected, var.span)) } else { self.visitor .handler @@ -75,12 +76,12 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> { None } - fn visit_value(&mut self, input: &'a ValueExpression) -> Option { + fn visit_value(&mut self, input: &'a ValueExpression, expected: &Self::AdditionalInput) -> Option { if let VisitResult::VisitChildren = self.visitor.visit_value(input) { return Some(match input { - ValueExpression::Address(_, _) => self.visitor.assert_type(Type::Address, self.visitor.expected_type), - ValueExpression::Boolean(_, _) => self.visitor.assert_type(Type::Boolean, self.visitor.expected_type), - ValueExpression::Field(_, _) => self.visitor.assert_type(Type::Field, self.visitor.expected_type), + ValueExpression::Address(_, _) => self.visitor.assert_type(Type::Address, expected, input.span()), + ValueExpression::Boolean(_, _) => self.visitor.assert_type(Type::Boolean, expected, input.span()), + ValueExpression::Field(_, _) => self.visitor.assert_type(Type::Field, expected, input.span()), ValueExpression::Integer(type_, str_content, _) => { match type_ { IntegerType::I8 => { @@ -171,10 +172,10 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> { _ => {} } self.visitor - .assert_type(Type::IntegerType(*type_), self.visitor.expected_type) + .assert_type(Type::IntegerType(*type_), expected, input.span()) } - ValueExpression::Group(_) => self.visitor.assert_type(Type::Group, self.visitor.expected_type), - ValueExpression::Scalar(_, _) => self.visitor.assert_type(Type::Scalar, self.visitor.expected_type), + ValueExpression::Group(_) => self.visitor.assert_type(Type::Group, expected, input.span()), + ValueExpression::Scalar(_, _) => self.visitor.assert_type(Type::Scalar, expected, input.span()), ValueExpression::String(_, _) => unreachable!("String types are not reachable"), }); } @@ -182,80 +183,73 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> { None } - fn visit_binary(&mut self, input: &'a BinaryExpression) -> Option { + fn visit_binary(&mut self, input: &'a BinaryExpression, expected: &Self::AdditionalInput) -> Option { if let VisitResult::VisitChildren = self.visitor.visit_binary(input) { return match input.op { BinaryOperation::And | BinaryOperation::Or => { - self.visitor.assert_type(Type::Boolean, self.visitor.expected_type); - let t1 = self.visit_expression(&input.left); - let t2 = self.visit_expression(&input.right); + self.visitor.assert_type(Type::Boolean, expected, input.span()); + let t1 = self.visit_expression(&input.left, expected); + let t2 = self.visit_expression(&input.right, expected); - return_incorrect_type(t1, t2, self.visitor.expected_type) + return_incorrect_type(t1, t2, expected) } BinaryOperation::Add => { - self.visitor - .assert_field_group_scalar_int_type(self.visitor.expected_type, input.span()); - let t1 = self.visit_expression(&input.left); - let t2 = self.visit_expression(&input.right); + self.visitor.assert_field_group_scalar_int_type(expected, input.span()); + let t1 = self.visit_expression(&input.left, expected); + let t2 = self.visit_expression(&input.right, expected); - return_incorrect_type(t1, t2, self.visitor.expected_type) + return_incorrect_type(t1, t2, expected) } BinaryOperation::Sub => { - self.visitor - .assert_field_group_int_type(self.visitor.expected_type, input.span()); - let t1 = self.visit_expression(&input.left); - let t2 = self.visit_expression(&input.right); + self.visitor.assert_field_group_int_type(expected, input.span()); + let t1 = self.visit_expression(&input.left, expected); + let t2 = self.visit_expression(&input.right, expected); - return_incorrect_type(t1, t2, self.visitor.expected_type) + return_incorrect_type(t1, t2, expected) } BinaryOperation::Mul => { - self.visitor - .assert_field_group_int_type(self.visitor.expected_type, input.span()); + self.visitor.assert_field_group_int_type(expected, input.span()); - let prev_expected_type = self.visitor.expected_type; - self.visitor.expected_type = None; - let t1 = self.visit_expression(&input.left); - let t2 = self.visit_expression(&input.right); - self.visitor.expected_type = prev_expected_type; + let t1 = self.visit_expression(&input.left, &None); + let t2 = self.visit_expression(&input.right, &None); // Allow `group` * `scalar` multiplication. match (t1.as_ref(), t2.as_ref()) { - (Some(Type::Group), Some(other)) - | (Some(other), Some(Type::Group)) => { - self.visitor.assert_type(Type::Group, self.visitor.expected_type); - self.visitor.assert_type(*other, Some(Type::Scalar)); + (Some(Type::Group), Some(other)) => { + self.visitor.assert_type(Type::Group, expected, input.left.span()); + self.visitor + .assert_type(*other, &Some(Type::Scalar), input.right.span()); + Some(Type::Group) + } + (Some(other), Some(Type::Group)) => { + self.visitor.assert_type(*other, &Some(Type::Scalar), input.left.span()); + self.visitor.assert_type(Type::Group, expected, input.right.span()); Some(Type::Group) } _ => { - self.visitor.assert_type(t1.unwrap(), self.visitor.expected_type); - self.visitor.assert_type(t2.unwrap(), self.visitor.expected_type); - return_incorrect_type(t1, t2, self.visitor.expected_type) + self.visitor.assert_type(t1.unwrap(), expected, input.left.span()); + self.visitor.assert_type(t2.unwrap(), expected, input.right.span()); + return_incorrect_type(t1, t2, expected) } } } BinaryOperation::Div => { - self.visitor - .assert_field_int_type(self.visitor.expected_type, input.span()); + self.visitor.assert_field_int_type(expected, input.span()); - let t1 = self.visit_expression(&input.left); - let t2 = self.visit_expression(&input.right); - - return_incorrect_type(t1, t2, self.visitor.expected_type) + let t1 = self.visit_expression(&input.left, expected); + let t2 = self.visit_expression(&input.right, expected); + + return_incorrect_type(t1, t2, expected) } BinaryOperation::Pow => { - let prev_expected_type = self.visitor.expected_type; - self.visitor.expected_type = None; - - let t1 = self.visit_expression(&input.left); - let t2 = self.visit_expression(&input.right); - - self.visitor.expected_type = prev_expected_type; + let t1 = self.visit_expression(&input.left, &None); + let t2 = self.visit_expression(&input.right, &None); match (t1.as_ref(), t2.as_ref()) { // Type A must be an int. // Type B must be a unsigned int. (Some(Type::IntegerType(_)), Some(Type::IntegerType(itype))) if !itype.is_signed() => { - self.visitor.assert_type(t1.unwrap(), self.visitor.expected_type); + self.visitor.assert_type(t1.unwrap(), expected, input.left.span()); } // Type A was an int. // But Type B was not a unsigned int. @@ -268,7 +262,7 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> { // Type A must be a field. // Type B must be an int. (Some(Type::Field), Some(Type::IntegerType(_))) => { - self.visitor.assert_type(Type::Field, self.visitor.expected_type); + self.visitor.assert_type(Type::Field, expected, input.left.span()); } // Type A was a field. // But Type B was not an int. @@ -289,28 +283,20 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> { t1 } BinaryOperation::Eq | BinaryOperation::Ne => { - let prev_expected_type = self.visitor.expected_type; - self.visitor.expected_type = None; - - let t1 = self.visit_expression(&input.left); - let t2 = self.visit_expression(&input.right); - - self.visitor.expected_type = prev_expected_type; + let t1 = self.visit_expression(&input.left, &None); + let t2 = self.visit_expression(&input.right, &None); + self.visitor.assert_eq_types(t1, t2, input.span()); Some(Type::Boolean) } BinaryOperation::Lt | BinaryOperation::Gt | BinaryOperation::Le | BinaryOperation::Ge => { - let prev_expected_type = self.visitor.expected_type; - self.visitor.expected_type = None; - - let t1 = self.visit_expression(&input.left); - self.visitor.assert_field_scalar_int_type(t1, input.left.span()); + let t1 = self.visit_expression(&input.left, &None); + self.visitor.assert_field_scalar_int_type(&t1, input.left.span()); - let t2 = self.visit_expression(&input.right); - self.visitor.assert_field_scalar_int_type(t2, input.right.span()); + let t2 = self.visit_expression(&input.right, &None); + self.visitor.assert_field_scalar_int_type(&t2, input.right.span()); - self.visitor.expected_type = prev_expected_type; self.visitor.assert_eq_types(t1, t2, input.span()); Some(Type::Boolean) @@ -321,17 +307,17 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> { None } - fn visit_unary(&mut self, input: &'a UnaryExpression) -> Option { + fn visit_unary(&mut self, input: &'a UnaryExpression, expected: &Self::AdditionalInput) -> Option { match input.op { UnaryOperation::Not => { - self.visitor.assert_type(Type::Boolean, self.visitor.expected_type); - self.visit_expression(&input.inner) + self.visitor.assert_type(Type::Boolean, expected, input.span()); + self.visit_expression(&input.inner, expected) } UnaryOperation::Negate => { let prior_negate_state = self.visitor.negate; self.visitor.negate = true; - let type_ = self.visit_expression(&input.inner); + let type_ = self.visit_expression(&input.inner, expected); self.visitor.negate = prior_negate_state; match type_.as_ref() { Some( @@ -356,27 +342,28 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> { } } - fn visit_ternary(&mut self, input: &'a TernaryExpression) -> Option { + fn visit_ternary( + &mut self, + input: &'a TernaryExpression, + expected: &Self::AdditionalInput, + ) -> Option { if let VisitResult::VisitChildren = self.visitor.visit_ternary(input) { - let prev_expected_type = self.visitor.expected_type; - self.visitor.expected_type = Some(Type::Boolean); - self.visit_expression(&input.condition); - self.visitor.expected_type = prev_expected_type; + self.visit_expression(&input.condition, &Some(Type::Boolean)); - let t1 = self.visit_expression(&input.if_true); - let t2 = self.visit_expression(&input.if_false); + let t1 = self.visit_expression(&input.if_true, expected); + let t2 = self.visit_expression(&input.if_false, expected); - return return_incorrect_type(t1, t2, self.visitor.expected_type); + return return_incorrect_type(t1, t2, expected); } None } - fn visit_call(&mut self, input: &'a CallExpression) -> Option { + fn visit_call(&mut self, input: &'a CallExpression, expected: &Self::AdditionalInput) -> Option { match &*input.function { Expression::Identifier(ident) => { if let Some(func) = self.visitor.symbol_table.clone().lookup_fn(&ident.name) { - let ret = self.visitor.assert_type(func.output, self.visitor.expected_type); + let ret = self.visitor.assert_type(func.output, expected, func.span()); if func.input.len() != input.arguments.len() { self.visitor.handler.emit_err( @@ -393,10 +380,7 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> { .iter() .zip(input.arguments.iter()) .for_each(|(expected, argument)| { - let prev_expected_type = self.visitor.expected_type; - self.visitor.expected_type = Some(expected.get_variable().type_); - self.visit_expression(argument); - self.visitor.expected_type = prev_expected_type; + self.visit_expression(argument, &Some(expected.get_variable().type_)); }); Some(ret) @@ -407,7 +391,7 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> { None } } - expr => self.visit_expression(expr), + expr => self.visit_expression(expr, expected), } } } diff --git a/compiler/passes/src/type_checker/check_statements.rs b/compiler/passes/src/type_checker/check_statements.rs index 888bb42f8f..4205462cc8 100644 --- a/compiler/passes/src/type_checker/check_statements.rs +++ b/compiler/passes/src/type_checker/check_statements.rs @@ -29,10 +29,10 @@ impl<'a> StatementVisitorDirector<'a> for Director<'a> { // statements should always have some parent block let parent = self.visitor.parent.unwrap(); - let prev_expected_type = self.visitor.expected_type; - self.visitor.expected_type = self.visitor.symbol_table.lookup_fn(&parent).map(|f| f.output); - self.visit_expression(&input.expression); - self.visitor.expected_type = prev_expected_type; + self.visit_expression( + &input.expression, + &self.visitor.symbol_table.lookup_fn(&parent).map(|f| f.output), + ); } fn visit_definition(&mut self, input: &'a DefinitionStatement) { @@ -54,10 +54,7 @@ impl<'a> StatementVisitorDirector<'a> for Director<'a> { self.visitor.handler.emit_err(err); } - let prev_expected_type = self.visitor.expected_type; - self.visitor.expected_type = Some(input.type_); - self.visit_expression(&input.value); - self.visitor.expected_type = prev_expected_type; + self.visit_expression(&input.value, &Some(input.type_)); }); } @@ -86,18 +83,12 @@ impl<'a> StatementVisitorDirector<'a> for Director<'a> { }; if var_type.is_some() { - let prev_expected_type = self.visitor.expected_type; - self.visitor.expected_type = var_type; - self.visit_expression(&input.value); - self.visitor.expected_type = prev_expected_type; + self.visit_expression(&input.value, &var_type); } } fn visit_conditional(&mut self, input: &'a ConditionalStatement) { - let prev_expected_type = self.visitor.expected_type; - self.visitor.expected_type = Some(Type::Boolean); - self.visit_expression(&input.condition); - self.visitor.expected_type = prev_expected_type; + self.visit_expression(&input.condition, &Some(Type::Boolean)); } fn visit_iteration(&mut self, input: &'a IterationStatement) { @@ -112,21 +103,14 @@ impl<'a> StatementVisitorDirector<'a> for Director<'a> { self.visitor.handler.emit_err(err); } - let prev_expected_type = self.visitor.expected_type; - self.visitor.expected_type = Some(input.type_); - self.visit_expression(&input.start); - self.visit_expression(&input.stop); - self.visitor.expected_type = prev_expected_type; + self.visit_expression(&input.start, &Some(input.type_)); + self.visit_expression(&input.stop, &Some(input.type_)); } fn visit_console(&mut self, input: &'a ConsoleStatement) { match &input.function { ConsoleFunction::Assert(expr) => { - let prev_expected_type = self.visitor.expected_type; - self.visitor.expected_type = Some(Type::Boolean); - self.visit_expression(expr); - self.visitor.expected_type = prev_expected_type; - // self.compare_expr_type(expr, Some(Type::Boolean), expr.span()); + self.visit_expression(expr, &Some(Type::Boolean)); } ConsoleFunction::Error(_) | ConsoleFunction::Log(_) => { // TODO: undetermined diff --git a/compiler/passes/src/type_checker/checker.rs b/compiler/passes/src/type_checker/checker.rs index 171ef54c93..5807fd7ee6 100644 --- a/compiler/passes/src/type_checker/checker.rs +++ b/compiler/passes/src/type_checker/checker.rs @@ -25,8 +25,6 @@ pub struct TypeChecker<'a> { pub(crate) handler: &'a Handler, pub(crate) parent: Option, pub(crate) negate: bool, - pub(crate) expected_type: Option, - pub(crate) span: Span, } const INT_TYPES: [Type; 10] = [ @@ -76,8 +74,6 @@ impl<'a> TypeChecker<'a> { handler, parent: None, negate: false, - expected_type: None, - span: Default::default(), } } @@ -95,11 +91,11 @@ impl<'a> TypeChecker<'a> { } /// Returns the given type if it equals the expected type or the expected type is none. - pub(crate) fn assert_type(&mut self, type_: Type, expected: Option) -> Type { + pub(crate) fn assert_type(&mut self, type_: Type, expected: &Option, span: Span) -> Type { if let Some(expected) = expected { - if type_ != expected { + if &type_ != expected { self.handler - .emit_err(TypeCheckerError::type_should_be(type_, expected, self.span).into()); + .emit_err(TypeCheckerError::type_should_be(type_, expected, span).into()); } } @@ -107,9 +103,9 @@ impl<'a> TypeChecker<'a> { } /// Emits an error to the error handler if the given type is not equal to any of the expected types. - pub(crate) fn assert_one_of_types(&self, type_: Option, expected: &[Type], span: Span) { + pub(crate) fn assert_one_of_types(&self, type_: &Option, expected: &[Type], span: Span) { if let Some(type_) = type_ { - if !expected.iter().any(|t: &Type| t == &type_) { + if !expected.iter().any(|t: &Type| t == type_) { self.handler.emit_err( TypeCheckerError::expected_one_type_of( expected.iter().map(|t| t.to_string() + ",").collect::(), @@ -123,22 +119,22 @@ impl<'a> TypeChecker<'a> { } /// Emits an error to the handler if the given type is not a field or integer. - pub(crate) fn assert_field_int_type(&self, type_: Option, span: Span) { + pub(crate) fn assert_field_int_type(&self, type_: &Option, span: Span) { self.assert_one_of_types(type_, &FIELD_INT_TYPES, span) } /// Emits an error to the handler if the given type is not a field, scalar, or integer. - pub(crate) fn assert_field_scalar_int_type(&self, type_: Option, span: Span) { + pub(crate) fn assert_field_scalar_int_type(&self, type_: &Option, span: Span) { self.assert_one_of_types(type_, &FIELD_SCALAR_INT_TYPES, span) } /// Emits an error to the handler if the given type is not a field, group, or integer. - pub(crate) fn assert_field_group_int_type(&self, type_: Option, span: Span) { + pub(crate) fn assert_field_group_int_type(&self, type_: &Option, span: Span) { self.assert_one_of_types(type_, &FIELD_GROUP_INT_TYPES, span) } /// Emits an error to the handler if the given type is not a field, group, scalar or integer. - pub(crate) fn assert_field_group_scalar_int_type(&self, type_: Option, span: Span) { + pub(crate) fn assert_field_group_scalar_int_type(&self, type_: &Option, span: Span) { self.assert_one_of_types(type_, &FIELD_GROUP_SCALAR_INT_TYPES, span) } } diff --git a/tests/expectations/compiler/compiler/group/mult_by_group_fail.leo.out b/tests/expectations/compiler/compiler/group/mult_by_group_fail.leo.out index 81b8873b74..17bc264a35 100644 --- a/tests/expectations/compiler/compiler/group/mult_by_group_fail.leo.out +++ b/tests/expectations/compiler/compiler/group/mult_by_group_fail.leo.out @@ -2,4 +2,4 @@ namespace: Compile expectation: Fail outputs: - - "Error [ETYC0372002]: Found type `group` but type `scalar` was expected\n --> compiler-test:1:1\n |\n 1 | \n | \n" + - "Error [ETYC0372002]: Found type `group` but type `scalar` was expected\n --> compiler-test:4:26\n |\n 4 | return (_, _)group * a;\n | ^\n" diff --git a/tests/expectations/compiler/compiler/statements/compare_diff_types_fail.leo.out b/tests/expectations/compiler/compiler/statements/compare_diff_types_fail.leo.out index de17a6ae9a..a3376fe4ec 100644 --- a/tests/expectations/compiler/compiler/statements/compare_diff_types_fail.leo.out +++ b/tests/expectations/compiler/compiler/statements/compare_diff_types_fail.leo.out @@ -2,4 +2,4 @@ namespace: Compile expectation: Fail outputs: - - "Error [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:4:19\n |\n 4 | let b: bool = a == 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:5:19\n |\n 5 | let c: bool = a != 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:6:19\n |\n 6 | let d: bool = a > 1u8;\n | ^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:7:19\n |\n 7 | let e: bool = a < 1u8;\n | ^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:8:19\n |\n 8 | let f: bool = a >= 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:9:19\n |\n 9 | let g: bool = a <= 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u32` was expected\n --> compiler-test:1:1\n |\n 1 | \n | \nError [ETYC0372002]: Found type `u8` but type `u32` was expected\n --> compiler-test:1:1\n |\n 1 | \n | \n" + - "Error [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:4:19\n |\n 4 | let b: bool = a == 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:5:19\n |\n 5 | let c: bool = a != 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:6:19\n |\n 6 | let d: bool = a > 1u8;\n | ^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:7:19\n |\n 7 | let e: bool = a < 1u8;\n | ^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:8:19\n |\n 8 | let f: bool = a >= 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:9:19\n |\n 9 | let g: bool = a <= 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u32` was expected\n --> compiler-test:10:18\n |\n 10 | let h: u32 = a * 1u8;\n | ^\nError [ETYC0372002]: Found type `u8` but type `u32` was expected\n --> compiler-test:10:22\n |\n 10 | let h: u32 = a * 1u8;\n | ^^^\n"