refa:ctored visitor pattern to better on an the AST

This commit is contained in:
gluax 2022-05-26 13:29:51 -07:00
parent 97ef64aa66
commit 044b2a10a4
9 changed files with 421 additions and 430 deletions

View File

@ -30,37 +30,35 @@ impl Default for VisitResult {
} }
pub trait ExpressionVisitor<'a> { pub trait ExpressionVisitor<'a> {
type Output; fn visit_expression(&mut self, _input: &'a Expression) -> VisitResult {
fn visit_expression(&mut self, _input: &'a Expression) -> (VisitResult, Option<Self::Output>) {
Default::default() Default::default()
} }
fn visit_identifier(&mut self, _input: &'a Identifier) -> (VisitResult, Option<Self::Output>) { fn visit_identifier(&mut self, _input: &'a Identifier) -> VisitResult {
Default::default() Default::default()
} }
fn visit_value(&mut self, _input: &'a ValueExpression) -> (VisitResult, Option<Self::Output>) { fn visit_value(&mut self, _input: &'a ValueExpression) -> VisitResult {
Default::default() Default::default()
} }
fn visit_binary(&mut self, _input: &'a BinaryExpression) -> (VisitResult, Option<Self::Output>) { fn visit_binary(&mut self, _input: &'a BinaryExpression) -> VisitResult {
Default::default() Default::default()
} }
fn visit_unary(&mut self, _input: &'a UnaryExpression) -> (VisitResult, Option<Self::Output>) { fn visit_unary(&mut self, _input: &'a UnaryExpression) -> VisitResult {
Default::default() Default::default()
} }
fn visit_ternary(&mut self, _input: &'a TernaryExpression) -> (VisitResult, Option<Self::Output>) { fn visit_ternary(&mut self, _input: &'a TernaryExpression) -> VisitResult {
Default::default() Default::default()
} }
fn visit_call(&mut self, _input: &'a CallExpression) -> (VisitResult, Option<Self::Output>) { fn visit_call(&mut self, _input: &'a CallExpression) -> VisitResult {
Default::default() Default::default()
} }
fn visit_err(&mut self, _input: &'a ErrExpression) -> (VisitResult, Option<Self::Output>) { fn visit_err(&mut self, _input: &'a ErrExpression) -> VisitResult {
Default::default() Default::default()
} }
} }

View File

