mirror of
https://github.com/ProvableHQ/leo.git
synced 2024-12-24 02:31:44 +03:00
trying to modify visitor pattern to better fit type checking
This commit is contained in:
parent
bc174419f7
commit
97ef64aa66
@ -30,35 +30,37 @@ impl Default for VisitResult {
|
||||
}
|
||||
|
||||
pub trait ExpressionVisitor<'a> {
|
||||
fn visit_expression(&mut self, _input: &'a Expression) -> VisitResult {
|
||||
type Output;
|
||||
|
||||
fn visit_expression(&mut self, _input: &'a Expression) -> (VisitResult, Option<Self::Output>) {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
fn visit_identifier(&mut self, _input: &'a Identifier) -> VisitResult {
|
||||
fn visit_identifier(&mut self, _input: &'a Identifier) -> (VisitResult, Option<Self::Output>) {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
fn visit_value(&mut self, _input: &'a ValueExpression) -> VisitResult {
|
||||
fn visit_value(&mut self, _input: &'a ValueExpression) -> (VisitResult, Option<Self::Output>) {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
fn visit_binary(&mut self, _input: &'a BinaryExpression) -> VisitResult {
|
||||
fn visit_binary(&mut self, _input: &'a BinaryExpression) -> (VisitResult, Option<Self::Output>) {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
fn visit_unary(&mut self, _input: &'a UnaryExpression) -> VisitResult {
|
||||
fn visit_unary(&mut self, _input: &'a UnaryExpression) -> (VisitResult, Option<Self::Output>) {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
fn visit_ternary(&mut self, _input: &'a TernaryExpression) -> VisitResult {
|
||||
fn visit_ternary(&mut self, _input: &'a TernaryExpression) -> (VisitResult, Option<Self::Output>) {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
fn visit_call(&mut self, _input: &'a CallExpression) -> VisitResult {
|
||||
fn visit_call(&mut self, _input: &'a CallExpression) -> (VisitResult, Option<Self::Output>) {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
fn visit_err(&mut self, _input: &'a ErrExpression) -> VisitResult {
|
||||
fn visit_err(&mut self, _input: &'a ErrExpression) -> (VisitResult, Option<Self::Output>) {
|
||||
Default::default()
|
||||
}
|
||||
}
|
||||
|
@ -18,72 +18,87 @@
|
||||
//! It implements default methods for each node to be made
|
||||
//! given the type of node its visiting.
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::*;
|
||||
|
||||
pub struct VisitorDirector<'a, V: ExpressionVisitor<'a>> {
|
||||
visitor: V,
|
||||
lifetime: PhantomData<&'a ()>,
|
||||
pub trait VisitorDirector<'a> {
|
||||
type Visitor: ExpressionVisitor<'a> + ProgramVisitor<'a> + StatementVisitor<'a>;
|
||||
|
||||
fn visitor(self) -> Self::Visitor;
|
||||
|
||||
fn visitor_ref(&mut self) -> &mut Self::Visitor;
|
||||
}
|
||||
|
||||
impl<'a, V: ExpressionVisitor<'a>> VisitorDirector<'a, V> {
|
||||
pub fn new(visitor: V) -> Self {
|
||||
Self {
|
||||
visitor,
|
||||
lifetime: PhantomData,
|
||||
}
|
||||
}
|
||||
pub trait ExpressionVisitorDirector<'a>: VisitorDirector<'a> {
|
||||
type Output;
|
||||
|
||||
pub fn visitor(self) -> V {
|
||||
self.visitor
|
||||
}
|
||||
|
||||
pub fn visit_expression(&mut self, input: &'a Expression) {
|
||||
if let VisitResult::VisitChildren = self.visitor.visit_expression(input) {
|
||||
fn visit_expression(&mut self, input: &'a Expression) -> Option<Self::Output> {
|
||||
if let VisitResult::VisitChildren = self.visitor_ref().visit_expression(input).0 {
|
||||
match input {
|
||||
Expression::Identifier(_) => {}
|
||||
Expression::Value(_) => {}
|
||||
Expression::Identifier(expr) => self.visit_identifier(expr),
|
||||
Expression::Value(expr) => self.visit_value(expr),
|
||||
Expression::Binary(expr) => self.visit_binary(expr),
|
||||
Expression::Unary(expr) => self.visit_unary(expr),
|
||||
Expression::Ternary(expr) => self.visit_ternary(expr),
|
||||
Expression::Call(expr) => self.visit_call(expr),
|
||||
Expression::Err(_) => {}
|
||||
}
|
||||
Expression::Err(expr) => self.visit_err(expr),
|
||||
};
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
pub fn visit_binary(&mut self, input: &'a BinaryExpression) {
|
||||
if let VisitResult::VisitChildren = self.visitor.visit_binary(input) {
|
||||
fn visit_identifier(&mut self, input: &'a Identifier) -> Option<Self::Output> {
|
||||
self.visitor_ref().visit_identifier(input);
|
||||
None
|
||||
}
|
||||
|
||||
fn visit_value(&mut self, input: &'a ValueExpression) -> Option<Self::Output> {
|
||||
self.visitor_ref().visit_value(input);
|
||||
None
|
||||
}
|
||||
|
||||
fn visit_binary(&mut self, input: &'a BinaryExpression) -> Option<Self::Output> {
|
||||
if let VisitResult::VisitChildren = self.visitor_ref().visit_binary(input).0 {
|
||||
self.visit_expression(&input.left);
|
||||
self.visit_expression(&input.right);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn visit_unary(&mut self, input: &'a UnaryExpression) {
|
||||
if let VisitResult::VisitChildren = self.visitor.visit_unary(input) {
|
||||
fn visit_unary(&mut self, input: &'a UnaryExpression) -> Option<Self::Output> {
|
||||
if let VisitResult::VisitChildren = self.visitor_ref().visit_unary(input).0 {
|
||||
self.visit_expression(&input.inner);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn visit_ternary(&mut self, input: &'a TernaryExpression) {
|
||||
if let VisitResult::VisitChildren = self.visitor.visit_ternary(input) {
|
||||
fn visit_ternary(&mut self, input: &'a TernaryExpression) -> Option<Self::Output> {
|
||||
if let VisitResult::VisitChildren = self.visitor_ref().visit_ternary(input).0 {
|
||||
self.visit_expression(&input.condition);
|
||||
self.visit_expression(&input.if_true);
|
||||
self.visit_expression(&input.if_false);
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn visit_call(&mut self, input: &'a CallExpression) {
|
||||
if let VisitResult::VisitChildren = self.visitor.visit_call(input) {
|
||||
input.arguments.iter().for_each(|expr| self.visit_expression(expr));
|
||||
fn visit_call(&mut self, input: &'a CallExpression) -> Option<Self::Output> {
|
||||
if let VisitResult::VisitChildren = self.visitor_ref().visit_call(input).0 {
|
||||
input.arguments.iter().for_each(|expr| {
|
||||
self.visit_expression(expr);
|
||||
});
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn visit_err(&mut self, input: &'a ErrExpression) -> Option<Self::Output> {
|
||||
self.visitor_ref().visit_err(input);
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, V: ExpressionVisitor<'a> + StatementVisitor<'a>> VisitorDirector<'a, V> {
|
||||
pub fn visit_statement(&mut self, input: &'a Statement) {
|
||||
if let VisitResult::VisitChildren = self.visitor.visit_statement(input) {
|
||||
pub trait StatementVisitorDirector<'a>: VisitorDirector<'a> + ExpressionVisitorDirector<'a> {
|
||||
fn visit_statement(&mut self, input: &'a Statement) {
|
||||
if let VisitResult::VisitChildren = self.visitor_ref().visit_statement(input) {
|
||||
match input {
|
||||
Statement::Return(stmt) => self.visit_return(stmt),
|
||||
Statement::Definition(stmt) => self.visit_definition(stmt),
|
||||
@ -96,26 +111,26 @@ impl<'a, V: ExpressionVisitor<'a> + StatementVisitor<'a>> VisitorDirector<'a, V>
|
||||
}
|
||||
}
|
||||
|
||||
pub fn visit_return(&mut self, input: &'a ReturnStatement) {
|
||||
if let VisitResult::VisitChildren = self.visitor.visit_return(input) {
|
||||
self.visit_expression(&input.expression);
|
||||
fn visit_return(&mut self, input: &'a ReturnStatement) {
|
||||
if let VisitResult::VisitChildren = self.visitor_ref().visit_return(input) {
|
||||
self.visitor_ref().visit_expression(&input.expression);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn visit_definition(&mut self, input: &'a DefinitionStatement) {
|
||||
if let VisitResult::VisitChildren = self.visitor.visit_definition(input) {
|
||||
fn visit_definition(&mut self, input: &'a DefinitionStatement) {
|
||||
if let VisitResult::VisitChildren = self.visitor_ref().visit_definition(input) {
|
||||
self.visit_expression(&input.value);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn visit_assign(&mut self, input: &'a AssignStatement) {
|
||||
if let VisitResult::VisitChildren = self.visitor.visit_assign(input) {
|
||||
fn visit_assign(&mut self, input: &'a AssignStatement) {
|
||||
if let VisitResult::VisitChildren = self.visitor_ref().visit_assign(input) {
|
||||
self.visit_expression(&input.value);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn visit_conditional(&mut self, input: &'a ConditionalStatement) {
|
||||
if let VisitResult::VisitChildren = self.visitor.visit_conditional(input) {
|
||||
fn visit_conditional(&mut self, input: &'a ConditionalStatement) {
|
||||
if let VisitResult::VisitChildren = self.visitor_ref().visit_conditional(input) {
|
||||
self.visit_expression(&input.condition);
|
||||
self.visit_block(&input.block);
|
||||
if let Some(stmt) = input.next.as_ref() {
|
||||
@ -124,35 +139,38 @@ impl<'a, V: ExpressionVisitor<'a> + StatementVisitor<'a>> VisitorDirector<'a, V>
|
||||
}
|
||||
}
|
||||
|
||||
pub fn visit_iteration(&mut self, input: &'a IterationStatement) {
|
||||
if let VisitResult::VisitChildren = self.visitor.visit_iteration(input) {
|
||||
fn visit_iteration(&mut self, input: &'a IterationStatement) {
|
||||
if let VisitResult::VisitChildren = self.visitor_ref().visit_iteration(input) {
|
||||
self.visit_expression(&input.start);
|
||||
self.visit_expression(&input.stop);
|
||||
self.visit_block(&input.block);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn visit_console(&mut self, input: &'a ConsoleStatement) {
|
||||
if let VisitResult::VisitChildren = self.visitor.visit_console(input) {
|
||||
fn visit_console(&mut self, input: &'a ConsoleStatement) {
|
||||
if let VisitResult::VisitChildren = self.visitor_ref().visit_console(input) {
|
||||
match &input.function {
|
||||
ConsoleFunction::Assert(expr) => self.visit_expression(expr),
|
||||
ConsoleFunction::Error(fmt) | ConsoleFunction::Log(fmt) => {
|
||||
fmt.parameters.iter().for_each(|expr| self.visit_expression(expr));
|
||||
fmt.parameters.iter().for_each(|expr| {
|
||||
self.visit_expression(expr);
|
||||
});
|
||||
None
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub fn visit_block(&mut self, input: &'a Block) {
|
||||
if let VisitResult::VisitChildren = self.visitor.visit_block(input) {
|
||||
fn visit_block(&mut self, input: &'a Block) {
|
||||
if let VisitResult::VisitChildren = self.visitor_ref().visit_block(input) {
|
||||
input.statements.iter().for_each(|stmt| self.visit_statement(stmt));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, V: ExpressionVisitor<'a> + ProgramVisitor<'a> + StatementVisitor<'a>> VisitorDirector<'a, V> {
|
||||
pub fn visit_program(&mut self, input: &'a Program) {
|
||||
if let VisitResult::VisitChildren = self.visitor.visit_program(input) {
|
||||
pub trait ProgramVisitorDirector<'a>: VisitorDirector<'a> + StatementVisitorDirector<'a> {
|
||||
fn visit_program(&mut self, input: &'a Program) {
|
||||
if let VisitResult::VisitChildren = self.visitor_ref().visit_program(input) {
|
||||
input
|
||||
.functions
|
||||
.values()
|
||||
@ -160,8 +178,8 @@ impl<'a, V: ExpressionVisitor<'a> + ProgramVisitor<'a> + StatementVisitor<'a>> V
|
||||
}
|
||||
}
|
||||
|
||||
pub fn visit_function(&mut self, input: &'a Function) {
|
||||
if let VisitResult::VisitChildren = self.visitor.visit_function(input) {
|
||||
fn visit_function(&mut self, input: &'a Function) {
|
||||
if let VisitResult::VisitChildren = self.visitor_ref().visit_function(input) {
|
||||
self.visit_block(&input.block);
|
||||
}
|
||||
}
|
||||
|
@ -36,7 +36,9 @@ impl<'a> CreateSymbolTable<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ExpressionVisitor<'a> for CreateSymbolTable<'a> {}
|
||||
impl<'a> ExpressionVisitor<'a> for CreateSymbolTable<'a> {
|
||||
type Output = ();
|
||||
}
|
||||
|
||||
impl<'a> StatementVisitor<'a> for CreateSymbolTable<'a> {}
|
||||
|
||||
|
52
compiler/passes/src/symbol_table/director.rs
Normal file
52
compiler/passes/src/symbol_table/director.rs
Normal file
@ -0,0 +1,52 @@
|
||||
// Copyright (C) 2019-2022 Aleo Systems Inc.
|
||||
// This file is part of the Leo library.
|
||||
|
||||
// The Leo library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
|
||||
// The Leo library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
use leo_ast::{ExpressionVisitorDirector, ProgramVisitorDirector, StatementVisitorDirector, VisitorDirector};
|
||||
use leo_errors::emitter::Handler;
|
||||
|
||||
use crate::CreateSymbolTable;
|
||||
|
||||
pub(crate) struct Director<'a> {
|
||||
visitor: CreateSymbolTable<'a>,
|
||||
}
|
||||
|
||||
impl<'a> Director<'a> {
|
||||
pub(crate) fn new(handler: &'a Handler) -> Self {
|
||||
Self {
|
||||
visitor: CreateSymbolTable::new(handler),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> VisitorDirector<'a> for Director<'a> {
|
||||
type Visitor = CreateSymbolTable<'a>;
|
||||
|
||||
fn visitor(self) -> Self::Visitor {
|
||||
self.visitor
|
||||
}
|
||||
|
||||
fn visitor_ref(&mut self) -> &mut Self::Visitor {
|
||||
&mut self.visitor
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ExpressionVisitorDirector<'a> for Director<'a> {
|
||||
type Output = ();
|
||||
}
|
||||
|
||||
impl<'a> StatementVisitorDirector<'a> for Director<'a> {}
|
||||
|
||||
impl<'a> ProgramVisitorDirector<'a> for Director<'a> {}
|
@ -17,6 +17,9 @@
|
||||
pub mod create;
|
||||
pub use create::*;
|
||||
|
||||
pub mod director;
|
||||
use director::*;
|
||||
|
||||
pub mod table;
|
||||
pub use table::*;
|
||||
|
||||
@ -28,7 +31,7 @@ pub use variable_symbol::*;
|
||||
|
||||
use crate::Pass;
|
||||
|
||||
use leo_ast::{Ast, VisitorDirector};
|
||||
use leo_ast::{Ast, ProgramVisitorDirector, VisitorDirector};
|
||||
use leo_errors::{emitter::Handler, Result};
|
||||
|
||||
impl<'a> Pass<'a> for CreateSymbolTable<'a> {
|
||||
@ -36,7 +39,7 @@ impl<'a> Pass<'a> for CreateSymbolTable<'a> {
|
||||
type Output = Result<SymbolTable<'a>>;
|
||||
|
||||
fn do_pass((ast, handler): Self::Input) -> Self::Output {
|
||||
let mut visitor = VisitorDirector::new(CreateSymbolTable::new(handler));
|
||||
let mut visitor = Director::new(handler);
|
||||
visitor.visit_program(ast.as_repr());
|
||||
handler.last_err()?;
|
||||
|
||||
|
@ -16,7 +16,6 @@
|
||||
|
||||
use leo_ast::*;
|
||||
use leo_errors::TypeCheckerError;
|
||||
use leo_span::Span;
|
||||
|
||||
use crate::TypeChecker;
|
||||
|
||||
@ -38,292 +37,239 @@ fn return_incorrect_type(t1: Option<Type>, t2: Option<Type>, expected: Option<Ty
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> TypeChecker<'a> {
|
||||
pub(crate) fn compare_expr_type(&mut self, expr: &Expression, expected: Option<Type>, span: Span) -> Option<Type> {
|
||||
match expr {
|
||||
Expression::Identifier(ident) => {
|
||||
if let Some(var) = self.symbol_table.lookup_variable(&ident.name) {
|
||||
Some(self.assert_type(*var.type_, expected, span))
|
||||
} else {
|
||||
self.handler
|
||||
.emit_err(TypeCheckerError::unknown_sym("variable", ident.name, span).into());
|
||||
None
|
||||
}
|
||||
}
|
||||
Expression::Value(value) => match value {
|
||||
ValueExpression::Address(_, _) => Some(self.assert_type(Type::Address, expected, value.span())),
|
||||
ValueExpression::Boolean(_, _) => Some(self.assert_type(Type::Boolean, expected, value.span())),
|
||||
ValueExpression::Field(_, _) => Some(self.assert_type(Type::Field, expected, value.span())),
|
||||
ValueExpression::Integer(type_, str_content, _) => {
|
||||
match type_ {
|
||||
IntegerType::I8 => {
|
||||
let int = if self.negate {
|
||||
format!("-{str_content}")
|
||||
} else {
|
||||
str_content.clone()
|
||||
};
|
||||
impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
|
||||
type Output = Type;
|
||||
|
||||
if int.parse::<i8>().is_err() {
|
||||
self.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(int, "i8", value.span()).into());
|
||||
}
|
||||
}
|
||||
IntegerType::I16 => {
|
||||
let int = if self.negate {
|
||||
format!("-{str_content}")
|
||||
} else {
|
||||
str_content.clone()
|
||||
};
|
||||
fn visit_identifier(&mut self, input: &'a Identifier) -> (VisitResult, Option<Self::Output>) {
|
||||
let type_ = if let Some(var) = self.symbol_table.lookup_variable(&input.name) {
|
||||
Some(self.assert_type(*var.type_, self.expected_type))
|
||||
} else {
|
||||
self.handler
|
||||
.emit_err(TypeCheckerError::unknown_sym("variable", input.name, self.span).into());
|
||||
None
|
||||
};
|
||||
|
||||
if int.parse::<i16>().is_err() {
|
||||
self.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(int, "i16", value.span()).into());
|
||||
}
|
||||
}
|
||||
IntegerType::I32 => {
|
||||
let int = if self.negate {
|
||||
format!("-{str_content}")
|
||||
} else {
|
||||
str_content.clone()
|
||||
};
|
||||
(VisitResult::VisitChildren, type_)
|
||||
}
|
||||
|
||||
if int.parse::<i32>().is_err() {
|
||||
self.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(int, "i32", value.span()).into());
|
||||
}
|
||||
}
|
||||
IntegerType::I64 => {
|
||||
let int = if self.negate {
|
||||
format!("-{str_content}")
|
||||
} else {
|
||||
str_content.clone()
|
||||
};
|
||||
fn visit_value(&mut self, input: &'a ValueExpression) -> (VisitResult, Option<Self::Output>) {
|
||||
let prev_span = self.span;
|
||||
self.span = input.span();
|
||||
|
||||
if int.parse::<i64>().is_err() {
|
||||
self.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(int, "i64", value.span()).into());
|
||||
}
|
||||
}
|
||||
IntegerType::I128 => {
|
||||
let int = if self.negate {
|
||||
format!("-{str_content}")
|
||||
} else {
|
||||
str_content.clone()
|
||||
};
|
||||
let type_ = Some(match input {
|
||||
ValueExpression::Address(_, _) => self.assert_type(Type::Address, self.expected_type),
|
||||
ValueExpression::Boolean(_, _) => self.assert_type(Type::Boolean, self.expected_type),
|
||||
ValueExpression::Field(_, _) => self.assert_type(Type::Field, self.expected_type),
|
||||
ValueExpression::Integer(type_, str_content, _) => {
|
||||
match type_ {
|
||||
IntegerType::I8 => {
|
||||
let int = if self.negate {
|
||||
format!("-{str_content}")
|
||||
} else {
|
||||
str_content.clone()
|
||||
};
|
||||
|
||||
if int.parse::<i128>().is_err() {
|
||||
self.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(int, "i128", value.span()).into());
|
||||
}
|
||||
}
|
||||
IntegerType::U8 if str_content.parse::<u8>().is_err() => self
|
||||
.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u8", value.span()).into()),
|
||||
IntegerType::U16 if str_content.parse::<u16>().is_err() => self
|
||||
.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u16", value.span()).into()),
|
||||
IntegerType::U32 if str_content.parse::<u32>().is_err() => self
|
||||
.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u32", value.span()).into()),
|
||||
IntegerType::U64 if str_content.parse::<u64>().is_err() => self
|
||||
.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u64", value.span()).into()),
|
||||
IntegerType::U128 if str_content.parse::<u128>().is_err() => self
|
||||
.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u128", value.span()).into()),
|
||||
_ => {}
|
||||
}
|
||||
Some(self.assert_type(Type::IntegerType(*type_), expected, value.span()))
|
||||
}
|
||||
ValueExpression::Group(_) => Some(self.assert_type(Type::Group, expected, value.span())),
|
||||
ValueExpression::Scalar(_, _) => Some(self.assert_type(Type::Scalar, expected, value.span())),
|
||||
ValueExpression::String(_, _) => unreachable!("String types are not reachable"),
|
||||
},
|
||||
Expression::Binary(binary) => match binary.op {
|
||||
BinaryOperation::And | BinaryOperation::Or => {
|
||||
self.assert_type(Type::Boolean, expected, binary.span());
|
||||
let t1 = self.compare_expr_type(&binary.left, expected, binary.left.span());
|
||||
let t2 = self.compare_expr_type(&binary.right, expected, binary.right.span());
|
||||
|
||||
return_incorrect_type(t1, t2, expected)
|
||||
}
|
||||
BinaryOperation::Add => {
|
||||
self.assert_field_group_scalar_int_type(expected, binary.span());
|
||||
let t1 = self.compare_expr_type(&binary.left, expected, binary.left.span());
|
||||
let t2 = self.compare_expr_type(&binary.right, expected, binary.right.span());
|
||||
|
||||
return_incorrect_type(t1, t2, expected)
|
||||
}
|
||||
BinaryOperation::Sub => {
|
||||
self.assert_field_group_int_type(expected, binary.span());
|
||||
let t1 = self.compare_expr_type(&binary.left, expected, binary.left.span());
|
||||
let t2 = self.compare_expr_type(&binary.right, expected, binary.right.span());
|
||||
|
||||
return_incorrect_type(t1, t2, expected)
|
||||
}
|
||||
BinaryOperation::Mul => {
|
||||
self.assert_field_group_int_type(expected, binary.span());
|
||||
|
||||
let t1 = self.compare_expr_type(&binary.left, None, binary.left.span());
|
||||
let t2 = self.compare_expr_type(&binary.right, None, binary.right.span());
|
||||
|
||||
// Allow `group` * `scalar` multiplication.
|
||||
match (t1.as_ref(), t2.as_ref()) {
|
||||
(Some(Type::Group), Some(other)) => {
|
||||
self.assert_type(Type::Group, expected, binary.left.span());
|
||||
self.assert_type(*other, Some(Type::Scalar), binary.span());
|
||||
Some(Type::Group)
|
||||
}
|
||||
(Some(other), Some(Type::Group)) => {
|
||||
self.assert_type(Type::Group, expected, binary.left.span());
|
||||
self.assert_type(*other, Some(Type::Scalar), binary.span());
|
||||
Some(Type::Group)
|
||||
}
|
||||
_ => {
|
||||
self.assert_type(t1.unwrap(), expected, binary.left.span());
|
||||
self.assert_type(t2.unwrap(), expected, binary.right.span());
|
||||
return_incorrect_type(t1, t2, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
BinaryOperation::Div => {
|
||||
self.assert_field_int_type(expected, binary.span());
|
||||
|
||||
let t1 = self.compare_expr_type(&binary.left, expected, binary.left.span());
|
||||
let t2 = self.compare_expr_type(&binary.right, expected, binary.right.span());
|
||||
return_incorrect_type(t1, t2, expected)
|
||||
}
|
||||
BinaryOperation::Pow => {
|
||||
let t1 = self.compare_expr_type(&binary.left, None, binary.left.span());
|
||||
let t2 = self.compare_expr_type(&binary.right, None, binary.right.span());
|
||||
|
||||
match (t1.as_ref(), t2.as_ref()) {
|
||||
// Type A must be an int.
|
||||
// Type B must be a unsigned int.
|
||||
(Some(Type::IntegerType(_)), Some(Type::IntegerType(itype))) if !itype.is_signed() => {
|
||||
self.assert_type(t1.unwrap(), expected, binary.span());
|
||||
}
|
||||
// Type A was an int.
|
||||
// But Type B was not a unsigned int.
|
||||
(Some(Type::IntegerType(_)), Some(t)) => {
|
||||
self.handler.emit_err(
|
||||
TypeCheckerError::incorrect_pow_exponent_type("unsigned int", t, binary.right.span())
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
// Type A must be a field.
|
||||
// Type B must be an int.
|
||||
(Some(Type::Field), Some(Type::IntegerType(_))) => {
|
||||
self.assert_type(Type::Field, expected, binary.span());
|
||||
}
|
||||
// Type A was a field.
|
||||
// But Type B was not an int.
|
||||
(Some(Type::Field), Some(t)) => {
|
||||
self.handler.emit_err(
|
||||
TypeCheckerError::incorrect_pow_exponent_type("int", t, binary.right.span()).into(),
|
||||
);
|
||||
}
|
||||
// The base is some type thats not an int or field.
|
||||
(Some(t), _) => {
|
||||
if int.parse::<i8>().is_err() {
|
||||
self.handler
|
||||
.emit_err(TypeCheckerError::incorrect_pow_base_type(t, binary.left.span()).into());
|
||||
.emit_err(TypeCheckerError::invalid_int_value(int, "i8", input.span()).into());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
IntegerType::I16 => {
|
||||
let int = if self.negate {
|
||||
format!("-{str_content}")
|
||||
} else {
|
||||
str_content.clone()
|
||||
};
|
||||
|
||||
t1
|
||||
if int.parse::<i16>().is_err() {
|
||||
self.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(int, "i16", input.span()).into());
|
||||
}
|
||||
}
|
||||
IntegerType::I32 => {
|
||||
let int = if self.negate {
|
||||
format!("-{str_content}")
|
||||
} else {
|
||||
str_content.clone()
|
||||
};
|
||||
|
||||
if int.parse::<i32>().is_err() {
|
||||
self.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(int, "i32", input.span()).into());
|
||||
}
|
||||
}
|
||||
IntegerType::I64 => {
|
||||
let int = if self.negate {
|
||||
format!("-{str_content}")
|
||||
} else {
|
||||
str_content.clone()
|
||||
};
|
||||
|
||||
if int.parse::<i64>().is_err() {
|
||||
self.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(int, "i64", input.span()).into());
|
||||
}
|
||||
}
|
||||
IntegerType::I128 => {
|
||||
let int = if self.negate {
|
||||
format!("-{str_content}")
|
||||
} else {
|
||||
str_content.clone()
|
||||
};
|
||||
|
||||
if int.parse::<i128>().is_err() {
|
||||
self.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(int, "i128", input.span()).into());
|
||||
}
|
||||
}
|
||||
IntegerType::U8 if str_content.parse::<u8>().is_err() => self
|
||||
.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u8", input.span()).into()),
|
||||
IntegerType::U16 if str_content.parse::<u16>().is_err() => self
|
||||
.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u16", input.span()).into()),
|
||||
IntegerType::U32 if str_content.parse::<u32>().is_err() => self
|
||||
.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u32", input.span()).into()),
|
||||
IntegerType::U64 if str_content.parse::<u64>().is_err() => self
|
||||
.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u64", input.span()).into()),
|
||||
IntegerType::U128 if str_content.parse::<u128>().is_err() => self
|
||||
.handler
|
||||
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u128", input.span()).into()),
|
||||
_ => {}
|
||||
}
|
||||
BinaryOperation::Eq | BinaryOperation::Ne => {
|
||||
let t1 = self.compare_expr_type(&binary.left, None, binary.left.span());
|
||||
let t2 = self.compare_expr_type(&binary.right, None, binary.right.span());
|
||||
|
||||
self.assert_eq_types(t1, t2, binary.span());
|
||||
|
||||
Some(Type::Boolean)
|
||||
}
|
||||
BinaryOperation::Lt | BinaryOperation::Gt | BinaryOperation::Le | BinaryOperation::Ge => {
|
||||
let t1 = self.compare_expr_type(&binary.left, None, binary.left.span());
|
||||
self.assert_field_scalar_int_type(t1, binary.left.span());
|
||||
|
||||
let t2 = self.compare_expr_type(&binary.right, None, binary.right.span());
|
||||
self.assert_field_scalar_int_type(t2, binary.right.span());
|
||||
|
||||
self.assert_eq_types(t1, t2, binary.span());
|
||||
|
||||
Some(Type::Boolean)
|
||||
}
|
||||
},
|
||||
Expression::Unary(unary) => match unary.op {
|
||||
UnaryOperation::Not => {
|
||||
self.assert_type(Type::Boolean, expected, unary.span());
|
||||
self.compare_expr_type(&unary.inner, expected, unary.inner.span())
|
||||
}
|
||||
UnaryOperation::Negate => {
|
||||
let prior_negate_state = self.negate;
|
||||
self.negate = true;
|
||||
let type_ = self.compare_expr_type(&unary.inner, expected, unary.inner.span());
|
||||
self.negate = prior_negate_state;
|
||||
match type_.as_ref() {
|
||||
Some(
|
||||
Type::IntegerType(
|
||||
IntegerType::I8
|
||||
| IntegerType::I16
|
||||
| IntegerType::I32
|
||||
| IntegerType::I64
|
||||
| IntegerType::I128,
|
||||
)
|
||||
| Type::Field
|
||||
| Type::Group,
|
||||
) => {}
|
||||
Some(t) => self
|
||||
.handler
|
||||
.emit_err(TypeCheckerError::type_is_not_negatable(t, unary.inner.span()).into()),
|
||||
_ => {}
|
||||
};
|
||||
type_
|
||||
}
|
||||
},
|
||||
Expression::Ternary(ternary) => {
|
||||
self.compare_expr_type(&ternary.condition, Some(Type::Boolean), ternary.condition.span());
|
||||
let t1 = self.compare_expr_type(&ternary.if_true, expected, ternary.if_true.span());
|
||||
let t2 = self.compare_expr_type(&ternary.if_false, expected, ternary.if_false.span());
|
||||
return_incorrect_type(t1, t2, expected)
|
||||
self.assert_type(Type::IntegerType(*type_), self.expected_type)
|
||||
}
|
||||
Expression::Call(call) => match &*call.function {
|
||||
Expression::Identifier(ident) => {
|
||||
if let Some(func) = self.symbol_table.lookup_fn(&ident.name) {
|
||||
let ret = self.assert_type(func.output, expected, ident.span());
|
||||
ValueExpression::Group(_) => self.assert_type(Type::Group, self.expected_type),
|
||||
ValueExpression::Scalar(_, _) => self.assert_type(Type::Scalar, self.expected_type),
|
||||
ValueExpression::String(_, _) => unreachable!("String types are not reachable"),
|
||||
});
|
||||
|
||||
if func.input.len() != call.arguments.len() {
|
||||
self.handler.emit_err(
|
||||
TypeCheckerError::incorrect_num_args_to_call(
|
||||
func.input.len(),
|
||||
call.arguments.len(),
|
||||
call.span(),
|
||||
)
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
self.span = prev_span;
|
||||
(VisitResult::VisitChildren, type_)
|
||||
}
|
||||
|
||||
func.input
|
||||
.iter()
|
||||
.zip(call.arguments.iter())
|
||||
.for_each(|(expected, argument)| {
|
||||
self.compare_expr_type(argument, Some(expected.get_variable().type_), argument.span());
|
||||
});
|
||||
fn visit_binary(&mut self, input: &'a BinaryExpression) -> (VisitResult, Option<Self::Output>) {
|
||||
let prev_span = self.span;
|
||||
self.span = input.span();
|
||||
|
||||
Some(ret)
|
||||
} else {
|
||||
self.handler
|
||||
.emit_err(TypeCheckerError::unknown_sym("function", &ident.name, ident.span()).into());
|
||||
None
|
||||
/* let type_ = match input.op {
|
||||
BinaryOperation::And | BinaryOperation::Or => {
|
||||
self.assert_type(Type::Boolean, self.expected_type);
|
||||
let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span());
|
||||
let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span());
|
||||
|
||||
return_incorrect_type(t1, t2, self.expected_type)
|
||||
}
|
||||
BinaryOperation::Add => {
|
||||
self.assert_field_group_scalar_int_type(self.expected_type, input.span());
|
||||
let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span());
|
||||
let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span());
|
||||
|
||||
return_incorrect_type(t1, t2, self.expected_type)
|
||||
}
|
||||
BinaryOperation::Sub => {
|
||||
self.assert_field_group_int_type(self.expected_type, input.span());
|
||||
let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span());
|
||||
let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span());
|
||||
|
||||
return_incorrect_type(t1, t2, self.expected_type)
|
||||
}
|
||||
BinaryOperation::Mul => {
|
||||
self.assert_field_group_int_type(self.expected_type, input.span());
|
||||
|
||||
let t1 = self.compare_expr_type(&input.left, None, input.left.span());
|
||||
let t2 = self.compare_expr_type(&input.right, None, input.right.span());
|
||||
|
||||
// Allow `group` * `scalar` multiplication.
|
||||
match (t1.as_ref(), t2.as_ref()) {
|
||||
(Some(Type::Group), Some(other)) => {
|
||||
self.assert_type(Type::Group, self.expected_type);
|
||||
self.assert_type(*other, Some(Type::Scalar));
|
||||
Some(Type::Group)
|
||||
}
|
||||
(Some(other), Some(Type::Group)) => {
|
||||
self.assert_type(Type::Group, self.expected_type);
|
||||
self.assert_type(*other, Some(Type::Scalar));
|
||||
Some(Type::Group)
|
||||
}
|
||||
_ => {
|
||||
self.assert_type(t1.unwrap(), self.expected_type);
|
||||
self.assert_type(t2.unwrap(), self.expected_type);
|
||||
return_incorrect_type(t1, t2, self.expected_type)
|
||||
}
|
||||
}
|
||||
expr => self.compare_expr_type(expr, expected, call.span()),
|
||||
},
|
||||
Expression::Err(_) => None,
|
||||
}
|
||||
}
|
||||
BinaryOperation::Div => {
|
||||
self.assert_field_int_type(self.expected_type, input.span());
|
||||
|
||||
let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span());
|
||||
let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span());
|
||||
return_incorrect_type(t1, t2, self.expected_type)
|
||||
}
|
||||
BinaryOperation::Pow => {
|
||||
let t1 = self.compare_expr_type(&input.left, None, input.left.span());
|
||||
let t2 = self.compare_expr_type(&input.right, None, input.right.span());
|
||||
|
||||
match (t1.as_ref(), t2.as_ref()) {
|
||||
// Type A must be an int.
|
||||
// Type B must be a unsigned int.
|
||||
(Some(Type::IntegerType(_)), Some(Type::IntegerType(itype))) if !itype.is_signed() => {
|
||||
self.assert_type(t1.unwrap(), self.expected_type);
|
||||
}
|
||||
// Type A was an int.
|
||||
// But Type B was not a unsigned int.
|
||||
(Some(Type::IntegerType(_)), Some(t)) => {
|
||||
self.handler.emit_err(
|
||||
TypeCheckerError::incorrect_pow_exponent_type("unsigned int", t, input.right.span())
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
// Type A must be a field.
|
||||
// Type B must be an int.
|
||||
(Some(Type::Field), Some(Type::IntegerType(_))) => {
|
||||
self.assert_type(Type::Field, self.expected_type);
|
||||
}
|
||||
// Type A was a field.
|
||||
// But Type B was not an int.
|
||||
(Some(Type::Field), Some(t)) => {
|
||||
self.handler.emit_err(
|
||||
TypeCheckerError::incorrect_pow_exponent_type("int", t, input.right.span()).into(),
|
||||
);
|
||||
}
|
||||
// The base is some type thats not an int or field.
|
||||
(Some(t), _) => {
|
||||
self.handler
|
||||
.emit_err(TypeCheckerError::incorrect_pow_base_type(t, input.left.span()).into());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
t1
|
||||
}
|
||||
BinaryOperation::Eq | BinaryOperation::Ne => {
|
||||
let t1 = self.compare_expr_type(&input.left, None, input.left.span());
|
||||
let t2 = self.compare_expr_type(&input.right, None, input.right.span());
|
||||
|
||||
self.assert_eq_types(t1, t2, input.span());
|
||||
|
||||
Some(Type::Boolean)
|
||||
}
|
||||
BinaryOperation::Lt | BinaryOperation::Gt | BinaryOperation::Le | BinaryOperation::Ge => {
|
||||
let t1 = self.compare_expr_type(&input.left, None, input.left.span());
|
||||
self.assert_field_scalar_int_type(t1, input.left.span());
|
||||
|
||||
let t2 = self.compare_expr_type(&input.right, None, input.right.span());
|
||||
self.assert_field_scalar_int_type(t2, input.right.span());
|
||||
|
||||
self.assert_eq_types(t1, t2, input.span());
|
||||
|
||||
Some(Type::Boolean)
|
||||
}
|
||||
}; */
|
||||
|
||||
self.span = prev_span;
|
||||
(VisitResult::VisitChildren, None)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {}
|
||||
|
@ -27,7 +27,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
|
||||
|
||||
// Would never be None.
|
||||
let func_output_type = self.symbol_table.lookup_fn(&parent).map(|f| f.output);
|
||||
self.compare_expr_type(&input.expression, func_output_type, input.expression.span());
|
||||
// self.compare_expr_type(&input.expression, func_output_type, input.expression.span());
|
||||
|
||||
VisitResult::VisitChildren
|
||||
}
|
||||
@ -51,7 +51,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
|
||||
self.handler.emit_err(err);
|
||||
}
|
||||
|
||||
self.compare_expr_type(&input.value, Some(input.type_), input.value.span());
|
||||
// self.compare_expr_type(&input.value, Some(input.type_), input.value.span());
|
||||
});
|
||||
|
||||
VisitResult::VisitChildren
|
||||
@ -80,14 +80,14 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
|
||||
};
|
||||
|
||||
if var_type.is_some() {
|
||||
self.compare_expr_type(&input.value, var_type, input.value.span());
|
||||
// self.compare_expr_type(&input.value, var_type, input.value.span());
|
||||
}
|
||||
|
||||
VisitResult::VisitChildren
|
||||
}
|
||||
|
||||
fn visit_conditional(&mut self, input: &'a ConditionalStatement) -> VisitResult {
|
||||
self.compare_expr_type(&input.condition, Some(Type::Boolean), input.condition.span());
|
||||
// self.compare_expr_type(&input.condition, Some(Type::Boolean), input.condition.span());
|
||||
|
||||
VisitResult::VisitChildren
|
||||
}
|
||||
@ -104,8 +104,8 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
|
||||
self.handler.emit_err(err);
|
||||
}
|
||||
|
||||
self.compare_expr_type(&input.start, Some(input.type_), input.start.span());
|
||||
self.compare_expr_type(&input.stop, Some(input.type_), input.stop.span());
|
||||
// self.compare_expr_type(&input.start, Some(input.type_), input.start.span());
|
||||
// self.compare_expr_type(&input.stop, Some(input.type_), input.stop.span());
|
||||
|
||||
VisitResult::VisitChildren
|
||||
}
|
||||
@ -113,7 +113,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
|
||||
fn visit_console(&mut self, input: &'a ConsoleStatement) -> VisitResult {
|
||||
match &input.function {
|
||||
ConsoleFunction::Assert(expr) => {
|
||||
self.compare_expr_type(expr, Some(Type::Boolean), expr.span());
|
||||
// self.compare_expr_type(expr, Some(Type::Boolean), expr.span());
|
||||
}
|
||||
ConsoleFunction::Error(_) | ConsoleFunction::Log(_) => {
|
||||
// TODO: undetermined
|
||||
|
@ -25,6 +25,8 @@ pub struct TypeChecker<'a> {
|
||||
pub(crate) handler: &'a Handler,
|
||||
pub(crate) parent: Option<Symbol>,
|
||||
pub(crate) negate: bool,
|
||||
pub(crate) expected_type: Option<Type>,
|
||||
pub(crate) span: Span,
|
||||
}
|
||||
|
||||
const INT_TYPES: [Type; 10] = [
|
||||
@ -74,6 +76,8 @@ impl<'a> TypeChecker<'a> {
|
||||
handler,
|
||||
parent: None,
|
||||
negate: false,
|
||||
expected_type: None,
|
||||
span: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,11 +95,11 @@ impl<'a> TypeChecker<'a> {
|
||||
}
|
||||
|
||||
/// Returns the given type if it equals the expected type or the expected type is none.
|
||||
pub(crate) fn assert_type(&self, type_: Type, expected: Option<Type>, span: Span) -> Type {
|
||||
pub(crate) fn assert_type(&mut self, type_: Type, expected: Option<Type>) -> Type {
|
||||
if let Some(expected) = expected {
|
||||
if type_ != expected {
|
||||
self.handler
|
||||
.emit_err(TypeCheckerError::type_should_be(type_, expected, span).into());
|
||||
.emit_err(TypeCheckerError::type_should_be(type_, expected, self.span).into());
|
||||
}
|
||||
}
|
||||
|
||||
|
200
compiler/passes/src/type_checker/director.rs
Normal file
200
compiler/passes/src/type_checker/director.rs
Normal file
@ -0,0 +1,200 @@
|
||||
// Copyright (C) 2019-2022 Aleo Systems Inc.
|
||||
// This file is part of the Leo library.
|
||||
|
||||
// The Leo library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
|
||||
// The Leo library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
use leo_ast::*;
|
||||
use leo_errors::{emitter::Handler, TypeCheckerError};
|
||||
|
||||
use crate::{SymbolTable, TypeChecker};
|
||||
|
||||
pub(crate) struct Director<'a> {
|
||||
visitor: TypeChecker<'a>,
|
||||
}
|
||||
|
||||
impl<'a> Director<'a> {
|
||||
pub(crate) fn new(symbol_table: &'a mut SymbolTable<'a>, handler: &'a Handler) -> Self {
|
||||
Self {
|
||||
visitor: TypeChecker::new(symbol_table, handler),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> VisitorDirector<'a> for Director<'a> {
|
||||
type Visitor = TypeChecker<'a>;
|
||||
|
||||
fn visitor(self) -> Self::Visitor {
|
||||
self.visitor
|
||||
}
|
||||
|
||||
fn visitor_ref(&mut self) -> &mut Self::Visitor {
|
||||
&mut self.visitor
|
||||
}
|
||||
}
|
||||
|
||||
fn return_incorrect_type(t1: Option<Type>, t2: Option<Type>, expected: Option<Type>) -> Option<Type> {
|
||||
match (t1, t2) {
|
||||
(Some(t1), Some(t2)) if t1 == t2 => Some(t1),
|
||||
(Some(t1), Some(t2)) => {
|
||||
if let Some(expected) = expected {
|
||||
if t1 != expected {
|
||||
Some(t1)
|
||||
} else {
|
||||
Some(t2)
|
||||
}
|
||||
} else {
|
||||
Some(t1)
|
||||
}
|
||||
}
|
||||
(None, Some(_)) | (Some(_), None) | (None, None) => None,
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ExpressionVisitorDirector<'a> for Director<'a> {
|
||||
type Output = Type;
|
||||
|
||||
fn visit_expression(&mut self, input: &'a Expression) -> Option<Self::Output> {
|
||||
if let VisitResult::VisitChildren = self.visitor.visit_expression(input).0 {
|
||||
return match input {
|
||||
Expression::Identifier(expr) => self.visit_identifier(expr),
|
||||
Expression::Value(expr) => self.visit_value(expr),
|
||||
Expression::Binary(expr) => self.visit_binary(expr),
|
||||
Expression::Unary(expr) => self.visit_unary(expr),
|
||||
Expression::Ternary(expr) => self.visit_ternary(expr),
|
||||
Expression::Call(expr) => self.visit_call(expr),
|
||||
Expression::Err(expr) => self.visit_err(expr),
|
||||
};
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn visit_identifier(&mut self, input: &'a Identifier) -> Option<Self::Output> {
|
||||
self.visitor.visit_identifier(input).1
|
||||
}
|
||||
|
||||
fn visit_value(&mut self, input: &'a ValueExpression) -> Option<Self::Output> {
|
||||
self.visitor.visit_value(input).1
|
||||
}
|
||||
|
||||
fn visit_binary(&mut self, input: &'a BinaryExpression) -> Option<Self::Output> {
|
||||
match self.visitor.visit_binary(input) {
|
||||
(VisitResult::VisitChildren, expected) => {
|
||||
let t1 = self.visit_expression(&input.left);
|
||||
let t2 = self.visit_expression(&input.right);
|
||||
|
||||
return_incorrect_type(t1, t2, self.visitor.expected_type)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_unary(&mut self, input: &'a UnaryExpression) -> Option<Self::Output> {
|
||||
match input.op {
|
||||
UnaryOperation::Not => {
|
||||
self.visitor.assert_type(Type::Boolean, self.visitor.expected_type);
|
||||
self.visit_expression(&input.inner)
|
||||
}
|
||||
UnaryOperation::Negate => {
|
||||
let prior_negate_state = self.visitor.negate;
|
||||
self.visitor.negate = true;
|
||||
|
||||
let type_ = self.visit_expression(&input.inner);
|
||||
self.visitor.negate = prior_negate_state;
|
||||
match type_.as_ref() {
|
||||
Some(
|
||||
Type::IntegerType(
|
||||
IntegerType::I8
|
||||
| IntegerType::I16
|
||||
| IntegerType::I32
|
||||
| IntegerType::I64
|
||||
| IntegerType::I128,
|
||||
)
|
||||
| Type::Field
|
||||
| Type::Group,
|
||||
) => {}
|
||||
Some(t) => self
|
||||
.visitor
|
||||
.handler
|
||||
.emit_err(TypeCheckerError::type_is_not_negatable(t, input.inner.span()).into()),
|
||||
_ => {}
|
||||
};
|
||||
type_
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_ternary(&mut self, input: &'a TernaryExpression) -> Option<Self::Output> {
|
||||
if let VisitResult::VisitChildren = self.visitor.visit_ternary(input).0 {
|
||||
let prev_expected_type = self.visitor.expected_type;
|
||||
self.visitor.expected_type = Some(Type::Boolean);
|
||||
self.visit_expression(&input.condition);
|
||||
self.visitor.expected_type = prev_expected_type;
|
||||
|
||||
let t1 = self.visit_expression(&input.if_true);
|
||||
let t2 = self.visit_expression(&input.if_false);
|
||||
|
||||
return return_incorrect_type(t1, t2, self.visitor.expected_type);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn visit_call(&mut self, input: &'a CallExpression) -> Option<Self::Output> {
|
||||
match &*input.function {
|
||||
Expression::Identifier(ident) => {
|
||||
if let Some(func) = self.visitor.symbol_table.clone().lookup_fn(&ident.name) {
|
||||
let ret = self.visitor.assert_type(func.output, self.visitor.expected_type);
|
||||
|
||||
if func.input.len() != input.arguments.len() {
|
||||
self.visitor.handler.emit_err(
|
||||
TypeCheckerError::incorrect_num_args_to_call(
|
||||
func.input.len(),
|
||||
input.arguments.len(),
|
||||
input.span(),
|
||||
)
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
|
||||
func.input
|
||||
.iter()
|
||||
.zip(input.arguments.iter())
|
||||
.for_each(|(expected, argument)| {
|
||||
let prev_expected_type = self.visitor.expected_type;
|
||||
self.visitor.expected_type = Some(expected.get_variable().type_);
|
||||
self.visit_expression(argument);
|
||||
self.visitor.expected_type = prev_expected_type;
|
||||
});
|
||||
|
||||
Some(ret)
|
||||
} else {
|
||||
self.visitor
|
||||
.handler
|
||||
.emit_err(TypeCheckerError::unknown_sym("function", &ident.name, ident.span()).into());
|
||||
None
|
||||
}
|
||||
}
|
||||
expr => self.visit_expression(expr),
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_err(&mut self, input: &'a ErrExpression) -> Option<Self::Output> {
|
||||
self.visitor.visit_err(input).1
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> StatementVisitorDirector<'a> for Director<'a> {}
|
||||
|
||||
impl<'a> ProgramVisitorDirector<'a> for Director<'a> {}
|
@ -26,9 +26,12 @@ pub use check_statements::*;
|
||||
pub mod checker;
|
||||
pub use checker::*;
|
||||
|
||||
pub mod director;
|
||||
use director::*;
|
||||
|
||||
use crate::{Pass, SymbolTable};
|
||||
|
||||
use leo_ast::{Ast, VisitorDirector};
|
||||
use leo_ast::{Ast, ProgramVisitorDirector};
|
||||
use leo_errors::{emitter::Handler, Result};
|
||||
|
||||
impl<'a> Pass<'a> for TypeChecker<'a> {
|
||||
@ -36,7 +39,7 @@ impl<'a> Pass<'a> for TypeChecker<'a> {
|
||||
type Output = Result<()>;
|
||||
|
||||
fn do_pass((ast, symbol_table, handler): Self::Input) -> Self::Output {
|
||||
let mut visitor = VisitorDirector::new(TypeChecker::new(symbol_table, handler));
|
||||
let mut visitor = Director::new(symbol_table, handler);
|
||||
visitor.visit_program(ast.as_repr());
|
||||
handler.last_err()?;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user