refa:ctored visitor pattern to better on an the AST

This commit is contained in:
gluax 2022-05-26 13:29:51 -07:00
parent 97ef64aa66
commit 044b2a10a4
9 changed files with 421 additions and 430 deletions

View File

@ -30,37 +30,35 @@ impl Default for VisitResult {
}
pub trait ExpressionVisitor<'a> {
type Output;
fn visit_expression(&mut self, _input: &'a Expression) -> (VisitResult, Option<Self::Output>) {
fn visit_expression(&mut self, _input: &'a Expression) -> VisitResult {
Default::default()
}
fn visit_identifier(&mut self, _input: &'a Identifier) -> (VisitResult, Option<Self::Output>) {
fn visit_identifier(&mut self, _input: &'a Identifier) -> VisitResult {
Default::default()
}
fn visit_value(&mut self, _input: &'a ValueExpression) -> (VisitResult, Option<Self::Output>) {
fn visit_value(&mut self, _input: &'a ValueExpression) -> VisitResult {
Default::default()
}
fn visit_binary(&mut self, _input: &'a BinaryExpression) -> (VisitResult, Option<Self::Output>) {
fn visit_binary(&mut self, _input: &'a BinaryExpression) -> VisitResult {
Default::default()
}
fn visit_unary(&mut self, _input: &'a UnaryExpression) -> (VisitResult, Option<Self::Output>) {
fn visit_unary(&mut self, _input: &'a UnaryExpression) -> VisitResult {
Default::default()
}
fn visit_ternary(&mut self, _input: &'a TernaryExpression) -> (VisitResult, Option<Self::Output>) {
fn visit_ternary(&mut self, _input: &'a TernaryExpression) -> VisitResult {
Default::default()
}
fn visit_call(&mut self, _input: &'a CallExpression) -> (VisitResult, Option<Self::Output>) {
fn visit_call(&mut self, _input: &'a CallExpression) -> VisitResult {
Default::default()
}
fn visit_err(&mut self, _input: &'a ErrExpression) -> (VisitResult, Option<Self::Output>) {
fn visit_err(&mut self, _input: &'a ErrExpression) -> VisitResult {
Default::default()
}
}

View File

@ -32,7 +32,7 @@ pub trait ExpressionVisitorDirector<'a>: VisitorDirector<'a> {
type Output;
fn visit_expression(&mut self, input: &'a Expression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor_ref().visit_expression(input).0 {
if let VisitResult::VisitChildren = self.visitor_ref().visit_expression(input) {
match input {
Expression::Identifier(expr) => self.visit_identifier(expr),
Expression::Value(expr) => self.visit_value(expr),
@ -58,7 +58,7 @@ pub trait ExpressionVisitorDirector<'a>: VisitorDirector<'a> {
}
fn visit_binary(&mut self, input: &'a BinaryExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor_ref().visit_binary(input).0 {
if let VisitResult::VisitChildren = self.visitor_ref().visit_binary(input) {
self.visit_expression(&input.left);
self.visit_expression(&input.right);
}
@ -66,14 +66,14 @@ pub trait ExpressionVisitorDirector<'a>: VisitorDirector<'a> {
}
fn visit_unary(&mut self, input: &'a UnaryExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor_ref().visit_unary(input).0 {
if let VisitResult::VisitChildren = self.visitor_ref().visit_unary(input) {
self.visit_expression(&input.inner);
}
None
}
fn visit_ternary(&mut self, input: &'a TernaryExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor_ref().visit_ternary(input).0 {
if let VisitResult::VisitChildren = self.visitor_ref().visit_ternary(input) {
self.visit_expression(&input.condition);
self.visit_expression(&input.if_true);
self.visit_expression(&input.if_false);
@ -82,7 +82,7 @@ pub trait ExpressionVisitorDirector<'a>: VisitorDirector<'a> {
}
fn visit_call(&mut self, input: &'a CallExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor_ref().visit_call(input).0 {
if let VisitResult::VisitChildren = self.visitor_ref().visit_call(input) {
input.arguments.iter().for_each(|expr| {
self.visit_expression(expr);
});
@ -113,7 +113,7 @@ pub trait StatementVisitorDirector<'a>: VisitorDirector<'a> + ExpressionVisitorD
fn visit_return(&mut self, input: &'a ReturnStatement) {
if let VisitResult::VisitChildren = self.visitor_ref().visit_return(input) {
self.visitor_ref().visit_expression(&input.expression);
self.visit_expression(&input.expression);
}
}

View File

@ -36,9 +36,7 @@ impl<'a> CreateSymbolTable<'a> {
}
}
impl<'a> ExpressionVisitor<'a> for CreateSymbolTable<'a> {
type Output = ();
}
impl<'a> ExpressionVisitor<'a> for CreateSymbolTable<'a> {}
impl<'a> StatementVisitor<'a> for CreateSymbolTable<'a> {}

View File

@ -19,6 +19,10 @@ use leo_errors::TypeCheckerError;
use crate::TypeChecker;
use super::director::Director;
impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {}
fn return_incorrect_type(t1: Option<Type>, t2: Option<Type>, expected: Option<Type>) -> Option<Type> {
match (t1, t2) {
(Some(t1), Some(t2)) if t1 == t2 => Some(t1),
@ -37,239 +41,373 @@ fn return_incorrect_type(t1: Option<Type>, t2: Option<Type>, expected: Option<Ty
}
}
impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
impl<'a> ExpressionVisitorDirector<'a> for Director<'a> {
type Output = Type;
fn visit_identifier(&mut self, input: &'a Identifier) -> (VisitResult, Option<Self::Output>) {
let type_ = if let Some(var) = self.symbol_table.lookup_variable(&input.name) {
Some(self.assert_type(*var.type_, self.expected_type))
} else {
self.handler
.emit_err(TypeCheckerError::unknown_sym("variable", input.name, self.span).into());
None
};
fn visit_expression(&mut self, input: &'a Expression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor.visit_expression(input) {
return match input {
Expression::Identifier(expr) => self.visit_identifier(expr),
Expression::Value(expr) => self.visit_value(expr),
Expression::Binary(expr) => self.visit_binary(expr),
Expression::Unary(expr) => self.visit_unary(expr),
Expression::Ternary(expr) => self.visit_ternary(expr),
Expression::Call(expr) => self.visit_call(expr),
Expression::Err(expr) => self.visit_err(expr),
};
}
(VisitResult::VisitChildren, type_)
None
}
fn visit_value(&mut self, input: &'a ValueExpression) -> (VisitResult, Option<Self::Output>) {
let prev_span = self.span;
self.span = input.span();
fn visit_identifier(&mut self, input: &'a Identifier) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor.visit_identifier(input) {
return if let Some(var) = self.visitor.symbol_table.clone().lookup_variable(&input.name) {
Some(self.visitor.assert_type(*var.type_, self.visitor.expected_type))
} else {
self.visitor
.handler
.emit_err(TypeCheckerError::unknown_sym("variable", input.name, input.span()).into());
None
};
}
let type_ = Some(match input {
ValueExpression::Address(_, _) => self.assert_type(Type::Address, self.expected_type),
ValueExpression::Boolean(_, _) => self.assert_type(Type::Boolean, self.expected_type),
ValueExpression::Field(_, _) => self.assert_type(Type::Field, self.expected_type),
ValueExpression::Integer(type_, str_content, _) => {
match type_ {
IntegerType::I8 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
if int.parse::<i8>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i8", input.span()).into());
}
}
IntegerType::I16 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
if int.parse::<i16>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i16", input.span()).into());
}
}
IntegerType::I32 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
if int.parse::<i32>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i32", input.span()).into());
}
}
IntegerType::I64 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
if int.parse::<i64>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i64", input.span()).into());
}
}
IntegerType::I128 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
if int.parse::<i128>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i128", input.span()).into());
}
}
IntegerType::U8 if str_content.parse::<u8>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u8", input.span()).into()),
IntegerType::U16 if str_content.parse::<u16>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u16", input.span()).into()),
IntegerType::U32 if str_content.parse::<u32>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u32", input.span()).into()),
IntegerType::U64 if str_content.parse::<u64>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u64", input.span()).into()),
IntegerType::U128 if str_content.parse::<u128>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u128", input.span()).into()),
_ => {}
}
self.assert_type(Type::IntegerType(*type_), self.expected_type)
}
ValueExpression::Group(_) => self.assert_type(Type::Group, self.expected_type),
ValueExpression::Scalar(_, _) => self.assert_type(Type::Scalar, self.expected_type),
ValueExpression::String(_, _) => unreachable!("String types are not reachable"),
});
self.span = prev_span;
(VisitResult::VisitChildren, type_)
None
}
fn visit_binary(&mut self, input: &'a BinaryExpression) -> (VisitResult, Option<Self::Output>) {
let prev_span = self.span;
self.span = input.span();
fn visit_value(&mut self, input: &'a ValueExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor.visit_value(input) {
return Some(match input {
ValueExpression::Address(_, _) => self.visitor.assert_type(Type::Address, self.visitor.expected_type),
ValueExpression::Boolean(_, _) => self.visitor.assert_type(Type::Boolean, self.visitor.expected_type),
ValueExpression::Field(_, _) => self.visitor.assert_type(Type::Field, self.visitor.expected_type),
ValueExpression::Integer(type_, str_content, _) => {
match type_ {
IntegerType::I8 => {
let int = if self.visitor.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
/* let type_ = match input.op {
BinaryOperation::And | BinaryOperation::Or => {
self.assert_type(Type::Boolean, self.expected_type);
let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span());
let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span());
if int.parse::<i8>().is_err() {
self.visitor
.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i8", input.span()).into());
}
}
IntegerType::I16 => {
let int = if self.visitor.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
return_incorrect_type(t1, t2, self.expected_type)
}
BinaryOperation::Add => {
self.assert_field_group_scalar_int_type(self.expected_type, input.span());
let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span());
let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span());
if int.parse::<i16>().is_err() {
self.visitor
.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i16", input.span()).into());
}
}
IntegerType::I32 => {
let int = if self.visitor.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
return_incorrect_type(t1, t2, self.expected_type)
}
BinaryOperation::Sub => {
self.assert_field_group_int_type(self.expected_type, input.span());
let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span());
let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span());
if int.parse::<i32>().is_err() {
self.visitor
.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i32", input.span()).into());
}
}
IntegerType::I64 => {
let int = if self.visitor.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
return_incorrect_type(t1, t2, self.expected_type)
}
BinaryOperation::Mul => {
self.assert_field_group_int_type(self.expected_type, input.span());
if int.parse::<i64>().is_err() {
self.visitor
.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i64", input.span()).into());
}
}
IntegerType::I128 => {
let int = if self.visitor.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
let t1 = self.compare_expr_type(&input.left, None, input.left.span());
let t2 = self.compare_expr_type(&input.right, None, input.right.span());
// Allow `group` * `scalar` multiplication.
match (t1.as_ref(), t2.as_ref()) {
(Some(Type::Group), Some(other)) => {
self.assert_type(Type::Group, self.expected_type);
self.assert_type(*other, Some(Type::Scalar));
Some(Type::Group)
if int.parse::<i128>().is_err() {
self.visitor
.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i128", input.span()).into());
}
}
IntegerType::U8 if str_content.parse::<u8>().is_err() => self
.visitor
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u8", input.span()).into()),
IntegerType::U16 if str_content.parse::<u16>().is_err() => self
.visitor
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u16", input.span()).into()),
IntegerType::U32 if str_content.parse::<u32>().is_err() => self
.visitor
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u32", input.span()).into()),
IntegerType::U64 if str_content.parse::<u64>().is_err() => self
.visitor
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u64", input.span()).into()),
IntegerType::U128 if str_content.parse::<u128>().is_err() => self
.visitor
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u128", input.span()).into()),
_ => {}
}
(Some(other), Some(Type::Group)) => {
self.assert_type(Type::Group, self.expected_type);
self.assert_type(*other, Some(Type::Scalar));
Some(Type::Group)
}
_ => {
self.assert_type(t1.unwrap(), self.expected_type);
self.assert_type(t2.unwrap(), self.expected_type);
return_incorrect_type(t1, t2, self.expected_type)
self.visitor
.assert_type(Type::IntegerType(*type_), self.visitor.expected_type)
}
ValueExpression::Group(_) => self.visitor.assert_type(Type::Group, self.visitor.expected_type),
ValueExpression::Scalar(_, _) => self.visitor.assert_type(Type::Scalar, self.visitor.expected_type),
ValueExpression::String(_, _) => unreachable!("String types are not reachable"),
});
}
None
}
fn visit_binary(&mut self, input: &'a BinaryExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor.visit_binary(input) {
return match input.op {
BinaryOperation::And | BinaryOperation::Or => {
self.visitor.assert_type(Type::Boolean, self.visitor.expected_type);
let t1 = self.visit_expression(&input.left);
let t2 = self.visit_expression(&input.right);
return_incorrect_type(t1, t2, self.visitor.expected_type)
}
BinaryOperation::Add => {
self.visitor
.assert_field_group_scalar_int_type(self.visitor.expected_type, input.span());
let t1 = self.visit_expression(&input.left);
let t2 = self.visit_expression(&input.right);
return_incorrect_type(t1, t2, self.visitor.expected_type)
}
BinaryOperation::Sub => {
self.visitor
.assert_field_group_int_type(self.visitor.expected_type, input.span());
let t1 = self.visit_expression(&input.left);
let t2 = self.visit_expression(&input.right);
return_incorrect_type(t1, t2, self.visitor.expected_type)
}
BinaryOperation::Mul => {
self.visitor
.assert_field_group_int_type(self.visitor.expected_type, input.span());
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = None;
let t1 = self.visit_expression(&input.left);
let t2 = self.visit_expression(&input.right);
self.visitor.expected_type = prev_expected_type;
// Allow `group` * `scalar` multiplication.
match (t1.as_ref(), t2.as_ref()) {
(Some(Type::Group), Some(other))
| (Some(other), Some(Type::Group)) => {
self.visitor.assert_type(Type::Group, self.visitor.expected_type);
self.visitor.assert_type(*other, Some(Type::Scalar));
Some(Type::Group)
}
_ => {
self.visitor.assert_type(t1.unwrap(), self.visitor.expected_type);
self.visitor.assert_type(t2.unwrap(), self.visitor.expected_type);
return_incorrect_type(t1, t2, self.visitor.expected_type)
}
}
}
}
BinaryOperation::Div => {
self.assert_field_int_type(self.expected_type, input.span());
BinaryOperation::Div => {
self.visitor
.assert_field_int_type(self.visitor.expected_type, input.span());
let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span());
let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span());
return_incorrect_type(t1, t2, self.expected_type)
}
BinaryOperation::Pow => {
let t1 = self.compare_expr_type(&input.left, None, input.left.span());
let t2 = self.compare_expr_type(&input.right, None, input.right.span());
let t1 = self.visit_expression(&input.left);
let t2 = self.visit_expression(&input.right);
return_incorrect_type(t1, t2, self.visitor.expected_type)
}
BinaryOperation::Pow => {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = None;
let t1 = self.visit_expression(&input.left);
let t2 = self.visit_expression(&input.right);
self.visitor.expected_type = prev_expected_type;
match (t1.as_ref(), t2.as_ref()) {
// Type A must be an int.
// Type B must be a unsigned int.
(Some(Type::IntegerType(_)), Some(Type::IntegerType(itype))) if !itype.is_signed() => {
self.assert_type(t1.unwrap(), self.expected_type);
}
// Type A was an int.
// But Type B was not a unsigned int.
(Some(Type::IntegerType(_)), Some(t)) => {
self.handler.emit_err(
TypeCheckerError::incorrect_pow_exponent_type("unsigned int", t, input.right.span())
.into(),
);
}
// Type A must be a field.
// Type B must be an int.
(Some(Type::Field), Some(Type::IntegerType(_))) => {
self.assert_type(Type::Field, self.expected_type);
}
// Type A was a field.
// But Type B was not an int.
(Some(Type::Field), Some(t)) => {
self.handler.emit_err(
TypeCheckerError::incorrect_pow_exponent_type("int", t, input.right.span()).into(),
);
}
// The base is some type thats not an int or field.
(Some(t), _) => {
self.handler
.emit_err(TypeCheckerError::incorrect_pow_base_type(t, input.left.span()).into());
match (t1.as_ref(), t2.as_ref()) {
// Type A must be an int.
// Type B must be a unsigned int.
(Some(Type::IntegerType(_)), Some(Type::IntegerType(itype))) if !itype.is_signed() => {
self.visitor.assert_type(t1.unwrap(), self.visitor.expected_type);
}
// Type A was an int.
// But Type B was not a unsigned int.
(Some(Type::IntegerType(_)), Some(t)) => {
self.visitor.handler.emit_err(
TypeCheckerError::incorrect_pow_exponent_type("unsigned int", t, input.right.span())
.into(),
);
}
// Type A must be a field.
// Type B must be an int.
(Some(Type::Field), Some(Type::IntegerType(_))) => {
self.visitor.assert_type(Type::Field, self.visitor.expected_type);
}
// Type A was a field.
// But Type B was not an int.
(Some(Type::Field), Some(t)) => {
self.visitor.handler.emit_err(
TypeCheckerError::incorrect_pow_exponent_type("int", t, input.right.span()).into(),
);
}
// The base is some type thats not an int or field.
(Some(t), _) => {
self.visitor
.handler
.emit_err(TypeCheckerError::incorrect_pow_base_type(t, input.left.span()).into());
}
_ => {}
}
t1
}
BinaryOperation::Eq | BinaryOperation::Ne => {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = None;
let t1 = self.visit_expression(&input.left);
let t2 = self.visit_expression(&input.right);
self.visitor.expected_type = prev_expected_type;
self.visitor.assert_eq_types(t1, t2, input.span());
Some(Type::Boolean)
}
BinaryOperation::Lt | BinaryOperation::Gt | BinaryOperation::Le | BinaryOperation::Ge => {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = None;
let t1 = self.visit_expression(&input.left);
self.visitor.assert_field_scalar_int_type(t1, input.left.span());
let t2 = self.visit_expression(&input.right);
self.visitor.assert_field_scalar_int_type(t2, input.right.span());
self.visitor.expected_type = prev_expected_type;
self.visitor.assert_eq_types(t1, t2, input.span());
Some(Type::Boolean)
}
};
}
None
}
fn visit_unary(&mut self, input: &'a UnaryExpression) -> Option<Self::Output> {
match input.op {
UnaryOperation::Not => {
self.visitor.assert_type(Type::Boolean, self.visitor.expected_type);
self.visit_expression(&input.inner)
}
UnaryOperation::Negate => {
let prior_negate_state = self.visitor.negate;
self.visitor.negate = true;
let type_ = self.visit_expression(&input.inner);
self.visitor.negate = prior_negate_state;
match type_.as_ref() {
Some(
Type::IntegerType(
IntegerType::I8
| IntegerType::I16
| IntegerType::I32
| IntegerType::I64
| IntegerType::I128,
)
| Type::Field
| Type::Group,
) => {}
Some(t) => self
.visitor
.handler
.emit_err(TypeCheckerError::type_is_not_negatable(t, input.inner.span()).into()),
_ => {}
};
type_
}
}
}
fn visit_ternary(&mut self, input: &'a TernaryExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor.visit_ternary(input) {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = Some(Type::Boolean);
self.visit_expression(&input.condition);
self.visitor.expected_type = prev_expected_type;
let t1 = self.visit_expression(&input.if_true);
let t2 = self.visit_expression(&input.if_false);
return return_incorrect_type(t1, t2, self.visitor.expected_type);
}
None
}
fn visit_call(&mut self, input: &'a CallExpression) -> Option<Self::Output> {
match &*input.function {
Expression::Identifier(ident) => {
if let Some(func) = self.visitor.symbol_table.clone().lookup_fn(&ident.name) {
let ret = self.visitor.assert_type(func.output, self.visitor.expected_type);
if func.input.len() != input.arguments.len() {
self.visitor.handler.emit_err(
TypeCheckerError::incorrect_num_args_to_call(
func.input.len(),
input.arguments.len(),
input.span(),
)
.into(),
);
}
func.input
.iter()
.zip(input.arguments.iter())
.for_each(|(expected, argument)| {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = Some(expected.get_variable().type_);
self.visit_expression(argument);
self.visitor.expected_type = prev_expected_type;
});
Some(ret)
} else {
self.visitor
.handler
.emit_err(TypeCheckerError::unknown_sym("function", &ident.name, ident.span()).into());
None
}
t1
}
BinaryOperation::Eq | BinaryOperation::Ne => {
let t1 = self.compare_expr_type(&input.left, None, input.left.span());
let t2 = self.compare_expr_type(&input.right, None, input.right.span());
self.assert_eq_types(t1, t2, input.span());
Some(Type::Boolean)
}
BinaryOperation::Lt | BinaryOperation::Gt | BinaryOperation::Le | BinaryOperation::Ge => {
let t1 = self.compare_expr_type(&input.left, None, input.left.span());
self.assert_field_scalar_int_type(t1, input.left.span());
let t2 = self.compare_expr_type(&input.right, None, input.right.span());
self.assert_field_scalar_int_type(t2, input.right.span());
self.assert_eq_types(t1, t2, input.span());
Some(Type::Boolean)
}
}; */
self.span = prev_span;
(VisitResult::VisitChildren, None)
expr => self.visit_expression(expr),
}
}
}

View File

@ -18,6 +18,8 @@ use leo_ast::*;
use crate::{Declaration, TypeChecker, VariableSymbol};
use super::director::Director;
impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
fn visit_function(&mut self, input: &'a Function) -> VisitResult {
self.symbol_table.clear_variables();
@ -40,3 +42,5 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
VisitResult::VisitChildren
}
}
impl<'a> ProgramVisitorDirector<'a> for Director<'a> {}

View File

@ -19,20 +19,23 @@ use leo_errors::TypeCheckerError;
use crate::{Declaration, TypeChecker, VariableSymbol};
impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
fn visit_return(&mut self, input: &'a ReturnStatement) -> VisitResult {
use super::director::Director;
impl<'a> StatementVisitor<'a> for TypeChecker<'a> {}
impl<'a> StatementVisitorDirector<'a> for Director<'a> {
fn visit_return(&mut self, input: &'a ReturnStatement) {
// we can safely unwrap all self.parent instances because
// statements should always have some parent block
let parent = self.parent.unwrap();
let parent = self.visitor.parent.unwrap();
// Would never be None.
let func_output_type = self.symbol_table.lookup_fn(&parent).map(|f| f.output);
// self.compare_expr_type(&input.expression, func_output_type, input.expression.span());
VisitResult::VisitChildren
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = self.visitor.symbol_table.lookup_fn(&parent).map(|f| f.output);
self.visit_expression(&input.expression);
self.visitor.expected_type = prev_expected_type;
}
fn visit_definition(&mut self, input: &'a DefinitionStatement) -> VisitResult {
fn visit_definition(&mut self, input: &'a DefinitionStatement) {
let declaration = if input.declaration_type == Declare::Const {
Declaration::Const
} else {
@ -40,7 +43,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
};
input.variable_names.iter().for_each(|v| {
if let Err(err) = self.symbol_table.insert_variable(
if let Err(err) = self.visitor.symbol_table.insert_variable(
v.identifier.name,
VariableSymbol {
type_: &input.type_,
@ -48,23 +51,26 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
declaration: declaration.clone(),
},
) {
self.handler.emit_err(err);
self.visitor.handler.emit_err(err);
}
// self.compare_expr_type(&input.value, Some(input.type_), input.value.span());
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = Some(input.type_);
self.visit_expression(&input.value);
self.visitor.expected_type = prev_expected_type;
});
VisitResult::VisitChildren
}
fn visit_assign(&mut self, input: &'a AssignStatement) -> VisitResult {
fn visit_assign(&mut self, input: &'a AssignStatement) {
let var_name = &input.assignee.identifier.name;
let var_type = if let Some(var) = self.symbol_table.lookup_variable(var_name) {
let var_type = if let Some(var) = self.visitor.symbol_table.lookup_variable(var_name) {
match &var.declaration {
Declaration::Const => self
.visitor
.handler
.emit_err(TypeCheckerError::cannont_assign_to_const_var(var_name, var.span).into()),
Declaration::Input(ParamMode::Constant) => self
.visitor
.handler
.emit_err(TypeCheckerError::cannont_assign_to_const_input(var_name, var.span).into()),
_ => {}
@ -72,7 +78,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
Some(*var.type_)
} else {
self.handler.emit_err(
self.visitor.handler.emit_err(
TypeCheckerError::unknown_sym("variable", &input.assignee.identifier.name, input.assignee.span).into(),
);
@ -80,20 +86,22 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
};
if var_type.is_some() {
// self.compare_expr_type(&input.value, var_type, input.value.span());
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = var_type;
self.visit_expression(&input.value);
self.visitor.expected_type = prev_expected_type;
}
VisitResult::VisitChildren
}
fn visit_conditional(&mut self, input: &'a ConditionalStatement) -> VisitResult {
// self.compare_expr_type(&input.condition, Some(Type::Boolean), input.condition.span());
VisitResult::VisitChildren
fn visit_conditional(&mut self, input: &'a ConditionalStatement) {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = Some(Type::Boolean);
self.visit_expression(&input.condition);
self.visitor.expected_type = prev_expected_type;
}
fn visit_iteration(&mut self, input: &'a IterationStatement) -> VisitResult {
if let Err(err) = self.symbol_table.insert_variable(
fn visit_iteration(&mut self, input: &'a IterationStatement) {
if let Err(err) = self.visitor.symbol_table.insert_variable(
input.variable.name,
VariableSymbol {
type_: &input.type_,
@ -101,30 +109,33 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
declaration: Declaration::Const,
},
) {
self.handler.emit_err(err);
self.visitor.handler.emit_err(err);
}
// self.compare_expr_type(&input.start, Some(input.type_), input.start.span());
// self.compare_expr_type(&input.stop, Some(input.type_), input.stop.span());
VisitResult::VisitChildren
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = Some(input.type_);
self.visit_expression(&input.start);
self.visit_expression(&input.stop);
self.visitor.expected_type = prev_expected_type;
}
fn visit_console(&mut self, input: &'a ConsoleStatement) -> VisitResult {
fn visit_console(&mut self, input: &'a ConsoleStatement) {
match &input.function {
ConsoleFunction::Assert(expr) => {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = Some(Type::Boolean);
self.visit_expression(expr);
self.visitor.expected_type = prev_expected_type;
// self.compare_expr_type(expr, Some(Type::Boolean), expr.span());
}
ConsoleFunction::Error(_) | ConsoleFunction::Log(_) => {
// TODO: undetermined
}
}
VisitResult::VisitChildren
}
fn visit_block(&mut self, input: &'a Block) -> VisitResult {
self.symbol_table.push_variable_scope();
fn visit_block(&mut self, input: &'a Block) {
self.visitor.symbol_table.push_variable_scope();
// have to redo the logic here so we have scoping
input.statements.iter().for_each(|stmt| {
match stmt {
@ -137,8 +148,6 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
Statement::Block(stmt) => self.visit_block(stmt),
};
});
self.symbol_table.pop_variable_scope();
VisitResult::SkipChildren
self.visitor.symbol_table.pop_variable_scope();
}
}

View File

@ -15,12 +15,12 @@
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use leo_ast::*;
use leo_errors::{emitter::Handler, TypeCheckerError};
use leo_errors::emitter::Handler;
use crate::{SymbolTable, TypeChecker};
pub(crate) struct Director<'a> {
visitor: TypeChecker<'a>,
pub(crate) visitor: TypeChecker<'a>,
}
impl<'a> Director<'a> {
@ -42,159 +42,3 @@ impl<'a> VisitorDirector<'a> for Director<'a> {
&mut self.visitor
}
}
fn return_incorrect_type(t1: Option<Type>, t2: Option<Type>, expected: Option<Type>) -> Option<Type> {
match (t1, t2) {
(Some(t1), Some(t2)) if t1 == t2 => Some(t1),
(Some(t1), Some(t2)) => {
if let Some(expected) = expected {
if t1 != expected {
Some(t1)
} else {
Some(t2)
}
} else {
Some(t1)
}
}
(None, Some(_)) | (Some(_), None) | (None, None) => None,
}
}
impl<'a> ExpressionVisitorDirector<'a> for Director<'a> {
type Output = Type;
fn visit_expression(&mut self, input: &'a Expression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor.visit_expression(input).0 {
return match input {
Expression::Identifier(expr) => self.visit_identifier(expr),
Expression::Value(expr) => self.visit_value(expr),
Expression::Binary(expr) => self.visit_binary(expr),
Expression::Unary(expr) => self.visit_unary(expr),
Expression::Ternary(expr) => self.visit_ternary(expr),
Expression::Call(expr) => self.visit_call(expr),
Expression::Err(expr) => self.visit_err(expr),
};
}
None
}
fn visit_identifier(&mut self, input: &'a Identifier) -> Option<Self::Output> {
self.visitor.visit_identifier(input).1
}
fn visit_value(&mut self, input: &'a ValueExpression) -> Option<Self::Output> {
self.visitor.visit_value(input).1
}
fn visit_binary(&mut self, input: &'a BinaryExpression) -> Option<Self::Output> {
match self.visitor.visit_binary(input) {
(VisitResult::VisitChildren, expected) => {
let t1 = self.visit_expression(&input.left);
let t2 = self.visit_expression(&input.right);
return_incorrect_type(t1, t2, self.visitor.expected_type)
}
_ => None,
}
}
fn visit_unary(&mut self, input: &'a UnaryExpression) -> Option<Self::Output> {
match input.op {
UnaryOperation::Not => {
self.visitor.assert_type(Type::Boolean, self.visitor.expected_type);
self.visit_expression(&input.inner)
}
UnaryOperation::Negate => {
let prior_negate_state = self.visitor.negate;
self.visitor.negate = true;
let type_ = self.visit_expression(&input.inner);
self.visitor.negate = prior_negate_state;
match type_.as_ref() {
Some(
Type::IntegerType(
IntegerType::I8
| IntegerType::I16
| IntegerType::I32
| IntegerType::I64
| IntegerType::I128,
)
| Type::Field
| Type::Group,
) => {}
Some(t) => self
.visitor
.handler
.emit_err(TypeCheckerError::type_is_not_negatable(t, input.inner.span()).into()),
_ => {}
};
type_
}
}
}
fn visit_ternary(&mut self, input: &'a TernaryExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor.visit_ternary(input).0 {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = Some(Type::Boolean);
self.visit_expression(&input.condition);
self.visitor.expected_type = prev_expected_type;
let t1 = self.visit_expression(&input.if_true);
let t2 = self.visit_expression(&input.if_false);
return return_incorrect_type(t1, t2, self.visitor.expected_type);
}
None
}
fn visit_call(&mut self, input: &'a CallExpression) -> Option<Self::Output> {
match &*input.function {
Expression::Identifier(ident) => {
if let Some(func) = self.visitor.symbol_table.clone().lookup_fn(&ident.name) {
let ret = self.visitor.assert_type(func.output, self.visitor.expected_type);
if func.input.len() != input.arguments.len() {
self.visitor.handler.emit_err(
TypeCheckerError::incorrect_num_args_to_call(
func.input.len(),
input.arguments.len(),
input.span(),
)
.into(),
);
}
func.input
.iter()
.zip(input.arguments.iter())
.for_each(|(expected, argument)| {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = Some(expected.get_variable().type_);
self.visit_expression(argument);
self.visitor.expected_type = prev_expected_type;
});
Some(ret)
} else {
self.visitor
.handler
.emit_err(TypeCheckerError::unknown_sym("function", &ident.name, ident.span()).into());
None
}
}
expr => self.visit_expression(expr),
}
}
fn visit_err(&mut self, input: &'a ErrExpression) -> Option<Self::Output> {
self.visitor.visit_err(input).1
}
}
impl<'a> StatementVisitorDirector<'a> for Director<'a> {}
impl<'a> ProgramVisitorDirector<'a> for Director<'a> {}

View File

@ -2,4 +2,4 @@
namespace: Compile
expectation: Fail
outputs:
- "Error [ETYC0372002]: Found type `group` but type `scalar` was expected\n --> compiler-test:4:12\n |\n 4 | return (_, _)group * a;\n | ^^^^^^^^^^^^^^^\n"
- "Error [ETYC0372002]: Found type `group` but type `scalar` was expected\n --> compiler-test:1:1\n |\n 1 | \n | \n"

View File

@ -2,4 +2,4 @@
namespace: Compile
expectation: Fail
outputs:
- "Error [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:4:19\n |\n 4 | let b: bool = a == 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:5:19\n |\n 5 | let c: bool = a != 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:6:19\n |\n 6 | let d: bool = a > 1u8;\n | ^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:7:19\n |\n 7 | let e: bool = a < 1u8;\n | ^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:8:19\n |\n 8 | let f: bool = a >= 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:9:19\n |\n 9 | let g: bool = a <= 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u32` was expected\n --> compiler-test:10:18\n |\n 10 | let h: u32 = a * 1u8;\n | ^\nError [ETYC0372002]: Found type `u8` but type `u32` was expected\n --> compiler-test:10:22\n |\n 10 | let h: u32 = a * 1u8;\n | ^^^\n"
- "Error [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:4:19\n |\n 4 | let b: bool = a == 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:5:19\n |\n 5 | let c: bool = a != 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:6:19\n |\n 6 | let d: bool = a > 1u8;\n | ^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:7:19\n |\n 7 | let e: bool = a < 1u8;\n | ^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:8:19\n |\n 8 | let f: bool = a >= 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:9:19\n |\n 9 | let g: bool = a <= 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u32` was expected\n --> compiler-test:1:1\n |\n 1 | \n | \nError [ETYC0372002]: Found type `u8` but type `u32` was expected\n --> compiler-test:1:1\n |\n 1 | \n | \n"