@ -32,7 +32,7 @@ pub trait ExpressionVisitorDirector<'a>: VisitorDirector<'a> {
type Output; type Output;
fn visit_expression(&mut self, input: &'a Expression) -> Option<Self::Output> { fn visit_expression(&mut self, input: &'a Expression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor_ref().visit_expression(input).0 { if let VisitResult::VisitChildren = self.visitor_ref().visit_expression(input) {
match input { match input {
Expression::Identifier(expr) => self.visit_identifier(expr), Expression::Identifier(expr) => self.visit_identifier(expr),
Expression::Value(expr) => self.visit_value(expr), Expression::Value(expr) => self.visit_value(expr),
@ -58,7 +58,7 @@ pub trait ExpressionVisitorDirector<'a>: VisitorDirector<'a> {
} }
fn visit_binary(&mut self, input: &'a BinaryExpression) -> Option<Self::Output> { fn visit_binary(&mut self, input: &'a BinaryExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor_ref().visit_binary(input).0 { if let VisitResult::VisitChildren = self.visitor_ref().visit_binary(input) {
self.visit_expression(&input.left); self.visit_expression(&input.left);
self.visit_expression(&input.right); self.visit_expression(&input.right);
} }
@ -66,14 +66,14 @@ pub trait ExpressionVisitorDirector<'a>: VisitorDirector<'a> {
} }
fn visit_unary(&mut self, input: &'a UnaryExpression) -> Option<Self::Output> { fn visit_unary(&mut self, input: &'a UnaryExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor_ref().visit_unary(input).0 { if let VisitResult::VisitChildren = self.visitor_ref().visit_unary(input) {
self.visit_expression(&input.inner); self.visit_expression(&input.inner);
} }
None None
} }
fn visit_ternary(&mut self, input: &'a TernaryExpression) -> Option<Self::Output> { fn visit_ternary(&mut self, input: &'a TernaryExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor_ref().visit_ternary(input).0 { if let VisitResult::VisitChildren = self.visitor_ref().visit_ternary(input) {
self.visit_expression(&input.condition); self.visit_expression(&input.condition);
self.visit_expression(&input.if_true); self.visit_expression(&input.if_true);
self.visit_expression(&input.if_false); self.visit_expression(&input.if_false);
@ -82,7 +82,7 @@ pub trait ExpressionVisitorDirector<'a>: VisitorDirector<'a> {
} }
fn visit_call(&mut self, input: &'a CallExpression) -> Option<Self::Output> { fn visit_call(&mut self, input: &'a CallExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor_ref().visit_call(input).0 { if let VisitResult::VisitChildren = self.visitor_ref().visit_call(input) {
input.arguments.iter().for_each(|expr| { input.arguments.iter().for_each(|expr| {
self.visit_expression(expr); self.visit_expression(expr);
}); });
@ -113,7 +113,7 @@ pub trait StatementVisitorDirector<'a>: VisitorDirector<'a> + ExpressionVisitorD
fn visit_return(&mut self, input: &'a ReturnStatement) { fn visit_return(&mut self, input: &'a ReturnStatement) {
if let VisitResult::VisitChildren = self.visitor_ref().visit_return(input) { if let VisitResult::VisitChildren = self.visitor_ref().visit_return(input) {
self.visitor_ref().visit_expression(&input.expression); self.visit_expression(&input.expression);
} }
} }

View File

@ -36,9 +36,7 @@ impl<'a> CreateSymbolTable<'a> {
} }
} }
impl<'a> ExpressionVisitor<'a> for CreateSymbolTable<'a> { impl<'a> ExpressionVisitor<'a> for CreateSymbolTable<'a> {}
type Output = ();
}
impl<'a> StatementVisitor<'a> for CreateSymbolTable<'a> {} impl<'a> StatementVisitor<'a> for CreateSymbolTable<'a> {}

View File

@ -19,6 +19,10 @@ use leo_errors::TypeCheckerError;
use crate::TypeChecker; use crate::TypeChecker;
use super::director::Director;
impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {}
fn return_incorrect_type(t1: Option<Type>, t2: Option<Type>, expected: Option<Type>) -> Option<Type> { fn return_incorrect_type(t1: Option<Type>, t2: Option<Type>, expected: Option<Type>) -> Option<Type> {
match (t1, t2) { match (t1, t2) {
(Some(t1), Some(t2)) if t1 == t2 => Some(t1), (Some(t1), Some(t2)) if t1 == t2 => Some(t1),
@ -37,239 +41,373 @@ fn return_incorrect_type(t1: Option<Type>, t2: Option<Type>, expected: Option<Ty
} }
} }
impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { impl<'a> ExpressionVisitorDirector<'a> for Director<'a> {
type Output = Type; type Output = Type;
fn visit_identifier(&mut self, input: &'a Identifier) -> (VisitResult, Option<Self::Output>) { fn visit_expression(&mut self, input: &'a Expression) -> Option<Self::Output> {
let type_ = if let Some(var) = self.symbol_table.lookup_variable(&input.name) { if let VisitResult::VisitChildren = self.visitor.visit_expression(input) {
Some(self.assert_type(*var.type_, self.expected_type)) return match input {
} else { Expression::Identifier(expr) => self.visit_identifier(expr),
self.handler Expression::Value(expr) => self.visit_value(expr),
.emit_err(TypeCheckerError::unknown_sym("variable", input.name, self.span).into()); Expression::Binary(expr) => self.visit_binary(expr),
None Expression::Unary(expr) => self.visit_unary(expr),
}; Expression::Ternary(expr) => self.visit_ternary(expr),
Expression::Call(expr) => self.visit_call(expr),
Expression::Err(expr) => self.visit_err(expr),
};
}
(VisitResult::VisitChildren, type_) None
} }
fn visit_value(&mut self, input: &'a ValueExpression) -> (VisitResult, Option<Self::Output>) { fn visit_identifier(&mut self, input: &'a Identifier) -> Option<Self::Output> {
let prev_span = self.span; if let VisitResult::VisitChildren = self.visitor.visit_identifier(input) {
self.span = input.span(); return if let Some(var) = self.visitor.symbol_table.clone().lookup_variable(&input.name) {
Some(self.visitor.assert_type(*var.type_, self.visitor.expected_type))
} else {
self.visitor
.handler
.emit_err(TypeCheckerError::unknown_sym("variable", input.name, input.span()).into());
None
};
}
let type_ = Some(match input { None
ValueExpression::Address(_, _) => self.assert_type(Type::Address, self.expected_type),
ValueExpression::Boolean(_, _) => self.assert_type(Type::Boolean, self.expected_type),
ValueExpression::Field(_, _) => self.assert_type(Type::Field, self.expected_type),
ValueExpression::Integer(type_, str_content, _) => {
match type_ {
IntegerType::I8 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
if int.parse::<i8>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i8", input.span()).into());
}
}
IntegerType::I16 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
if int.parse::<i16>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i16", input.span()).into());
}
}
IntegerType::I32 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
if int.parse::<i32>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i32", input.span()).into());
}
}
IntegerType::I64 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
if int.parse::<i64>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i64", input.span()).into());
}
}
IntegerType::I128 => {
let int = if self.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
if int.parse::<i128>().is_err() {
self.handler
.emit_err(TypeCheckerError::invalid_int_value(int, "i128", input.span()).into());
}
}
IntegerType::U8 if str_content.parse::<u8>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u8", input.span()).into()),
IntegerType::U16 if str_content.parse::<u16>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u16", input.span()).into()),
IntegerType::U32 if str_content.parse::<u32>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u32", input.span()).into()),
IntegerType::U64 if str_content.parse::<u64>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u64", input.span()).into()),
IntegerType::U128 if str_content.parse::<u128>().is_err() => self
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u128", input.span()).into()),
_ => {}
}
self.assert_type(Type::IntegerType(*type_), self.expected_type)
}
ValueExpression::Group(_) => self.assert_type(Type::Group, self.expected_type),
ValueExpression::Scalar(_, _) => self.assert_type(Type::Scalar, self.expected_type),
ValueExpression::String(_, _) => unreachable!("String types are not reachable"),
});
self.span = prev_span;
(VisitResult::VisitChildren, type_)
} }
fn visit_binary(&mut self, input: &'a BinaryExpression) -> (VisitResult, Option<Self::Output>) { fn visit_value(&mut self, input: &'a ValueExpression) -> Option<Self::Output> {
let prev_span = self.span; if let VisitResult::VisitChildren = self.visitor.visit_value(input) {
self.span = input.span(); return Some(match input {
ValueExpression::Address(_, _) => self.visitor.assert_type(Type::Address, self.visitor.expected_type),
ValueExpression::Boolean(_, _) => self.visitor.assert_type(Type::Boolean, self.visitor.expected_type),
ValueExpression::Field(_, _) => self.visitor.assert_type(Type::Field, self.visitor.expected_type),
ValueExpression::Integer(type_, str_content, _) => {
match type_ {
IntegerType::I8 => {
let int = if self.visitor.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
/* let type_ = match input.op { if int.parse::<i8>().is_err() {
BinaryOperation::And | BinaryOperation::Or => { self.visitor
self.assert_type(Type::Boolean, self.expected_type); .handler
let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span()); .emit_err(TypeCheckerError::invalid_int_value(int, "i8", input.span()).into());
let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span()); }
}
IntegerType::I16 => {
let int = if self.visitor.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
return_incorrect_type(t1, t2, self.expected_type) if int.parse::<i16>().is_err() {
} self.visitor
BinaryOperation::Add => { .handler
self.assert_field_group_scalar_int_type(self.expected_type, input.span()); .emit_err(TypeCheckerError::invalid_int_value(int, "i16", input.span()).into());
let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span()); }
let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span()); }
IntegerType::I32 => {
let int = if self.visitor.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
return_incorrect_type(t1, t2, self.expected_type) if int.parse::<i32>().is_err() {
} self.visitor
BinaryOperation::Sub => { .handler
self.assert_field_group_int_type(self.expected_type, input.span()); .emit_err(TypeCheckerError::invalid_int_value(int, "i32", input.span()).into());
let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span()); }
let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span()); }
IntegerType::I64 => {
let int = if self.visitor.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
return_incorrect_type(t1, t2, self.expected_type) if int.parse::<i64>().is_err() {
} self.visitor
BinaryOperation::Mul => { .handler
self.assert_field_group_int_type(self.expected_type, input.span()); .emit_err(TypeCheckerError::invalid_int_value(int, "i64", input.span()).into());
}
}
IntegerType::I128 => {
let int = if self.visitor.negate {
format!("-{str_content}")
} else {
str_content.clone()
};
let t1 = self.compare_expr_type(&input.left, None, input.left.span()); if int.parse::<i128>().is_err() {
let t2 = self.compare_expr_type(&input.right, None, input.right.span()); self.visitor
.handler
// Allow `group` * `scalar` multiplication. .emit_err(TypeCheckerError::invalid_int_value(int, "i128", input.span()).into());
match (t1.as_ref(), t2.as_ref()) { }
(Some(Type::Group), Some(other)) => { }
self.assert_type(Type::Group, self.expected_type); IntegerType::U8 if str_content.parse::<u8>().is_err() => self
self.assert_type(*other, Some(Type::Scalar)); .visitor
Some(Type::Group) .handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u8", input.span()).into()),
IntegerType::U16 if str_content.parse::<u16>().is_err() => self
.visitor
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u16", input.span()).into()),
IntegerType::U32 if str_content.parse::<u32>().is_err() => self
.visitor
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u32", input.span()).into()),
IntegerType::U64 if str_content.parse::<u64>().is_err() => self
.visitor
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u64", input.span()).into()),
IntegerType::U128 if str_content.parse::<u128>().is_err() => self
.visitor
.handler
.emit_err(TypeCheckerError::invalid_int_value(str_content, "u128", input.span()).into()),
_ => {}
} }
(Some(other), Some(Type::Group)) => { self.visitor
self.assert_type(Type::Group, self.expected_type); .assert_type(Type::IntegerType(*type_), self.visitor.expected_type)
self.assert_type(*other, Some(Type::Scalar)); }
Some(Type::Group) ValueExpression::Group(_) => self.visitor.assert_type(Type::Group, self.visitor.expected_type),
} ValueExpression::Scalar(_, _) => self.visitor.assert_type(Type::Scalar, self.visitor.expected_type),
_ => { ValueExpression::String(_, _) => unreachable!("String types are not reachable"),
self.assert_type(t1.unwrap(), self.expected_type); });
self.assert_type(t2.unwrap(), self.expected_type); }
return_incorrect_type(t1, t2, self.expected_type)
None
}
fn visit_binary(&mut self, input: &'a BinaryExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor.visit_binary(input) {
return match input.op {
BinaryOperation::And | BinaryOperation::Or => {
self.visitor.assert_type(Type::Boolean, self.visitor.expected_type);
let t1 = self.visit_expression(&input.left);
let t2 = self.visit_expression(&input.right);
return_incorrect_type(t1, t2, self.visitor.expected_type)
}
BinaryOperation::Add => {
self.visitor
.assert_field_group_scalar_int_type(self.visitor.expected_type, input.span());
let t1 = self.visit_expression(&input.left);
let t2 = self.visit_expression(&input.right);
return_incorrect_type(t1, t2, self.visitor.expected_type)
}
BinaryOperation::Sub => {
self.visitor
.assert_field_group_int_type(self.visitor.expected_type, input.span());
let t1 = self.visit_expression(&input.left);
let t2 = self.visit_expression(&input.right);
return_incorrect_type(t1, t2, self.visitor.expected_type)
}
BinaryOperation::Mul => {
self.visitor
.assert_field_group_int_type(self.visitor.expected_type, input.span());
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = None;
let t1 = self.visit_expression(&input.left);
let t2 = self.visit_expression(&input.right);
self.visitor.expected_type = prev_expected_type;
// Allow `group` * `scalar` multiplication.
match (t1.as_ref(), t2.as_ref()) {
(Some(Type::Group), Some(other))
| (Some(other), Some(Type::Group)) => {
self.visitor.assert_type(Type::Group, self.visitor.expected_type);
self.visitor.assert_type(*other, Some(Type::Scalar));
Some(Type::Group)
}
_ => {
self.visitor.assert_type(t1.unwrap(), self.visitor.expected_type);
self.visitor.assert_type(t2.unwrap(), self.visitor.expected_type);
return_incorrect_type(t1, t2, self.visitor.expected_type)
}
} }
} }
} BinaryOperation::Div => {
BinaryOperation::Div => { self.visitor
self.assert_field_int_type(self.expected_type, input.span()); .assert_field_int_type(self.visitor.expected_type, input.span());
let t1 = self.compare_expr_type(&input.left, self.expected_type, input.left.span()); let t1 = self.visit_expression(&input.left);
let t2 = self.compare_expr_type(&input.right, self.expected_type, input.right.span()); let t2 = self.visit_expression(&input.right);
return_incorrect_type(t1, t2, self.expected_type)
}
BinaryOperation::Pow => {
let t1 = self.compare_expr_type(&input.left, None, input.left.span());
let t2 = self.compare_expr_type(&input.right, None, input.right.span());
match (t1.as_ref(), t2.as_ref()) { return_incorrect_type(t1, t2, self.visitor.expected_type)
// Type A must be an int. }
// Type B must be a unsigned int. BinaryOperation::Pow => {
(Some(Type::IntegerType(_)), Some(Type::IntegerType(itype))) if !itype.is_signed() => { let prev_expected_type = self.visitor.expected_type;
self.assert_type(t1.unwrap(), self.expected_type); self.visitor.expected_type = None;
}
// Type A was an int. let t1 = self.visit_expression(&input.left);
// But Type B was not a unsigned int. let t2 = self.visit_expression(&input.right);
(Some(Type::IntegerType(_)), Some(t)) => {
self.handler.emit_err( self.visitor.expected_type = prev_expected_type;
TypeCheckerError::incorrect_pow_exponent_type("unsigned int", t, input.right.span())
.into(), match (t1.as_ref(), t2.as_ref()) {
); // Type A must be an int.
} // Type B must be a unsigned int.
// Type A must be a field. (Some(Type::IntegerType(_)), Some(Type::IntegerType(itype))) if !itype.is_signed() => {
// Type B must be an int. self.visitor.assert_type(t1.unwrap(), self.visitor.expected_type);
(Some(Type::Field), Some(Type::IntegerType(_))) => { }
self.assert_type(Type::Field, self.expected_type); // Type A was an int.
} // But Type B was not a unsigned int.
// Type A was a field. (Some(Type::IntegerType(_)), Some(t)) => {
// But Type B was not an int. self.visitor.handler.emit_err(
(Some(Type::Field), Some(t)) => { TypeCheckerError::incorrect_pow_exponent_type("unsigned int", t, input.right.span())
self.handler.emit_err( .into(),
TypeCheckerError::incorrect_pow_exponent_type("int", t, input.right.span()).into(), );
); }
} // Type A must be a field.
// The base is some type thats not an int or field. // Type B must be an int.
(Some(t), _) => { (Some(Type::Field), Some(Type::IntegerType(_))) => {
self.handler self.visitor.assert_type(Type::Field, self.visitor.expected_type);
.emit_err(TypeCheckerError::incorrect_pow_base_type(t, input.left.span()).into()); }
// Type A was a field.
// But Type B was not an int.
(Some(Type::Field), Some(t)) => {
self.visitor.handler.emit_err(
TypeCheckerError::incorrect_pow_exponent_type("int", t, input.right.span()).into(),
);
}
// The base is some type thats not an int or field.
(Some(t), _) => {
self.visitor
.handler
.emit_err(TypeCheckerError::incorrect_pow_base_type(t, input.left.span()).into());
}
_ => {}
} }
t1
}
BinaryOperation::Eq | BinaryOperation::Ne => {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = None;
let t1 = self.visit_expression(&input.left);
let t2 = self.visit_expression(&input.right);
self.visitor.expected_type = prev_expected_type;
self.visitor.assert_eq_types(t1, t2, input.span());
Some(Type::Boolean)
}
BinaryOperation::Lt | BinaryOperation::Gt | BinaryOperation::Le | BinaryOperation::Ge => {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = None;
let t1 = self.visit_expression(&input.left);
self.visitor.assert_field_scalar_int_type(t1, input.left.span());
let t2 = self.visit_expression(&input.right);
self.visitor.assert_field_scalar_int_type(t2, input.right.span());
self.visitor.expected_type = prev_expected_type;
self.visitor.assert_eq_types(t1, t2, input.span());
Some(Type::Boolean)
}
};
}
None
}
fn visit_unary(&mut self, input: &'a UnaryExpression) -> Option<Self::Output> {
match input.op {
UnaryOperation::Not => {
self.visitor.assert_type(Type::Boolean, self.visitor.expected_type);
self.visit_expression(&input.inner)
}
UnaryOperation::Negate => {
let prior_negate_state = self.visitor.negate;
self.visitor.negate = true;
let type_ = self.visit_expression(&input.inner);
self.visitor.negate = prior_negate_state;
match type_.as_ref() {
Some(
Type::IntegerType(
IntegerType::I8
| IntegerType::I16
| IntegerType::I32
| IntegerType::I64
| IntegerType::I128,
)
| Type::Field
| Type::Group,
) => {}
Some(t) => self
.visitor
.handler
.emit_err(TypeCheckerError::type_is_not_negatable(t, input.inner.span()).into()),
_ => {} _ => {}
};
type_
}
}
}
fn visit_ternary(&mut self, input: &'a TernaryExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor.visit_ternary(input) {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = Some(Type::Boolean);
self.visit_expression(&input.condition);
self.visitor.expected_type = prev_expected_type;
let t1 = self.visit_expression(&input.if_true);
let t2 = self.visit_expression(&input.if_false);
return return_incorrect_type(t1, t2, self.visitor.expected_type);
}
None
}
fn visit_call(&mut self, input: &'a CallExpression) -> Option<Self::Output> {
match &*input.function {
Expression::Identifier(ident) => {
if let Some(func) = self.visitor.symbol_table.clone().lookup_fn(&ident.name) {
let ret = self.visitor.assert_type(func.output, self.visitor.expected_type);
if func.input.len() != input.arguments.len() {
self.visitor.handler.emit_err(
TypeCheckerError::incorrect_num_args_to_call(
func.input.len(),
input.arguments.len(),
input.span(),
)
.into(),
);
}
func.input
.iter()
.zip(input.arguments.iter())
.for_each(|(expected, argument)| {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = Some(expected.get_variable().type_);
self.visit_expression(argument);
self.visitor.expected_type = prev_expected_type;
});
Some(ret)
} else {
self.visitor
.handler
.emit_err(TypeCheckerError::unknown_sym("function", &ident.name, ident.span()).into());
None
} }
t1
} }
BinaryOperation::Eq | BinaryOperation::Ne => { expr => self.visit_expression(expr),
let t1 = self.compare_expr_type(&input.left, None, input.left.span()); }
let t2 = self.compare_expr_type(&input.right, None, input.right.span());
self.assert_eq_types(t1, t2, input.span());
Some(Type::Boolean)
}
BinaryOperation::Lt | BinaryOperation::Gt | BinaryOperation::Le | BinaryOperation::Ge => {
let t1 = self.compare_expr_type(&input.left, None, input.left.span());
self.assert_field_scalar_int_type(t1, input.left.span());
let t2 = self.compare_expr_type(&input.right, None, input.right.span());
self.assert_field_scalar_int_type(t2, input.right.span());
self.assert_eq_types(t1, t2, input.span());
Some(Type::Boolean)
}
}; */
self.span = prev_span;
(VisitResult::VisitChildren, None)
} }
} }

View File

@ -18,6 +18,8 @@ use leo_ast::*;
use crate::{Declaration, TypeChecker, VariableSymbol}; use crate::{Declaration, TypeChecker, VariableSymbol};
use super::director::Director;
impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
fn visit_function(&mut self, input: &'a Function) -> VisitResult { fn visit_function(&mut self, input: &'a Function) -> VisitResult {
self.symbol_table.clear_variables(); self.symbol_table.clear_variables();
@ -40,3 +42,5 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
VisitResult::VisitChildren VisitResult::VisitChildren
} }
} }
impl<'a> ProgramVisitorDirector<'a> for Director<'a> {}

