trying to modify visitor pattern to better fit type checking

This commit is contained in:
gluax 2022-05-26 10:05:04 -07:00
parent bc174419f7
commit 97ef64aa66
10 changed files with 580 additions and 350 deletions

View File

@ -30,35 +30,37 @@ impl Default for VisitResult {
}
pub trait ExpressionVisitor<'a> {
fn visit_expression(&mut self, _input: &'a Expression) -> VisitResult {
type Output;
fn visit_expression(&mut self, _input: &'a Expression) -> (VisitResult, Option<Self::Output>) {
Default::default()
}
fn visit_identifier(&mut self, _input: &'a Identifier) -> VisitResult {
fn visit_identifier(&mut self, _input: &'a Identifier) -> (VisitResult, Option<Self::Output>) {
Default::default()
}
fn visit_value(&mut self, _input: &'a ValueExpression) -> VisitResult {
fn visit_value(&mut self, _input: &'a ValueExpression) -> (VisitResult, Option<Self::Output>) {
Default::default()
}
fn visit_binary(&mut self, _input: &'a BinaryExpression) -> VisitResult {
fn visit_binary(&mut self, _input: &'a BinaryExpression) -> (VisitResult, Option<Self::Output>) {
Default::default()
}
fn visit_unary(&mut self, _input: &'a UnaryExpression) -> VisitResult {
fn visit_unary(&mut self, _input: &'a UnaryExpression) -> (VisitResult, Option<Self::Output>) {
Default::default()
}
fn visit_ternary(&mut self, _input: &'a TernaryExpression) -> VisitResult {
fn visit_ternary(&mut self, _input: &'a TernaryExpression) -> (VisitResult, Option<Self::Output>) {
Default::default()
}
fn visit_call(&mut self, _input: &'a CallExpression) -> VisitResult {
fn visit_call(&mut self, _input: &'a CallExpression) -> (VisitResult, Option<Self::Output>) {
Default::default()
}
fn visit_err(&mut self, _input: &'a ErrExpression) -> VisitResult {
fn visit_err(&mut self, _input: &'a ErrExpression) -> (VisitResult, Option<Self::Output>) {
Default::default()
}
}

View File

@ -18,72 +18,87 @@
//! It implements default methods for each node to be made
//! given the type of node its visiting.
use std::marker::PhantomData;
use crate::*;
pub struct VisitorDirector<'a, V: ExpressionVisitor<'a>> {
visitor: V,
lifetime: PhantomData<&'a ()>,
pub trait VisitorDirector<'a> {
type Visitor: ExpressionVisitor<'a> + ProgramVisitor<'a> + StatementVisitor<'a>;
fn visitor(self) -> Self::Visitor;
fn visitor_ref(&mut self) -> &mut Self::Visitor;
}
impl<'a, V: ExpressionVisitor<'a>> VisitorDirector<'a, V> {
pub fn new(visitor: V) -> Self {
Self {
visitor,
lifetime: PhantomData,
}
}
pub trait ExpressionVisitorDirector<'a>: VisitorDirector<'a> {
type Output;
pub fn visitor(self) -> V {
self.visitor
}
pub fn visit_expression(&mut self, input: &'a Expression) {
if let VisitResult::VisitChildren = self.visitor.visit_expression(input) {
fn visit_expression(&mut self, input: &'a Expression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor_ref().visit_expression(input).0 {
match input {
Expression::Identifier(_) => {}
Expression::Value(_) => {}
Expression::Identifier(expr) => self.visit_identifier(expr),
Expression::Value(expr) => self.visit_value(expr),
Expression::Binary(expr) => self.visit_binary(expr),
Expression::Unary(expr) => self.visit_unary(expr),
Expression::Ternary(expr) => self.visit_ternary(expr),
Expression::Call(expr) => self.visit_call(expr),
Expression::Err(_) => {}
}
Expression::Err(expr) => self.visit_err(expr),
};
}
None
}
pub fn visit_binary(&mut self, input: &'a BinaryExpression) {
if let VisitResult::VisitChildren = self.visitor.visit_binary(input) {
fn visit_identifier(&mut self, input: &'a Identifier) -> Option<Self::Output> {
self.visitor_ref().visit_identifier(input);
None
}
fn visit_value(&mut self, input: &'a ValueExpression) -> Option<Self::Output> {
self.visitor_ref().visit_value(input);
None
}
fn visit_binary(&mut self, input: &'a BinaryExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor_ref().visit_binary(input).0 {
self.visit_expression(&input.left);
self.visit_expression(&input.right);
}
None
}
pub fn visit_unary(&mut self, input: &'a UnaryExpression) {
if let VisitResult::VisitChildren = self.visitor.visit_unary(input) {
fn visit_unary(&mut self, input: &'a UnaryExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor_ref().visit_unary(input).0 {
self.visit_expression(&input.inner);
}
None
}
pub fn visit_ternary(&mut self, input: &'a TernaryExpression) {
if let VisitResult::VisitChildren = self.visitor.visit_ternary(input) {
fn visit_ternary(&mut self, input: &'a TernaryExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor_ref().visit_ternary(input).0 {
self.visit_expression(&input.condition);
self.visit_expression(&input.if_true);
self.visit_expression(&input.if_false);
}
None
}
pub fn visit_call(&mut self, input: &'a CallExpression) {
if let VisitResult::VisitChildren = self.visitor.visit_call(input) {
input.arguments.iter().for_each(|expr| self.visit_expression(expr));
fn visit_call(&mut self, input: &'a CallExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor_ref().visit_call(input).0 {
input.arguments.iter().for_each(|expr| {
self.visit_expression(expr);
});
}
None
}
fn visit_err(&mut self, input: &'a ErrExpression) -> Option<Self::Output> {
self.visitor_ref().visit_err(input);
None
}
}
impl<'a, V: ExpressionVisitor<'a> + StatementVisitor<'a>> VisitorDirector<'a, V> {
pub fn visit_statement(&mut self, input: &'a Statement) {
if let VisitResult::VisitChildren = self.visitor.visit_statement(input) {
pub trait StatementVisitorDirector<'a>: VisitorDirector<'a> + ExpressionVisitorDirector<'a> {
fn visit_statement(&mut self, input: &'a Statement) {
if let VisitResult::VisitChildren = self.visitor_ref().visit_statement(input) {
match input {
Statement::Return(stmt) => self.visit_return(stmt),
Statement::Definition(stmt) => self.visit_definition(stmt),
@ -96,26 +111,26 @@ impl<'a, V: ExpressionVisitor<'a> + StatementVisitor<'a>> VisitorDirector<'a, V>
}
}
pub fn visit_return(&mut self, input: &'a ReturnStatement) {
if let VisitResult::VisitChildren = self.visitor.visit_return(input) {
self.visit_expression(&input.expression);
fn visit_return(&mut self, input: &'a ReturnStatement) {
if let VisitResult::VisitChildren = self.visitor_ref().visit_return(input) {
self.visitor_ref().visit_expression(&input.expression);
}
}
pub fn visit_definition(&mut self, input: &'a DefinitionStatement) {
if let VisitResult::VisitChildren = self.visitor.visit_definition(input) {
fn visit_definition(&mut self, input: &'a DefinitionStatement) {
if let VisitResult::VisitChildren = self.visitor_ref().visit_definition(input) {
self.visit_expression(&input.value);
}
}
pub fn visit_assign(&mut self, input: &'a AssignStatement) {
if let VisitResult::VisitChildren = self.visitor.visit_assign(input) {
fn visit_assign(&mut self, input: &'a AssignStatement) {
if let VisitResult::VisitChildren = self.visitor_ref().visit_assign(input) {
self.visit_expression(&input.value);
}
}
pub fn visit_conditional(&mut self, input: &'a ConditionalStatement) {
if let VisitResult::VisitChildren = self.visitor.visit_conditional(input) {
fn visit_conditional(&mut self, input: &'a ConditionalStatement) {
if let VisitResult::VisitChildren = self.visitor_ref().visit_conditional(input) {
self.visit_expression(&input.condition);
self.visit_block(&input.block);
if let Some(stmt) = input.next.as_ref() {
@ -124,35 +139,38 @@ impl<'a, V: ExpressionVisitor<'a> + StatementVisitor<'a>> VisitorDirector<'a, V>
}
}
pub fn visit_iteration(&mut self, input: &'a IterationStatement) {
if let VisitResult::VisitChildren = self.visitor.visit_iteration(input) {
fn visit_iteration(&mut self, input: &'a IterationStatement) {
if let VisitResult::VisitChildren = self.visitor_ref().visit_iteration(input) {
self.visit_expression(&input.start);
self.visit_expression(&input.stop);
self.visit_block(&input.block);
}
}
pub fn visit_console(&mut self, input: &'a ConsoleStatement) {
if let VisitResult::VisitChildren = self.visitor.visit_console(input) {
fn visit_console(&mut self, input: &'a ConsoleStatement) {
if let VisitResult::VisitChildren = self.visitor_ref().visit_console(input) {
match &input.function {
ConsoleFunction::Assert(expr) => self.visit_expression(expr),
ConsoleFunction::Error(fmt) | ConsoleFunction::Log(fmt) => {
fmt.parameters.iter().for_each(|expr| self.visit_expression(expr));
fmt.parameters.iter().for_each(|expr| {
self.visit_expression(expr);
});
None
}
}
};
}
}
pub fn visit_block(&mut self, input: &'a Block) {
if let VisitResult::VisitChildren = self.visitor.visit_block(input) {
fn visit_block(&mut self, input: &'a Block) {
if let VisitResult::VisitChildren = self.visitor_ref().visit_block(input) {
input.statements.iter().for_each(|stmt| self.visit_statement(stmt));
}
}
}
impl<'a, V: ExpressionVisitor<'a> + ProgramVisitor<'a> + StatementVisitor<'a>> VisitorDirector<'a, V> {
pub fn visit_program(&mut self, input: &'a Program) {
if let VisitResult::VisitChildren = self.visitor.visit_program(input) {
pub trait ProgramVisitorDirector<'a>: VisitorDirector<'a> + StatementVisitorDirector<'a> {
fn visit_program(&mut self, input: &'a Program) {
if let VisitResult::VisitChildren = self.visitor_ref().visit_program(input) {
input
.functions
.values()
@ -160,8 +178,8 @@ impl<'a, V: ExpressionVisitor<'a> + ProgramVisitor<'a> + StatementVisitor<'a>> V
}
}
pub fn visit_function(&mut self, input: &'a Function) {
if let VisitResult::VisitChildren = self.visitor.visit_function(input) {
fn visit_function(&mut self, input: &'a Function) {
if let VisitResult::VisitChildren = self.visitor_ref().visit_function(input) {
self.visit_block(&input.block);
}
}

View File

@ -36,7 +36,9 @@ impl<'a> CreateSymbolTable<'a> {
}
}
impl<'a> ExpressionVisitor<'a> for CreateSymbolTable<'a> {}
impl<'a> ExpressionVisitor<'a> for CreateSymbolTable<'a> {
type Output = ();
}
impl<'a> StatementVisitor<'a> for CreateSymbolTable<'a> {}

View File

@ -0,0 +1,52 @@
// Copyright (C) 2019-2022 Aleo Systems Inc.
// This file is part of the Leo library.
// The Leo library is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
// The Leo library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use leo_ast::{ExpressionVisitorDirector, ProgramVisitorDirector, StatementVisitorDirector, VisitorDirector};
use leo_errors::emitter::Handler;
use crate::CreateSymbolTable;
pub(crate) struct Director<'a> {
visitor: CreateSymbolTable<'a>,
}
impl<'a> Director<'a> {
pub(crate) fn new(handler: &'a Handler) -> Self {
Self {
visitor: CreateSymbolTable::new(handler),
}
}
}
impl<'a> VisitorDirector<'a> for Director<'a> {
type Visitor = CreateSymbolTable<'a>;
fn visitor(self) -> Self::Visitor {
self.visitor
}
fn visitor_ref(&mut self) -> &mut Self::Visitor {
&mut self.visitor
}
}
impl<'a> ExpressionVisitorDirector<'a> for Director<'a> {
type Output = ();
}
impl<'a> StatementVisitorDirector<'a> for Director<'a> {}
impl<'a> ProgramVisitorDirector<'a> for Director<'a> {}

View File

@ -17,6 +17,9 @@
pub mod create;
pub use create::*;
pub mod director;
use director::*;
pub mod table;
pub use table::*;
@ -28,7 +31,7 @@ pub use variable_symbol::*;
use crate::Pass;
use leo_ast::{Ast, VisitorDirector};
use leo_ast::{Ast, ProgramVisitorDirector, VisitorDirector};
use leo_errors::{emitter::Handler, Result};
impl<'a> Pass<'a> for CreateSymbolTable<'a> {
@ -36,7 +39,7 @@ impl<'a> Pass<'a> for CreateSymbolTable<'a> {
type Output = Result<SymbolTable<'a>>;
fn do_pass((ast, handler): Self::Input) -> Self::Output {
let mut visitor = VisitorDirector::new(CreateSymbolTable::new(handler));
let mut visitor = Director::new(handler);
visitor.visit_program(ast.as_repr());
handler.last_err()?;

View File

@ -16,7 +16,6 @@
use leo_ast::*;
use leo_errors::TypeCheckerError;
use leo_span::Span;
use crate::TypeChecker;
@ -38,292 +37,239 @@ fn return_incorrect_type(t1: Option<Type>, t2: Option<Type>, expected: Option<Ty
}
}
impl<'a> TypeChecker<'a> {
pub(crate) fn compare_expr_type(&mut self, expr: &Expression, expected: Option<Type>, span: Span) -> Option<Type> {
match expr {
Expression::Identifier(ident) => {
if let Some(var) = self.symbol_table.lookup_variable(&ident.name) {
Some(self.assert_type(*var.type_, expected, span))
} else {
self.handler
.emit_err(TypeCheckerError::unknown_sym("variable", ident.name, span).into());
None
}
}
Expression::Value(value) => match value {
ValueExpression::Address(_, _) => Some(self.assert_type(Type::Address, expected, value.span())),
ValueExpression::Boolean(_, _) => Some(self.assert_type(Type::Boolean, expected, value.span())),
ValueExpression::Field(_, _) => Some(self.assert_type(Type::Field, expected, value.span())),
ValueExpression::Integer(type_, str_content, _) => {
match type_ {
IntegerType::I8 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
type Output = Type;
if int.parse::<i8>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i8", value.span()).into());
}
}
IntegerType::I16 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
fn visit_identifier(&mut self, input: &'a Identifier) -> (VisitResult, Option<Self::Output>) {
let type_ = if let Some(var) = self.symbol_table.lookup_variable(&input.name) {
Some(self.assert_type(*var.type_, self.expected_type))
} else {
self.handler
.emit_err(TypeCheckerError::unknown_sym("variable", input.name, self.span).into());
None
};
if int.parse::<i16>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i16", value.span()).into());
}
}
IntegerType::I32 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
(VisitResult::VisitChildren, type_)
}
if int.parse::<i32>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i32", value.span()).into());
}
}
IntegerType::I64 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
fn visit_value(&mut self, input: &'a ValueExpression) -> (VisitResult, Option<Self::Output>) {
let prev_span = self.span;
self.span = input.span();
if int.parse::<i64>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i64", value.span()).into());
}
}
IntegerType::I128 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
let type_ = Some(match input {
ValueExpression::Address(_, _) => self.assert_type(Type::Address, self.expected_type),
ValueExpression::Boolean(_, _) => self.assert_type(Type::Boolean, self.expected_type),
ValueExpression::Field(_, _) => self.assert_type(Type::Field, self.expected_type),
ValueExpression::Integer(type_, str_content, _) => {
match type_ {
IntegerType::I8 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
if int.parse::<i128>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i128", value.span()).into());
}
}
IntegerType::U8 if str_content.parse::<u8>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u8", value.span()).into()),
IntegerType::U16 if str_content.parse::<u16>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u16", value.span()).into()),
IntegerType::U32 if str_content.parse::<u32>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u32", value.span()).into()),
IntegerType::U64 if str_content.parse::<u64>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u64", value.span()).into()),
IntegerType::U128 if str_content.parse::<u128>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u128", value.span()).into()),
_ => {}
}
Some(self.assert_type(Type::IntegerType(*type_), expected, value.span()))
}
ValueExpression::Group(_) => Some(self.assert_type(Type::Group, expected, value.span())),
ValueExpression::Scalar(_, _) => Some(self.assert_type(Type::Scalar, expected, value.span())),
ValueExpression::String(_, _) => unreachable!("String types are not reachable"),
},
Expression::Binary(binary) => match binary.op {
BinaryOperation::And | BinaryOperation::Or => {
self.assert_type(Type::Boolean, expected, binary.span());
let t1 = self.compare_expr_type(&binary.left, expected, binary.left.span());
let t2 = self.compare_expr_type(&binary.right, expected, binary.right.span());
return_incorrect_type(t1, t2, expected)
}
BinaryOperation::Add => {
self.assert_field_group_scalar_int_type(expected, binary.span());
let t1 = self.compare_expr_type(&binary.left, expected, binary.left.span());
let t2 = self.compare_expr_type(&binary.right, expected, binary.right.span());
return_incorrect_type(t1, t2, expected)
}
BinaryOperation::Sub => {
self.assert_field_group_int_type(expected, binary.span());
let t1 = self.compare_expr_type(&binary.left, expected, binary.left.span());
let t2 = self.compare_expr_type(&binary.right, expected, binary.right.span());
return_incorrect_type(t1, t2, expected)
}
BinaryOperation::Mul => {
self.assert_field_group_int_type(expected, binary.span());
let t1 = self.compare_expr_type(&binary.left, None, binary.left.span());
let t2 = self.compare_expr_type(&binary.right, None, binary.right.span());
// Allow `group` * `scalar` multiplication.
match (t1.as_ref(), t2.as_ref()) {
(Some(Type::Group), Some(other)) => {
self.assert_type(Type::Group, expected, binary.left.span());
self.assert_type(*other, Some(Type::Scalar), binary.span());
Some(Type::Group)
}
(Some(other), Some(Type::Group)) => {
self.assert_type(Type::Group, expected, binary.left.span());
self.assert_type(*other, Some(Type::Scalar), binary.span());
Some(Type::Group)
}
_ => {
self.assert_type(t1.unwrap(), expected, binary.left.span());
self.assert_type(t2.unwrap(), expected, binary.right.span());
return_incorrect_type(t1, t2, expected)
}
}
}
BinaryOperation::Div => {
self.assert_field_int_type(expected, binary.span());
let t1 = self.compare_expr_type(&binary.left, expected, binary.left.span());
let t2 = self.compare_expr_type(&binary.right, expected, binary.right.span());
return_incorrect_type(t1, t2, expected)
}
BinaryOperation::Pow => {
let t1 = self.compare_expr_type(&binary.left, None, binary.left.span());
let t2 = self.compare_expr_type(&binary.right, None, binary.right.span());
match (t1.as_ref(), t2.as_ref()) {
// Type A must be an int.
// Type B must be a unsigned int.
(Some(Type::IntegerType(_)), Some(Type::IntegerType(itype))) if !itype.is_signed() => {
self.assert_type(t1.unwrap(), expected, binary.span());
}
// Type A was an int.
// But Type B was not a unsigned int.
(Some(Type::IntegerType(_)), Some(t)) => {
self.handler.emit_err(
TypeCheckerError::incorrect_pow_exponent_type("unsigned int", t, binary.right.span())
.into(),
);
}
// Type A must be a field.
// Type B must be an int.
(Some(Type::Field), Some(Type::IntegerType(_))) => {
self.assert_type(Type::Field, expected, binary.span());
}
// Type A was a field.
// But Type B was not an int.
(Some(Type::Field), Some(t)) => {
self.handler.emit_err(
TypeCheckerError::incorrect_pow_exponent_type("int", t, binary.right.span()).into(),
);
}
// The base is some type thats not an int or field.
(Some(t), _) => {
if int.parse::<i8>().is_err() {
self.handler
.emit_err(TypeCheckerError::incorrect_pow_base_type(t, binary.left.span()).into());
.emit_err(TypeCheckerError::invalid_int_value(int, "i8", input.span()).into());
}
_ => {}
}
IntegerType::I16 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
t1
if int.parse::<i16>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i16", input.span()).into());
}
}
IntegerType::I32 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
if int.parse::<i32>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i32", input.span()).into());
}
}
IntegerType::I64 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
if int.parse::<i64>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i64", input.span()).into());
}
}
IntegerType::I128 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
if int.parse::<i128>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i128", input.span()).into());
}
}
IntegerType::U8 if str_content.parse::<u8>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u8", input.span()).into()),
IntegerType::U16 if str_content.parse::<u16>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u16", input.span()).into()),
IntegerType::U32 if str_content.parse::<u32>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u32", input.span()).into()),
IntegerType::U64 if str_content.parse::<u64>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u64", input.span()).into()),
IntegerType::U128 if str_content.parse::<u128>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u128", input.span()).into()),
_ => {}
}
BinaryOperation::Eq | BinaryOperation::Ne => {
let t1 = self.compare_expr_type(&binary.left, None, binary.left.span());
let t2 = self.compare_expr_type(&binary.right, None, binary.right.span());
self.assert_eq_types(t1, t2, binary.span());
Some(Type::Boolean)
}
BinaryOperation::Lt | BinaryOperation::Gt | BinaryOperation::Le | BinaryOperation::Ge => {
let t1 = self.compare_expr_type(&binary.left, None, binary.left.span());
self.assert_field_scalar_int_type(t1, binary.left.span());
let t2 = self.compare_expr_type(&binary.right, None, binary.right.span());
self.assert_field_scalar_int_type(t2, binary.right.span());
self.assert_eq_types(t1, t2, binary.span());
Some(Type::Boolean)
}
},
Expression::Unary(unary) => match unary.op {
UnaryOperation::Not => {
self.assert_type(Type::Boolean, expected, unary.span());
self.compare_expr_type(&unary.inner, expected, unary.inner.span())
}
UnaryOperation::Negate => {
let prior_negate_state = self.negate;
self.negate = true;
let type_ = self.compare_expr_type(&unary.inner, expected, unary.inner.span());
self.negate = prior_negate_state;
match type_.as_ref() {
Some(
Type::IntegerType(
IntegerType::I8
| IntegerType::I16
| IntegerType::I32
| IntegerType::I64
| IntegerType::I128,
)
| Type::Field
| Type::Group,
) => {}
Some(t) => self
.handler
.emit_err(TypeCheckerError::type_is_not_negatable(t, unary.inner.span()).into()),
_ => {}
};
type_
}
},
Expression::Ternary(ternary) => {
self.compare_expr_type(&ternary.condition, Some(Type::Boolean), ternary.condition.span());
let t1 = self.compare_expr_type(&ternary.if_true, expected, ternary.if_true.span());
let t2 = self.compare_expr_type(&ternary.if_false, expected, ternary.if_false.span());
return_incorrect_type(t1, t2, expected)
self.assert_type(Type::IntegerType(*type_), self.expected_type)
}
Expression::Call(call) => match &*call.function {
Expression::Identifier(ident) => {
if let Some(func) = self.symbol_table.lookup_fn(&ident.name) {
let ret = self.assert_type(func.output, expected, ident.span());
ValueExpression::Group(_) => self.assert_type(Type::Group, self.expected_type),
ValueExpression::Scalar(_, _) => self.assert_type(Type::Scalar, self.expected_type),
ValueExpression::String(_, _) => unreachable!("String types are not reachable"),
});
if func.input.len() != call.arguments.len() {
self.handler.emit_err(
TypeCheckerError::incorrect_num_args_to_call(
func.input.len(),
call.arguments.len(),
call.span(),
)
.into(),
);
}
self.span = prev_span;
(VisitResult::VisitChildren, type_)
}
func.input
.iter()
.zip(call.arguments.iter())
.for_each(|(expected, argument)| {
self.compare_expr_type(argument, Some(expected.get_variable().type_), argument.span());
});
fn visit_binary(&mut self, input: &'a BinaryExpression) -> (VisitResult, Option<Self::Output>) {
let prev_span = self.span;
self.span = input.span();
Some(ret)
} else {
self.handler
.emit_err(TypeCheckerError::unknown_sym("function", &ident.name, ident.span()).into());
None
/* let type_ = match input.op {
BinaryOperation::And | BinaryOperation::Or => {
self.assert_type(Type::Boolean, self.expected_type);
let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span());
let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span());
return_incorrect_type(t1, t2, self.expected_type)
}
BinaryOperation::Add => {
self.assert_field_group_scalar_int_type(self.expected_type, input.span());
let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span());
let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span());
return_incorrect_type(t1, t2, self.expected_type)
}
BinaryOperation::Sub => {
self.assert_field_group_int_type(self.expected_type, input.span());
let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span());
let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span());
return_incorrect_type(t1, t2, self.expected_type)
}
BinaryOperation::Mul => {
self.assert_field_group_int_type(self.expected_type, input.span());
let t1 = self.compare_expr_type(&input.left, None, input.left.span());
let t2 = self.compare_expr_type(&input.right, None, input.right.span());
// Allow `group` * `scalar` multiplication.
match (t1.as_ref(), t2.as_ref()) {
(Some(Type::Group), Some(other)) => {
self.assert_type(Type::Group, self.expected_type);
self.assert_type(*other, Some(Type::Scalar));
Some(Type::Group)
}
(Some(other), Some(Type::Group)) => {
self.assert_type(Type::Group, self.expected_type);
self.assert_type(*other, Some(Type::Scalar));
Some(Type::Group)
}
_ => {
self.assert_type(t1.unwrap(), self.expected_type);
self.assert_type(t2.unwrap(), self.expected_type);
return_incorrect_type(t1, t2, self.expected_type)
}
}
expr => self.compare_expr_type(expr, expected, call.span()),
},
Expression::Err(_) => None,
}
}
BinaryOperation::Div => {
self.assert_field_int_type(self.expected_type, input.span());
let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span());
let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span());
return_incorrect_type(t1, t2, self.expected_type)
}
BinaryOperation::Pow => {
let t1 = self.compare_expr_type(&input.left, None, input.left.span());
let t2 = self.compare_expr_type(&input.right, None, input.right.span());
match (t1.as_ref(), t2.as_ref()) {
// Type A must be an int.
// Type B must be a unsigned int.
(Some(Type::IntegerType(_)), Some(Type::IntegerType(itype))) if !itype.is_signed() => {
self.assert_type(t1.unwrap(), self.expected_type);
}
// Type A was an int.
// But Type B was not a unsigned int.
(Some(Type::IntegerType(_)), Some(t)) => {
self.handler.emit_err(
TypeCheckerError::incorrect_pow_exponent_type("unsigned int", t, input.right.span())
.into(),
);
}
// Type A must be a field.
// Type B must be an int.
(Some(Type::Field), Some(Type::IntegerType(_))) => {
self.assert_type(Type::Field, self.expected_type);
}
// Type A was a field.
// But Type B was not an int.
(Some(Type::Field), Some(t)) => {
self.handler.emit_err(
TypeCheckerError::incorrect_pow_exponent_type("int", t, input.right.span()).into(),
);
}
// The base is some type thats not an int or field.
(Some(t), _) => {
self.handler
.emit_err(TypeCheckerError::incorrect_pow_base_type(t, input.left.span()).into());
}
_ => {}
}
t1
}
BinaryOperation::Eq | BinaryOperation::Ne => {
let t1 = self.compare_expr_type(&input.left, None, input.left.span());
let t2 = self.compare_expr_type(&input.right, None, input.right.span());
self.assert_eq_types(t1, t2, input.span());
Some(Type::Boolean)
}
BinaryOperation::Lt | BinaryOperation::Gt | BinaryOperation::Le | BinaryOperation::Ge => {
let t1 = self.compare_expr_type(&input.left, None, input.left.span());
self.assert_field_scalar_int_type(t1, input.left.span());
let t2 = self.compare_expr_type(&input.right, None, input.right.span());
self.assert_field_scalar_int_type(t2, input.right.span());
self.assert_eq_types(t1, t2, input.span());
Some(Type::Boolean)
}
}; */
self.span = prev_span;
(VisitResult::VisitChildren, None)
}
}
impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {}

View File

@ -27,7 +27,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
// Would never be None.
let func_output_type = self.symbol_table.lookup_fn(&parent).map(|f| f.output);
self.compare_expr_type(&input.expression, func_output_type, input.expression.span());
// self.compare_expr_type(&input.expression, func_output_type, input.expression.span());
VisitResult::VisitChildren
}
@ -51,7 +51,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
self.handler.emit_err(err);
}
self.compare_expr_type(&input.value, Some(input.type_), input.value.span());
// self.compare_expr_type(&input.value, Some(input.type_), input.value.span());
});
VisitResult::VisitChildren
@ -80,14 +80,14 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
};
if var_type.is_some() {
self.compare_expr_type(&input.value, var_type, input.value.span());
// self.compare_expr_type(&input.value, var_type, input.value.span());
}
VisitResult::VisitChildren
}
fn visit_conditional(&mut self, input: &'a ConditionalStatement) -> VisitResult {
self.compare_expr_type(&input.condition, Some(Type::Boolean), input.condition.span());
// self.compare_expr_type(&input.condition, Some(Type::Boolean), input.condition.span());
VisitResult::VisitChildren
}
@ -104,8 +104,8 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
self.handler.emit_err(err);
}
self.compare_expr_type(&input.start, Some(input.type_), input.start.span());
self.compare_expr_type(&input.stop, Some(input.type_), input.stop.span());
// self.compare_expr_type(&input.start, Some(input.type_), input.start.span());
// self.compare_expr_type(&input.stop, Some(input.type_), input.stop.span());
VisitResult::VisitChildren
}
@ -113,7 +113,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
fn visit_console(&mut self, input: &'a ConsoleStatement) -> VisitResult {
match &input.function {
ConsoleFunction::Assert(expr) => {
self.compare_expr_type(expr, Some(Type::Boolean), expr.span());
// self.compare_expr_type(expr, Some(Type::Boolean), expr.span());
}
ConsoleFunction::Error(_) | ConsoleFunction::Log(_) => {
// TODO: undetermined

View File

@ -25,6 +25,8 @@ pub struct TypeChecker<'a> {
pub(crate) handler: &'a Handler,
pub(crate) parent: Option<Symbol>,
pub(crate) negate: bool,
pub(crate) expected_type: Option<Type>,
pub(crate) span: Span,
}
const INT_TYPES: [Type; 10] = [
@ -74,6 +76,8 @@ impl<'a> TypeChecker<'a> {
handler,
parent: None,
negate: false,
expected_type: None,
span: Default::default(),
}
}
@ -91,11 +95,11 @@ impl<'a> TypeChecker<'a> {
}
/// Returns the given type if it equals the expected type or the expected type is none.
pub(crate) fn assert_type(&self, type_: Type, expected: Option<Type>, span: Span) -> Type {
pub(crate) fn assert_type(&mut self, type_: Type, expected: Option<Type>) -> Type {
if let Some(expected) = expected {
if type_ != expected {
self.handler
.emit_err(TypeCheckerError::type_should_be(type_, expected, span).into());
.emit_err(TypeCheckerError::type_should_be(type_, expected, self.span).into());
}
}

View File

@ -0,0 +1,200 @@
// Copyright (C) 2019-2022 Aleo Systems Inc.
// This file is part of the Leo library.
// The Leo library is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
// The Leo library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use leo_ast::*;
use leo_errors::{emitter::Handler, TypeCheckerError};
use crate::{SymbolTable, TypeChecker};
pub(crate) struct Director<'a> {
visitor: TypeChecker<'a>,
}
impl<'a> Director<'a> {
pub(crate) fn new(symbol_table: &'a mut SymbolTable<'a>, handler: &'a Handler) -> Self {
Self {
visitor: TypeChecker::new(symbol_table, handler),
}
}
}
impl<'a> VisitorDirector<'a> for Director<'a> {
type Visitor = TypeChecker<'a>;
fn visitor(self) -> Self::Visitor {
self.visitor
}
fn visitor_ref(&mut self) -> &mut Self::Visitor {
&mut self.visitor
}
}
fn return_incorrect_type(t1: Option<Type>, t2: Option<Type>, expected: Option<Type>) -> Option<Type> {
match (t1, t2) {
(Some(t1), Some(t2)) if t1 == t2 => Some(t1),
(Some(t1), Some(t2)) => {
if let Some(expected) = expected {
if t1 != expected {
Some(t1)
} else {
Some(t2)
}
} else {
Some(t1)
}
}
(None, Some(_)) | (Some(_), None) | (None, None) => None,
}
}
impl<'a> ExpressionVisitorDirector<'a> for Director<'a> {
type Output = Type;
fn visit_expression(&mut self, input: &'a Expression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor.visit_expression(input).0 {
return match input {
Expression::Identifier(expr) => self.visit_identifier(expr),
Expression::Value(expr) => self.visit_value(expr),
Expression::Binary(expr) => self.visit_binary(expr),
Expression::Unary(expr) => self.visit_unary(expr),
Expression::Ternary(expr) => self.visit_ternary(expr),
Expression::Call(expr) => self.visit_call(expr),
Expression::Err(expr) => self.visit_err(expr),
};
}
None
}
fn visit_identifier(&mut self, input: &'a Identifier) -> Option<Self::Output> {
self.visitor.visit_identifier(input).1
}
fn visit_value(&mut self, input: &'a ValueExpression) -> Option<Self::Output> {
self.visitor.visit_value(input).1
}
fn visit_binary(&mut self, input: &'a BinaryExpression) -> Option<Self::Output> {
match self.visitor.visit_binary(input) {
(VisitResult::VisitChildren, expected) => {
let t1 = self.visit_expression(&input.left);
let t2 = self.visit_expression(&input.right);
return_incorrect_type(t1, t2, self.visitor.expected_type)
}
_ => None,
}
}
fn visit_unary(&mut self, input: &'a UnaryExpression) -> Option<Self::Output> {
match input.op {
UnaryOperation::Not => {
self.visitor.assert_type(Type::Boolean, self.visitor.expected_type);
self.visit_expression(&input.inner)
}
UnaryOperation::Negate => {
let prior_negate_state = self.visitor.negate;
self.visitor.negate = true;
let type_ = self.visit_expression(&input.inner);
self.visitor.negate = prior_negate_state;
match type_.as_ref() {
Some(
Type::IntegerType(
IntegerType::I8
| IntegerType::I16
| IntegerType::I32
| IntegerType::I64
| IntegerType::I128,
)
| Type::Field
| Type::Group,
) => {}
Some(t) => self
.visitor
.handler
.emit_err(TypeCheckerError::type_is_not_negatable(t, input.inner.span()).into()),
_ => {}
};
type_
}
}
}
fn visit_ternary(&mut self, input: &'a TernaryExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor.visit_ternary(input).0 {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = Some(Type::Boolean);
self.visit_expression(&input.condition);
self.visitor.expected_type = prev_expected_type;
let t1 = self.visit_expression(&input.if_true);
let t2 = self.visit_expression(&input.if_false);
return return_incorrect_type(t1, t2, self.visitor.expected_type);
}
None
}
fn visit_call(&mut self, input: &'a CallExpression) -> Option<Self::Output> {
match &*input.function {
Expression::Identifier(ident) => {
if let Some(func) = self.visitor.symbol_table.clone().lookup_fn(&ident.name) {
let ret = self.visitor.assert_type(func.output, self.visitor.expected_type);
if func.input.len() != input.arguments.len() {
self.visitor.handler.emit_err(
TypeCheckerError::incorrect_num_args_to_call(
func.input.len(),
input.arguments.len(),
input.span(),
)
.into(),
);
}
func.input
.iter()
.zip(input.arguments.iter())
.for_each(|(expected, argument)| {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = Some(expected.get_variable().type_);
self.visit_expression(argument);
self.visitor.expected_type = prev_expected_type;
});
Some(ret)
} else {
self.visitor
.handler
.emit_err(TypeCheckerError::unknown_sym("function", &ident.name, ident.span()).into());
None
}
}
expr => self.visit_expression(expr),
}
}
fn visit_err(&mut self, input: &'a ErrExpression) -> Option<Self::Output> {
self.visitor.visit_err(input).1
}
}
impl<'a> StatementVisitorDirector<'a> for Director<'a> {}
impl<'a> ProgramVisitorDirector<'a> for Director<'a> {}

View File

@ -26,9 +26,12 @@ pub use check_statements::*;
pub mod checker;
pub use checker::*;
pub mod director;
use director::*;
use crate::{Pass, SymbolTable};
use leo_ast::{Ast, VisitorDirector};
use leo_ast::{Ast, ProgramVisitorDirector};
use leo_errors::{emitter::Handler, Result};
impl<'a> Pass<'a> for TypeChecker<'a> {
@ -36,7 +39,7 @@ impl<'a> Pass<'a> for TypeChecker<'a> {
type Output = Result<()>;
fn do_pass((ast, symbol_table, handler): Self::Input) -> Self::Output {
let mut visitor = VisitorDirector::new(TypeChecker::new(symbol_table, handler));
let mut visitor = Director::new(symbol_table, handler);
visitor.visit_program(ast.as_repr());
handler.last_err()?;