diff --git a/compiler/ast/src/expressions/circuit_init.rs b/compiler/ast/src/expressions/circuit_init.rs index 46e5ad4306..81e8fec94f 100644 --- a/compiler/ast/src/expressions/circuit_init.rs +++ b/compiler/ast/src/expressions/circuit_init.rs @@ -38,7 +38,7 @@ impl fmt::Display for CircuitVariableInitializer { /// A circuit initialization expression, e.g., `Foo { bar: 42, baz }`. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct CircuitInitExpression { +pub struct CircuitExpression { /// The name of the structure type to initialize. pub name: Identifier, /// Initializer expressions for each of the fields in the circuit. @@ -50,7 +50,7 @@ pub struct CircuitInitExpression { pub span: Span, } -impl fmt::Display for CircuitInitExpression { +impl fmt::Display for CircuitExpression { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{} {{ ", self.name)?; for member in self.members.iter() { @@ -61,4 +61,4 @@ impl fmt::Display for CircuitInitExpression { } } -crate::simple_node_impl!(CircuitInitExpression); +crate::simple_node_impl!(CircuitExpression); diff --git a/compiler/ast/src/expressions/mod.rs b/compiler/ast/src/expressions/mod.rs index 4fcca54b4b..e896696dac 100644 --- a/compiler/ast/src/expressions/mod.rs +++ b/compiler/ast/src/expressions/mod.rs @@ -38,6 +38,9 @@ pub use err::*; mod ternary; pub use ternary::*; +mod tuple_init; +pub use tuple_init::*; + mod unary; pub use unary::*; @@ -49,21 +52,23 @@ pub use value::*; pub enum Expression { /// A circuit access expression, e.g., `Foo.bar`. Access(AccessExpression), - /// An identifier expression. - Identifier(Identifier), - /// A literal expression. - Literal(LiteralExpression), /// A binary expression, e.g., `42 + 24`. Binary(BinaryExpression), /// A call expression, e.g., `my_fun(args)`. Call(CallExpression), /// An expression constructing a circuit like `Foo { bar: 42, baz }`. - CircuitInit(CircuitInitExpression), + Circuit(CircuitExpression), /// An expression of type "error". /// Will result in a compile error eventually. Err(ErrExpression), + /// An identifier. + Identifier(Identifier), + /// A literal expression. + Literal(LiteralExpression), /// A ternary conditional expression `cond ? if_expr : else_expr`. Ternary(TernaryExpression), + /// A tuple expression e.g., `(foo, 42, true)`. + Tuple(TupleExpression), /// An unary expression. Unary(UnaryExpression), } @@ -73,13 +78,14 @@ impl Node for Expression { use Expression::*; match self { Access(n) => n.span(), - Identifier(n) => n.span(), - Literal(n) => n.span(), Binary(n) => n.span(), Call(n) => n.span(), - CircuitInit(n) => n.span(), + Circuit(n) => n.span(), Err(n) => n.span(), + Identifier(n) => n.span(), + Literal(n) => n.span(), Ternary(n) => n.span(), + Tuple(n) => n.span(), Unary(n) => n.span(), } } @@ -88,13 +94,14 @@ impl Node for Expression { use Expression::*; match self { Access(n) => n.set_span(span), - Identifier(n) => n.set_span(span), - Literal(n) => n.set_span(span), Binary(n) => n.set_span(span), Call(n) => n.set_span(span), - CircuitInit(n) => n.set_span(span), + Circuit(n) => n.set_span(span), + Identifier(n) => n.set_span(span), + Literal(n) => n.set_span(span), Err(n) => n.set_span(span), Ternary(n) => n.set_span(span), + Tuple(n) => n.set_span(span), Unary(n) => n.set_span(span), } } @@ -105,13 +112,14 @@ impl fmt::Display for Expression { use Expression::*; match &self { Access(n) => n.fmt(f), - Identifier(n) => n.fmt(f), - Literal(n) => n.fmt(f), Binary(n) => n.fmt(f), Call(n) => n.fmt(f), - CircuitInit(n) => n.fmt(f), + Circuit(n) => n.fmt(f), Err(n) => n.fmt(f), + Identifier(n) => n.fmt(f), + Literal(n) => n.fmt(f), Ternary(n) => n.fmt(f), + Tuple(n) => n.fmt(f), Unary(n) => n.fmt(f), } } diff --git a/compiler/ast/src/expressions/tuple_init.rs b/compiler/ast/src/expressions/tuple_init.rs new file mode 100644 index 0000000000..7be5a38a97 --- /dev/null +++ b/compiler/ast/src/expressions/tuple_init.rs @@ -0,0 +1,43 @@ +// Copyright (C) 2019-2022 Aleo Systems Inc. +// This file is part of the Leo library. + +// The Leo library is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// The Leo library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with the Leo library. If not, see . + +use super::*; + +/// A tuple construction expression, e.g., `(foo, false, 42)`. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct TupleExpression { + /// The elements of the tuple. + /// In the example above, it would be `foo`, `false`, and `42`. + pub elements: Vec, + /// The span from `(` to `)`. + pub span: Span, +} + +impl fmt::Display for TupleExpression { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "({})", + self.elements + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(",") + ) + } +} + +crate::simple_node_impl!(TupleExpression); diff --git a/compiler/ast/src/passes/reconstructor.rs b/compiler/ast/src/passes/reconstructor.rs index 24d4a0874e..048c27143b 100644 --- a/compiler/ast/src/passes/reconstructor.rs +++ b/compiler/ast/src/passes/reconstructor.rs @@ -27,33 +27,22 @@ pub trait ExpressionReconstructor { fn reconstruct_expression(&mut self, input: Expression) -> (Expression, Self::AdditionalOutput) { match input { Expression::Access(access) => self.reconstruct_access(access), - Expression::Identifier(identifier) => self.reconstruct_identifier(identifier), - Expression::Literal(value) => self.reconstruct_literal(value), Expression::Binary(binary) => self.reconstruct_binary(binary), Expression::Call(call) => self.reconstruct_call(call), - Expression::CircuitInit(circuit) => self.reconstruct_circuit_init(circuit), - Expression::Unary(unary) => self.reconstruct_unary(unary), - Expression::Ternary(ternary) => self.reconstruct_ternary(ternary), + Expression::Circuit(circuit) => self.reconstruct_circuit_init(circuit), Expression::Err(err) => self.reconstruct_err(err), + Expression::Identifier(identifier) => self.reconstruct_identifier(identifier), + Expression::Literal(value) => self.reconstruct_literal(value), + Expression::Ternary(ternary) => self.reconstruct_ternary(ternary), + Expression::Tuple(tuple) => self.reconstruct_tuple(tuple), + Expression::Unary(unary) => self.reconstruct_unary(unary), } } - fn reconstruct_identifier(&mut self, input: Identifier) -> (Expression, Self::AdditionalOutput) { - (Expression::Identifier(input), Default::default()) - } - - fn reconstruct_literal(&mut self, input: LiteralExpression) -> (Expression, Self::AdditionalOutput) { - (Expression::Literal(input), Default::default()) - } - fn reconstruct_access(&mut self, input: AccessExpression) -> (Expression, Self::AdditionalOutput) { (Expression::Access(input), Default::default()) } - fn reconstruct_circuit_init(&mut self, input: CircuitInitExpression) -> (Expression, Self::AdditionalOutput) { - (Expression::CircuitInit(input), Default::default()) - } - fn reconstruct_binary(&mut self, input: BinaryExpression) -> (Expression, Self::AdditionalOutput) { ( Expression::Binary(BinaryExpression { @@ -66,28 +55,6 @@ pub trait ExpressionReconstructor { ) } - fn reconstruct_unary(&mut self, input: UnaryExpression) -> (Expression, Self::AdditionalOutput) { - ( - Expression::Unary(UnaryExpression { - receiver: Box::new(self.reconstruct_expression(*input.receiver).0), - op: input.op, - span: input.span, - }), - Default::default(), - ) - } - - fn reconstruct_ternary(&mut self, input: TernaryExpression) -> (Expression, Self::AdditionalOutput) { - ( - Expression::Ternary(TernaryExpression { - condition: Box::new(self.reconstruct_expression(*input.condition).0), - if_true: Box::new(self.reconstruct_expression(*input.if_true).0), - if_false: Box::new(self.reconstruct_expression(*input.if_false).0), - span: input.span, - }), - Default::default(), - ) - } fn reconstruct_call(&mut self, input: CallExpression) -> (Expression, Self::AdditionalOutput) { ( @@ -104,9 +71,58 @@ pub trait ExpressionReconstructor { ) } + fn reconstruct_circuit_init(&mut self, input: CircuitExpression) -> (Expression, Self::AdditionalOutput) { + (Expression::Circuit(input), Default::default()) + } + fn reconstruct_err(&mut self, input: ErrExpression) -> (Expression, Self::AdditionalOutput) { (Expression::Err(input), Default::default()) } + + fn reconstruct_identifier(&mut self, input: Identifier) -> (Expression, Self::AdditionalOutput) { + (Expression::Identifier(input), Default::default()) + } + + fn reconstruct_literal(&mut self, input: LiteralExpression) -> (Expression, Self::AdditionalOutput) { + (Expression::Literal(input), Default::default()) + } + + fn reconstruct_ternary(&mut self, input: TernaryExpression) -> (Expression, Self::AdditionalOutput) { + ( + Expression::Ternary(TernaryExpression { + condition: Box::new(self.reconstruct_expression(*input.condition).0), + if_true: Box::new(self.reconstruct_expression(*input.if_true).0), + if_false: Box::new(self.reconstruct_expression(*input.if_false).0), + span: input.span, + }), + Default::default(), + ) + } + + fn reconstruct_tuple(&mut self, input: TupleExpression) -> (Expression, Self::AdditionalOutput) { + ( + Expression::Tuple(TupleExpression { + elements: input + .elements + .into_iter() + .map(|element| self.reconstruct_expression(element).0) + .collect(), + span: input.span, + }), + Default::default(), + ) + } + + fn reconstruct_unary(&mut self, input: UnaryExpression) -> (Expression, Self::AdditionalOutput) { + ( + Expression::Unary(UnaryExpression { + receiver: Box::new(self.reconstruct_expression(*input.receiver).0), + op: input.op, + span: input.span, + }), + Default::default(), + ) + } } /// A Reconstructor trait for statements in the AST. diff --git a/compiler/ast/src/passes/visitor.rs b/compiler/ast/src/passes/visitor.rs index 3a1b0b3d18..ea7e74e17f 100644 --- a/compiler/ast/src/passes/visitor.rs +++ b/compiler/ast/src/passes/visitor.rs @@ -27,15 +27,16 @@ pub trait ExpressionVisitor<'a> { fn visit_expression(&mut self, input: &'a Expression, additional: &Self::AdditionalInput) -> Self::Output { match input { - Expression::Access(expr) => self.visit_access(expr, additional), - Expression::CircuitInit(expr) => self.visit_circuit_init(expr, additional), - Expression::Identifier(expr) => self.visit_identifier(expr, additional), - Expression::Literal(expr) => self.visit_literal(expr, additional), - Expression::Binary(expr) => self.visit_binary(expr, additional), - Expression::Unary(expr) => self.visit_unary(expr, additional), - Expression::Ternary(expr) => self.visit_ternary(expr, additional), - Expression::Call(expr) => self.visit_call(expr, additional), - Expression::Err(expr) => self.visit_err(expr, additional), + Expression::Access(access) => self.visit_access(access, additional), + Expression::Binary(binary) => self.visit_binary(binary, additional), + Expression::Call(call) => self.visit_call(call, additional), + Expression::Circuit(circuit) => self.visit_circuit_init(circuit, additional), + Expression::Err(err) => self.visit_err(err, additional), + Expression::Identifier(identifier) => self.visit_identifier(identifier, additional), + Expression::Literal(literal) => self.visit_literal(literal, additional), + Expression::Ternary(ternary) => self.visit_ternary(ternary, additional), + Expression::Tuple(tuple) => self.visit_tuple(tuple, additional), + Expression::Unary(unary) => self.visit_unary(unary, additional), } } @@ -43,40 +44,12 @@ pub trait ExpressionVisitor<'a> { Default::default() } - fn visit_circuit_init( - &mut self, - _input: &'a CircuitInitExpression, - _additional: &Self::AdditionalInput, - ) -> Self::Output { - Default::default() - } - - fn visit_identifier(&mut self, _input: &'a Identifier, _additional: &Self::AdditionalInput) -> Self::Output { - Default::default() - } - - fn visit_literal(&mut self, _input: &'a LiteralExpression, _additional: &Self::AdditionalInput) -> Self::Output { - Default::default() - } - fn visit_binary(&mut self, input: &'a BinaryExpression, additional: &Self::AdditionalInput) -> Self::Output { self.visit_expression(&input.left, additional); self.visit_expression(&input.right, additional); Default::default() } - fn visit_unary(&mut self, input: &'a UnaryExpression, additional: &Self::AdditionalInput) -> Self::Output { - self.visit_expression(&input.receiver, additional); - Default::default() - } - - fn visit_ternary(&mut self, input: &'a TernaryExpression, additional: &Self::AdditionalInput) -> Self::Output { - self.visit_expression(&input.condition, additional); - self.visit_expression(&input.if_true, additional); - self.visit_expression(&input.if_false, additional); - Default::default() - } - fn visit_call(&mut self, input: &'a CallExpression, additional: &Self::AdditionalInput) -> Self::Output { input.arguments.iter().for_each(|expr| { self.visit_expression(expr, additional); @@ -87,6 +60,40 @@ pub trait ExpressionVisitor<'a> { fn visit_err(&mut self, _input: &'a ErrExpression, _additional: &Self::AdditionalInput) -> Self::Output { Default::default() } + fn visit_identifier(&mut self, _input: &'a Identifier, _additional: &Self::AdditionalInput) -> Self::Output { + Default::default() + } + + fn visit_literal(&mut self, _input: &'a LiteralExpression, _additional: &Self::AdditionalInput) -> Self::Output { + Default::default() + } + + fn visit_circuit_init( + &mut self, + _input: &'a CircuitExpression, + _additional: &Self::AdditionalInput, + ) -> Self::Output { + Default::default() + } + + fn visit_ternary(&mut self, input: &'a TernaryExpression, additional: &Self::AdditionalInput) -> Self::Output { + self.visit_expression(&input.condition, additional); + self.visit_expression(&input.if_true, additional); + self.visit_expression(&input.if_false, additional); + Default::default() + } + + fn visit_tuple(&mut self, input: &'a TupleExpression, additional: &Self::AdditionalInput) -> Self::Output { + input.elements.iter().for_each(|expr| { + self.visit_expression(expr, additional); + }); + Default::default() + } + + fn visit_unary(&mut self, input: &'a UnaryExpression, additional: &Self::AdditionalInput) -> Self::Output { + self.visit_expression(&input.receiver, additional); + Default::default() + } } /// A Visitor trait for statements in the AST. diff --git a/compiler/ast/src/types/tuple.rs b/compiler/ast/src/types/tuple.rs index 1f861be121..427fe58608 100644 --- a/compiler/ast/src/types/tuple.rs +++ b/compiler/ast/src/types/tuple.rs @@ -19,10 +19,7 @@ use leo_errors::{AstError, Result}; use leo_span::Span; use serde::{Deserialize, Serialize}; -use std::{ - fmt, - ops::Deref -}; +use std::{fmt, ops::Deref}; /// A type list of at least two types. #[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] @@ -34,7 +31,7 @@ impl Tuple { match elements.len() { 0 => Err(AstError::empty_tuple(span).into()), 1 => Err(AstError::one_element_tuple(span).into()), - _ => Ok(Type::Tuple(Tuple(elements))) + _ => Ok(Type::Tuple(Tuple(elements))), } } } @@ -49,6 +46,10 @@ impl Deref for Tuple { impl fmt::Display for Tuple { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "({})", self.0.iter().map(|x| x.to_string()).collect::>().join(",")) + write!( + f, + "({})", + self.0.iter().map(|x| x.to_string()).collect::>().join(",") + ) } -} \ No newline at end of file +} diff --git a/compiler/parser/src/parser/expression.rs b/compiler/parser/src/parser/expression.rs index 1cdf2233e7..6d0689aee5 100644 --- a/compiler/parser/src/parser/expression.rs +++ b/compiler/parser/src/parser/expression.rs @@ -470,7 +470,7 @@ impl ParserContext<'_> { p.parse_circuit_member().map(Some) })?; - Ok(Expression::CircuitInit(CircuitInitExpression { + Ok(Expression::Circuit(CircuitExpression { span: identifier.span + end, name: identifier, members, diff --git a/compiler/passes/src/type_checker/check_expressions.rs b/compiler/passes/src/type_checker/check_expressions.rs index 09a2c4c5ff..8089db9e20 100644 --- a/compiler/passes/src/type_checker/check_expressions.rs +++ b/compiler/passes/src/type_checker/check_expressions.rs @@ -43,14 +43,15 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { fn visit_expression(&mut self, input: &'a Expression, expected: &Self::AdditionalInput) -> Self::Output { match input { - Expression::Access(expr) => self.visit_access(expr, expected), - Expression::Identifier(expr) => self.visit_identifier(expr, expected), - Expression::Literal(expr) => self.visit_literal(expr, expected), - Expression::Binary(expr) => self.visit_binary(expr, expected), - Expression::Call(expr) => self.visit_call(expr, expected), - Expression::CircuitInit(expr) => self.visit_circuit_init(expr, expected), - Expression::Err(expr) => self.visit_err(expr, expected), - Expression::Ternary(expr) => self.visit_ternary(expr, expected), + Expression::Access(access) => self.visit_access(access, expected), + Expression::Binary(binary) => self.visit_binary(binary, expected), + Expression::Call(call) => self.visit_call(call, expected), + Expression::Circuit(circuit) => self.visit_circuit_init(circuit, expected), + Expression::Identifier(identifier) => self.visit_identifier(identifier, expected), + Expression::Err(err) => self.visit_err(err, expected), + Expression::Literal(literal) => self.visit_literal(literal, expected), + Expression::Ternary(ternary) => self.visit_ternary(ternary, expected), + Expression::Tuple(tuple) => self.visit_tuple(tuple, expected), Expression::Unary(expr) => self.visit_unary(expr, expected), } } @@ -97,155 +98,6 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { } } - fn visit_circuit_init( - &mut self, - input: &'a CircuitInitExpression, - additional: &Self::AdditionalInput, - ) -> Self::Output { - let circ = self.symbol_table.borrow().lookup_circuit(&input.name.name).cloned(); - if let Some(circ) = circ { - // Check circuit type name. - let ret = self.check_expected_circuit(circ.identifier, additional, input.name.span()); - - // Check number of circuit members. - if circ.members.len() != input.members.len() { - self.handler.emit_err( - TypeCheckerError::incorrect_num_circuit_members( - circ.members.len(), - input.members.len(), - input.span(), - ) - .into(), - ); - } - - // Check circuit member types. - circ.members - .iter() - .for_each(|CircuitMember::CircuitVariable(name, ty)| { - // Lookup circuit variable name. - if let Some(actual) = input.members.iter().find(|member| member.identifier.name == name.name) { - if let Some(expr) = &actual.expression { - self.visit_expression(expr, &Some(ty.clone())); - } - } else { - self.handler.emit_err( - TypeCheckerError::unknown_sym("circuit member variable", name, name.span()).into(), - ); - }; - }); - - Some(ret) - } else { - self.handler - .emit_err(TypeCheckerError::unknown_sym("circuit", &input.name.name, input.name.span()).into()); - None - } - } - - fn visit_identifier(&mut self, var: &'a Identifier, expected: &Self::AdditionalInput) -> Self::Output { - if let Some(circuit) = self.symbol_table.borrow().lookup_circuit(&var.name) { - Some(self.assert_and_return_type(Type::Identifier(circuit.identifier), expected, var.span)) - } else if let Some(var) = self.symbol_table.borrow().lookup_variable(&var.name) { - Some(self.assert_and_return_type(var.type_.clone(), expected, var.span)) - } else { - self.handler - .emit_err(TypeCheckerError::unknown_sym("variable", var.name, var.span()).into()); - None - } - } - - fn visit_literal(&mut self, input: &'a LiteralExpression, expected: &Self::AdditionalInput) -> Self::Output { - Some(match input { - LiteralExpression::Address(_, _) => self.assert_and_return_type(Type::Address, expected, input.span()), - LiteralExpression::Boolean(_, _) => self.assert_and_return_type(Type::Boolean, expected, input.span()), - LiteralExpression::Field(_, _) => self.assert_and_return_type(Type::Field, expected, input.span()), - LiteralExpression::Integer(type_, str_content, _) => { - match type_ { - IntegerType::I8 => { - let int = if self.negate { - format!("-{str_content}") - } else { - str_content.clone() - }; - - if int.parse::().is_err() { - self.handler - .emit_err(TypeCheckerError::invalid_int_value(int, "i8", input.span()).into()); - } - } - IntegerType::I16 => { - let int = if self.negate { - format!("-{str_content}") - } else { - str_content.clone() - }; - - if int.parse::().is_err() { - self.handler - .emit_err(TypeCheckerError::invalid_int_value(int, "i16", input.span()).into()); - } - } - IntegerType::I32 => { - let int = if self.negate { - format!("-{str_content}") - } else { - str_content.clone() - }; - - if int.parse::().is_err() { - self.handler - .emit_err(TypeCheckerError::invalid_int_value(int, "i32", input.span()).into()); - } - } - IntegerType::I64 => { - let int = if self.negate { - format!("-{str_content}") - } else { - str_content.clone() - }; - - if int.parse::().is_err() { - self.handler - .emit_err(TypeCheckerError::invalid_int_value(int, "i64", input.span()).into()); - } - } - IntegerType::I128 => { - let int = if self.negate { - format!("-{str_content}") - } else { - str_content.clone() - }; - - if int.parse::().is_err() { - self.handler - .emit_err(TypeCheckerError::invalid_int_value(int, "i128", input.span()).into()); - } - } - IntegerType::U8 if str_content.parse::().is_err() => self - .handler - .emit_err(TypeCheckerError::invalid_int_value(str_content, "u8", input.span()).into()), - IntegerType::U16 if str_content.parse::().is_err() => self - .handler - .emit_err(TypeCheckerError::invalid_int_value(str_content, "u16", input.span()).into()), - IntegerType::U32 if str_content.parse::().is_err() => self - .handler - .emit_err(TypeCheckerError::invalid_int_value(str_content, "u32", input.span()).into()), - IntegerType::U64 if str_content.parse::().is_err() => self - .handler - .emit_err(TypeCheckerError::invalid_int_value(str_content, "u64", input.span()).into()), - IntegerType::U128 if str_content.parse::().is_err() => self - .handler - .emit_err(TypeCheckerError::invalid_int_value(str_content, "u128", input.span()).into()), - _ => {} - } - self.assert_and_return_type(Type::IntegerType(*type_), expected, input.span()) - } - LiteralExpression::Group(_) => self.assert_and_return_type(Type::Group, expected, input.span()), - LiteralExpression::Scalar(_, _) => self.assert_and_return_type(Type::Scalar, expected, input.span()), - LiteralExpression::String(_, _) => self.assert_and_return_type(Type::String, expected, input.span()), - }) - } fn visit_binary(&mut self, input: &'a BinaryExpression, destination: &Self::AdditionalInput) -> Self::Output { match input.op { @@ -470,6 +322,228 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { } } + + fn visit_call(&mut self, input: &'a CallExpression, expected: &Self::AdditionalInput) -> Self::Output { + match &*input.function { + Expression::Identifier(ident) => { + let func = self.symbol_table.borrow().lookup_fn(&ident.name).cloned(); + if let Some(func) = func { + let ret = self.assert_and_return_type(func.output, expected, func.span); + + // Check number of function arguments. + if func.input.len() != input.arguments.len() { + self.handler.emit_err( + TypeCheckerError::incorrect_num_args_to_call( + func.input.len(), + input.arguments.len(), + input.span(), + ) + .into(), + ); + } + + // Check function argument types. + func.input + .iter() + .zip(input.arguments.iter()) + .for_each(|(expected, argument)| { + self.visit_expression(argument, &Some(expected.get_variable().type_.clone())); + }); + + Some(ret) + } else { + self.handler + .emit_err(TypeCheckerError::unknown_sym("function", &ident.name, ident.span()).into()); + None + } + } + expr => self.visit_expression(expr, expected), + } + } + + fn visit_circuit_init( + &mut self, + input: &'a CircuitExpression, + additional: &Self::AdditionalInput, + ) -> Self::Output { + let circ = self.symbol_table.borrow().lookup_circuit(&input.name.name).cloned(); + if let Some(circ) = circ { + // Check circuit type name. + let ret = self.check_expected_circuit(circ.identifier, additional, input.name.span()); + + // Check number of circuit members. + if circ.members.len() != input.members.len() { + self.emit_err( + TypeCheckerError::incorrect_num_circuit_members( + circ.members.len(), + input.members.len(), + input.span(), + ) + .into(), + ); + } + + // Check circuit member types. + circ.members + .iter() + .for_each(|CircuitMember::CircuitVariable(name, ty)| { + // Lookup circuit variable name. + if let Some(actual) = input.members.iter().find(|member| member.identifier.name == name.name) { + if let Some(expr) = &actual.expression { + self.visit_expression(expr, &Some(ty.clone())); + } + } else { + self.handler.emit_err( + TypeCheckerError::unknown_sym("circuit member variable", name, name.span()).into(), + ); + }; + }); + + Some(ret) + } else { + self.emit_err(TypeCheckerError::unknown_sym("circuit", &input.name.name, input.name.span()).into()); + None + } + } + + fn visit_identifier(&mut self, var: &'a Identifier, expected: &Self::AdditionalInput) -> Self::Output { + if let Some(circuit) = self.symbol_table.borrow().lookup_circuit(&var.name) { + Some(self.assert_and_return_type(Type::Identifier(circuit.identifier), expected, var.span)) + } else if let Some(var) = self.symbol_table.borrow().lookup_variable(&var.name) { + Some(self.assert_and_return_type(var.type_.clone(), expected, var.span)) + } else { + self.emit_err(TypeCheckerError::unknown_sym("variable", var.name, var.span()).into()); + None + } + } + + fn visit_literal(&mut self, input: &'a LiteralExpression, expected: &Self::AdditionalInput) -> Self::Output { + Some(match input { + LiteralExpression::Address(_, _) => self.assert_and_return_type(Type::Address, expected, input.span()), + LiteralExpression::Boolean(_, _) => self.assert_and_return_type(Type::Boolean, expected, input.span()), + LiteralExpression::Field(_, _) => self.assert_and_return_type(Type::Field, expected, input.span()), + LiteralExpression::Integer(type_, str_content, _) => { + match type_ { + IntegerType::I8 => { + let int = if self.negate { + format!("-{str_content}") + } else { + str_content.clone() + }; + + if int.parse::().is_err() { + self.handler + .emit_err(TypeCheckerError::invalid_int_value(int, "i8", input.span()).into()); + } + } + IntegerType::I16 => { + let int = if self.negate { + format!("-{str_content}") + } else { + str_content.clone() + }; + + if int.parse::().is_err() { + self.handler + .emit_err(TypeCheckerError::invalid_int_value(int, "i16", input.span()).into()); + } + } + IntegerType::I32 => { + let int = if self.negate { + format!("-{str_content}") + } else { + str_content.clone() + }; + + if int.parse::().is_err() { + self.handler + .emit_err(TypeCheckerError::invalid_int_value(int, "i32", input.span()).into()); + } + } + IntegerType::I64 => { + let int = if self.negate { + format!("-{str_content}") + } else { + str_content.clone() + }; + + if int.parse::().is_err() { + self.handler + .emit_err(TypeCheckerError::invalid_int_value(int, "i64", input.span()).into()); + } + } + IntegerType::I128 => { + let int = if self.negate { + format!("-{str_content}") + } else { + str_content.clone() + }; + + if int.parse::().is_err() { + self.handler + .emit_err(TypeCheckerError::invalid_int_value(int, "i128", input.span()).into()); + } + } + IntegerType::U8 if str_content.parse::().is_err() => self + .handler + .emit_err(TypeCheckerError::invalid_int_value(str_content, "u8", input.span()).into()), + IntegerType::U16 if str_content.parse::().is_err() => self + .handler + .emit_err(TypeCheckerError::invalid_int_value(str_content, "u16", input.span()).into()), + IntegerType::U32 if str_content.parse::().is_err() => self + .handler + .emit_err(TypeCheckerError::invalid_int_value(str_content, "u32", input.span()).into()), + IntegerType::U64 if str_content.parse::().is_err() => self + .handler + .emit_err(TypeCheckerError::invalid_int_value(str_content, "u64", input.span()).into()), + IntegerType::U128 if str_content.parse::().is_err() => self + .handler + .emit_err(TypeCheckerError::invalid_int_value(str_content, "u128", input.span()).into()), + _ => {} + } + self.assert_and_return_type(Type::IntegerType(*type_), expected, input.span()) + } + LiteralExpression::Group(_) => self.assert_and_return_type(Type::Group, expected, input.span()), + LiteralExpression::Scalar(_, _) => self.assert_and_return_type(Type::Scalar, expected, input.span()), + LiteralExpression::String(_, _) => self.assert_and_return_type(Type::String, expected, input.span()), + }) + } + + + + fn visit_ternary(&mut self, input: &'a TernaryExpression, expected: &Self::AdditionalInput) -> Self::Output { + self.visit_expression(&input.condition, &Some(Type::Boolean)); + + let t1 = self.visit_expression(&input.if_true, expected); + let t2 = self.visit_expression(&input.if_false, expected); + + return_incorrect_type(t1, t2, expected) + } + + fn visit_tuple(&mut self, input: &'a TupleExpression, expected: &Self::AdditionalInput) -> Self::Output { + // Check the expected tuple types if they are known. + if let Some(Type::Tuple(expected_types)) = expected { + // Check actual length is equal to expected length. + if expected_types.len() != input.elements.len() { + self.emit_err(TypeCheckerError::incorrect_tuple_length(expected_types.len(), input.elements.len(), input.span())); + } + + expected_types + .iter() + .zip(input.elements.iter()) + .for_each(|(expected, expr)| { + self.visit_expression(expr, &Some(expected.clone())); + }); + + Some(Type::Tuple(expected_types.clone())) + } else { + // Tuples must be explicitly typed in testnet3. + self.emit_err(TypeCheckerError::invalid_tuple(input.span())); + + None + } + } + fn visit_unary(&mut self, input: &'a UnaryExpression, destination: &Self::AdditionalInput) -> Self::Output { match input.op { UnaryOperation::Abs => { @@ -521,50 +595,4 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { } } - fn visit_ternary(&mut self, input: &'a TernaryExpression, expected: &Self::AdditionalInput) -> Self::Output { - self.visit_expression(&input.condition, &Some(Type::Boolean)); - - let t1 = self.visit_expression(&input.if_true, expected); - let t2 = self.visit_expression(&input.if_false, expected); - - return_incorrect_type(t1, t2, expected) - } - - fn visit_call(&mut self, input: &'a CallExpression, expected: &Self::AdditionalInput) -> Self::Output { - match &*input.function { - Expression::Identifier(ident) => { - let func = self.symbol_table.borrow().lookup_fn(&ident.name).cloned(); - if let Some(func) = func { - let ret = self.assert_and_return_type(func.output, expected, func.span); - - // Check number of function arguments. - if func.input.len() != input.arguments.len() { - self.handler.emit_err( - TypeCheckerError::incorrect_num_args_to_call( - func.input.len(), - input.arguments.len(), - input.span(), - ) - .into(), - ); - } - - // Check function argument types. - func.input - .iter() - .zip(input.arguments.iter()) - .for_each(|(expected, argument)| { - self.visit_expression(argument, &Some(expected.get_variable().type_.clone())); - }); - - Some(ret) - } else { - self.handler - .emit_err(TypeCheckerError::unknown_sym("function", &ident.name, ident.span()).into()); - None - } - } - expr => self.visit_expression(expr, expected), - } - } } diff --git a/compiler/passes/src/type_checker/check_program.rs b/compiler/passes/src/type_checker/check_program.rs index 1a87d44994..95c3c0d1e5 100644 --- a/compiler/passes/src/type_checker/check_program.rs +++ b/compiler/passes/src/type_checker/check_program.rs @@ -52,8 +52,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { self.visit_block(&input.block); if !self.has_return { - self.handler - .emit_err(TypeCheckerError::function_has_no_return(input.name(), input.span()).into()); + self.emit_err(TypeCheckerError::function_has_no_return(input.name(), input.span()).into()); } let prev_st = *self.symbol_table.borrow_mut().parent.take().unwrap(); diff --git a/compiler/passes/src/type_checker/check_statements.rs b/compiler/passes/src/type_checker/check_statements.rs index e3e1778bf4..5aa68701ad 100644 --- a/compiler/passes/src/type_checker/check_statements.rs +++ b/compiler/passes/src/type_checker/check_statements.rs @@ -83,8 +83,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { Some(var.type_.clone()) } else { - self.handler - .emit_err(TypeCheckerError::unknown_sym("variable", var_name.name, var_name.span).into()); + self.emit_err(TypeCheckerError::unknown_sym("variable", var_name.name, var_name.span).into()); None }; diff --git a/compiler/passes/src/type_checker/checker.rs b/compiler/passes/src/type_checker/checker.rs index 40cff35838..e6a240f6f7 100644 --- a/compiler/passes/src/type_checker/checker.rs +++ b/compiler/passes/src/type_checker/checker.rs @@ -85,7 +85,7 @@ impl<'a> TypeChecker<'a> { } /// Emits a type checker error. - fn emit_err(&self, err: TypeCheckerError) { + pub fn emit_err(&self, err: TypeCheckerError) { self.handler.emit_err(err.into()); } diff --git a/leo/errors/src/errors/type_checker/type_checker_error.rs b/leo/errors/src/errors/type_checker/type_checker_error.rs index 1aeeaa1404..4a1ba2d500 100644 --- a/leo/errors/src/errors/type_checker/type_checker_error.rs +++ b/leo/errors/src/errors/type_checker/type_checker_error.rs @@ -271,4 +271,18 @@ create_messages!( msg: format!("Comparison `{operator}` is not supported for the address type."), help: None, } + + @formatted + incorrect_tuple_length { + args: (expected: impl Display, actual: impl Display), + msg: format!("Expected a tuple of length `{expected}` got `{actual}`"), + help: None, + } + + @formatted + invalid_tuple { + args: (), + msg: format!("Tuples must be explicitly typed in Leo"), + help: Some("The function definition must match the function return statement".to_string()), + } );