View File

@ -19,20 +19,23 @@ use leo_errors::TypeCheckerError;
use crate::{Declaration, TypeChecker, VariableSymbol}; use crate::{Declaration, TypeChecker, VariableSymbol};
impl<'a> StatementVisitor<'a> for TypeChecker<'a> { use super::director::Director;
fn visit_return(&mut self, input: &'a ReturnStatement) -> VisitResult {
impl<'a> StatementVisitor<'a> for TypeChecker<'a> {}
impl<'a> StatementVisitorDirector<'a> for Director<'a> {
fn visit_return(&mut self, input: &'a ReturnStatement) {
// we can safely unwrap all self.parent instances because // we can safely unwrap all self.parent instances because
// statements should always have some parent block // statements should always have some parent block
let parent = self.parent.unwrap(); let parent = self.visitor.parent.unwrap();
// Would never be None. let prev_expected_type = self.visitor.expected_type;
let func_output_type = self.symbol_table.lookup_fn(&parent).map(|f| f.output); self.visitor.expected_type = self.visitor.symbol_table.lookup_fn(&parent).map(|f| f.output);
// self.compare_expr_type(&input.expression, func_output_type, input.expression.span()); self.visit_expression(&input.expression);
self.visitor.expected_type = prev_expected_type;
VisitResult::VisitChildren
} }
fn visit_definition(&mut self, input: &'a DefinitionStatement) -> VisitResult { fn visit_definition(&mut self, input: &'a DefinitionStatement) {
let declaration = if input.declaration_type == Declare::Const { let declaration = if input.declaration_type == Declare::Const {
Declaration::Const Declaration::Const
} else { } else {
@ -40,7 +43,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
}; };
input.variable_names.iter().for_each(|v| { input.variable_names.iter().for_each(|v| {
if let Err(err) = self.symbol_table.insert_variable( if let Err(err) = self.visitor.symbol_table.insert_variable(
v.identifier.name, v.identifier.name,
VariableSymbol { VariableSymbol {
type_: &input.type_, type_: &input.type_,
@ -48,23 +51,26 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
declaration: declaration.clone(), declaration: declaration.clone(),
}, },
) { ) {
self.handler.emit_err(err); self.visitor.handler.emit_err(err);
} }
// self.compare_expr_type(&input.value, Some(input.type_), input.value.span()); let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = Some(input.type_);
self.visit_expression(&input.value);
self.visitor.expected_type = prev_expected_type;
}); });
VisitResult::VisitChildren
} }
fn visit_assign(&mut self, input: &'a AssignStatement) -> VisitResult { fn visit_assign(&mut self, input: &'a AssignStatement) {
let var_name = &input.assignee.identifier.name; let var_name = &input.assignee.identifier.name;
let var_type = if let Some(var) = self.symbol_table.lookup_variable(var_name) { let var_type = if let Some(var) = self.visitor.symbol_table.lookup_variable(var_name) {
match &var.declaration { match &var.declaration {
Declaration::Const => self Declaration::Const => self
.visitor
.handler .handler
.emit_err(TypeCheckerError::cannont_assign_to_const_var(var_name, var.span).into()), .emit_err(TypeCheckerError::cannont_assign_to_const_var(var_name, var.span).into()),
Declaration::Input(ParamMode::Constant) => self Declaration::Input(ParamMode::Constant) => self
.visitor
.handler .handler
.emit_err(TypeCheckerError::cannont_assign_to_const_input(var_name, var.span).into()), .emit_err(TypeCheckerError::cannont_assign_to_const_input(var_name, var.span).into()),
_ => {} _ => {}
@ -72,7 +78,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
Some(*var.type_) Some(*var.type_)
} else { } else {
self.handler.emit_err( self.visitor.handler.emit_err(
TypeCheckerError::unknown_sym("variable", &input.assignee.identifier.name, input.assignee.span).into(), TypeCheckerError::unknown_sym("variable", &input.assignee.identifier.name, input.assignee.span).into(),
); );
@ -80,20 +86,22 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
}; };
if var_type.is_some() { if var_type.is_some() {
// self.compare_expr_type(&input.value, var_type, input.value.span()); let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = var_type;
self.visit_expression(&input.value);
self.visitor.expected_type = prev_expected_type;
} }
VisitResult::VisitChildren
} }
fn visit_conditional(&mut self, input: &'a ConditionalStatement) -> VisitResult { fn visit_conditional(&mut self, input: &'a ConditionalStatement) {
// self.compare_expr_type(&input.condition, Some(Type::Boolean), input.condition.span()); let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = Some(Type::Boolean);
VisitResult::VisitChildren self.visit_expression(&input.condition);
self.visitor.expected_type = prev_expected_type;
} }
fn visit_iteration(&mut self, input: &'a IterationStatement) -> VisitResult { fn visit_iteration(&mut self, input: &'a IterationStatement) {
if let Err(err) = self.symbol_table.insert_variable( if let Err(err) = self.visitor.symbol_table.insert_variable(
input.variable.name, input.variable.name,
VariableSymbol { VariableSymbol {
type_: &input.type_, type_: &input.type_,
@ -101,30 +109,33 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
declaration: Declaration::Const, declaration: Declaration::Const,
}, },
) { ) {
self.handler.emit_err(err); self.visitor.handler.emit_err(err);
} }
// self.compare_expr_type(&input.start, Some(input.type_), input.start.span()); let prev_expected_type = self.visitor.expected_type;
// self.compare_expr_type(&input.stop, Some(input.type_), input.stop.span()); self.visitor.expected_type = Some(input.type_);
self.visit_expression(&input.start);
VisitResult::VisitChildren self.visit_expression(&input.stop);
self.visitor.expected_type = prev_expected_type;
} }
fn visit_console(&mut self, input: &'a ConsoleStatement) -> VisitResult { fn visit_console(&mut self, input: &'a ConsoleStatement) {
match &input.function { match &input.function {
ConsoleFunction::Assert(expr) => { ConsoleFunction::Assert(expr) => {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = Some(Type::Boolean);
self.visit_expression(expr);
self.visitor.expected_type = prev_expected_type;
// self.compare_expr_type(expr, Some(Type::Boolean), expr.span()); // self.compare_expr_type(expr, Some(Type::Boolean), expr.span());
} }
ConsoleFunction::Error(_) | ConsoleFunction::Log(_) => { ConsoleFunction::Error(_) | ConsoleFunction::Log(_) => {
// TODO: undetermined // TODO: undetermined
} }
} }
VisitResult::VisitChildren
} }
fn visit_block(&mut self, input: &'a Block) -> VisitResult { fn visit_block(&mut self, input: &'a Block) {
self.symbol_table.push_variable_scope(); self.visitor.symbol_table.push_variable_scope();
// have to redo the logic here so we have scoping // have to redo the logic here so we have scoping
input.statements.iter().for_each(|stmt| { input.statements.iter().for_each(|stmt| {
match stmt { match stmt {
@ -137,8 +148,6 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
Statement::Block(stmt) => self.visit_block(stmt), Statement::Block(stmt) => self.visit_block(stmt),
}; };
}); });
self.symbol_table.pop_variable_scope(); self.visitor.symbol_table.pop_variable_scope();
VisitResult::SkipChildren
} }
} }

View File

@ -15,12 +15,12 @@
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>. // along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use leo_ast::*; use leo_ast::*;
use leo_errors::{emitter::Handler, TypeCheckerError}; use leo_errors::emitter::Handler;
use crate::{SymbolTable, TypeChecker}; use crate::{SymbolTable, TypeChecker};
pub(crate) struct Director<'a> { pub(crate) struct Director<'a> {
visitor: TypeChecker<'a>, pub(crate) visitor: TypeChecker<'a>,
} }
impl<'a> Director<'a> { impl<'a> Director<'a> {
@ -42,159 +42,3 @@ impl<'a> VisitorDirector<'a> for Director<'a> {
&mut self.visitor &mut self.visitor
} }
} }
fn return_incorrect_type(t1: Option<Type>, t2: Option<Type>, expected: Option<Type>) -> Option<Type> {
match (t1, t2) {
(Some(t1), Some(t2)) if t1 == t2 => Some(t1),
(Some(t1), Some(t2)) => {
if let Some(expected) = expected {
if t1 != expected {
Some(t1)
} else {
Some(t2)
}
} else {
Some(t1)
}
}
(None, Some(_)) | (Some(_), None) | (None, None) => None,
}
}
impl<'a> ExpressionVisitorDirector<'a> for Director<'a> {
type Output = Type;
fn visit_expression(&mut self, input: &'a Expression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor.visit_expression(input).0 {
return match input {
Expression::Identifier(expr) => self.visit_identifier(expr),
Expression::Value(expr) => self.visit_value(expr),
Expression::Binary(expr) => self.visit_binary(expr),
Expression::Unary(expr) => self.visit_unary(expr),
Expression::Ternary(expr) => self.visit_ternary(expr),
Expression::Call(expr) => self.visit_call(expr),
Expression::Err(expr) => self.visit_err(expr),
};
}
None
}
fn visit_identifier(&mut self, input: &'a Identifier) -> Option<Self::Output> {
self.visitor.visit_identifier(input).1
}
fn visit_value(&mut self, input: &'a ValueExpression) -> Option<Self::Output> {
self.visitor.visit_value(input).1
}
fn visit_binary(&mut self, input: &'a BinaryExpression) -> Option<Self::Output> {
match self.visitor.visit_binary(input) {
(VisitResult::VisitChildren, expected) => {
let t1 = self.visit_expression(&input.left);
let t2 = self.visit_expression(&input.right);
return_incorrect_type(t1, t2, self.visitor.expected_type)
}
_ => None,
}
}
fn visit_unary(&mut self, input: &'a UnaryExpression) -> Option<Self::Output> {
match input.op {
UnaryOperation::Not => {
self.visitor.assert_type(Type::Boolean, self.visitor.expected_type);
self.visit_expression(&input.inner)
}
UnaryOperation::Negate => {
let prior_negate_state = self.visitor.negate;
self.visitor.negate = true;
let type_ = self.visit_expression(&input.inner);
self.visitor.negate = prior_negate_state;
match type_.as_ref() {
Some(
Type::IntegerType(
IntegerType::I8
| IntegerType::I16
| IntegerType::I32
| IntegerType::I64
| IntegerType::I128,
)
| Type::Field
| Type::Group,
) => {}
Some(t) => self
.visitor
.handler
.emit_err(TypeCheckerError::type_is_not_negatable(t, input.inner.span()).into()),
_ => {}
};
type_
}
}
}
fn visit_ternary(&mut self, input: &'a TernaryExpression) -> Option<Self::Output> {
if let VisitResult::VisitChildren = self.visitor.visit_ternary(input).0 {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = Some(Type::Boolean);
self.visit_expression(&input.condition);
self.visitor.expected_type = prev_expected_type;
let t1 = self.visit_expression(&input.if_true);
let t2 = self.visit_expression(&input.if_false);
return return_incorrect_type(t1, t2, self.visitor.expected_type);
}
None
}
fn visit_call(&mut self, input: &'a CallExpression) -> Option<Self::Output> {
match &*input.function {
Expression::Identifier(ident) => {
if let Some(func) = self.visitor.symbol_table.clone().lookup_fn(&ident.name) {
let ret = self.visitor.assert_type(func.output, self.visitor.expected_type);
if func.input.len() != input.arguments.len() {
self.visitor.handler.emit_err(
TypeCheckerError::incorrect_num_args_to_call(
func.input.len(),
input.arguments.len(),
input.span(),
)
.into(),
);
}
func.input
.iter()
.zip(input.arguments.iter())
.for_each(|(expected, argument)| {
let prev_expected_type = self.visitor.expected_type;
self.visitor.expected_type = Some(expected.get_variable().type_);
self.visit_expression(argument);
self.visitor.expected_type = prev_expected_type;
});
Some(ret)
} else {
self.visitor
.handler
.emit_err(TypeCheckerError::unknown_sym("function", &ident.name, ident.span()).into());
None
}
}
expr => self.visit_expression(expr),
}
}
fn visit_err(&mut self, input: &'a ErrExpression) -> Option<Self::Output> {
self.visitor.visit_err(input).1
}
}
impl<'a> StatementVisitorDirector<'a> for Director<'a> {}
impl<'a> ProgramVisitorDirector<'a> for Director<'a> {}

View File

@ -2,4 +2,4 @@
namespace: Compile namespace: Compile
expectation: Fail expectation: Fail
outputs: outputs:
- "Error [ETYC0372002]: Found type `group` but type `scalar` was expected\n --> compiler-test:4:12\n |\n 4 | return (_, _)group * a;\n | ^^^^^^^^^^^^^^^\n" - "Error [ETYC0372002]: Found type `group` but type `scalar` was expected\n --> compiler-test:1:1\n |\n 1 | \n | \n"

View File

@ -2,4 +2,4 @@
namespace: Compile namespace: Compile
expectation: Fail expectation: Fail
outputs: outputs:
- "Error [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:4:19\n |\n 4 | let b: bool = a == 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:5:19\n |\n 5 | let c: bool = a != 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:6:19\n |\n 6 | let d: bool = a > 1u8;\n | ^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:7:19\n |\n 7 | let e: bool = a < 1u8;\n | ^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:8:19\n |\n 8 | let f: bool = a >= 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:9:19\n |\n 9 | let g: bool = a <= 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u32` was expected\n --> compiler-test:10:18\n |\n 10 | let h: u32 = a * 1u8;\n | ^\nError [ETYC0372002]: Found type `u8` but type `u32` was expected\n --> compiler-test:10:22\n |\n 10 | let h: u32 = a * 1u8;\n | ^^^\n" - "Error [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:4:19\n |\n 4 | let b: bool = a == 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:5:19\n |\n 5 | let c: bool = a != 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:6:19\n |\n 6 | let d: bool = a > 1u8;\n | ^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:7:19\n |\n 7 | let e: bool = a < 1u8;\n | ^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:8:19\n |\n 8 | let f: bool = a >= 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u8` was expected\n --> compiler-test:9:19\n |\n 9 | let g: bool = a <= 1u8;\n | ^^^^^^^^\nError [ETYC0372002]: Found type `i8` but type `u32` was expected\n --> compiler-test:1:1\n |\n 1 | \n | \nError [ETYC0372002]: Found type `u8` but type `u32` was expected\n --> compiler-test:1:1\n |\n 1 | \n | \n"