diff --git a/compiler/ast/src/expression/binary.rs b/compiler/ast/src/expression/binary.rs index 7c75f97433..c266dec7cd 100644 --- a/compiler/ast/src/expression/binary.rs +++ b/compiler/ast/src/expression/binary.rs @@ -15,6 +15,7 @@ // along with the Leo library. If not, see . use super::*; +use leo_span::Symbol; /// A binary operator. /// @@ -23,41 +24,88 @@ use super::*; pub enum BinaryOperation { /// Addition, i.e. `+`, `.add()`. Add, - /// Wrapped addition, i.e. `.add_wrapped()`. + /// Wrapping addition, i.e. `.add_wrapped()`. AddWrapped, - /// Subtraction, i.e. `-`. - Sub, - /// Multiplication, i.e. `*`. - Mul, - /// Division, i.e. `/`. - Div, - /// Exponentiation, i.e. `**` in `a ** b`. - Pow, - /// Logical-or, i.e., `||`. - Or, - /// Logical-and, i.e., `&&`. + /// Bitwise AND, i.e. `&&`, `.and()`. And, - /// Equality relation, i.e., `==`. + /// Division, i.e. `/`, `.div()`. + Div, + /// Wrapping division, i.e. `.div_wrapped()`. + DivWrapped, + /// Equality relation, i.e. `==`, `.eq()`. Eq, - /// In-equality relation, i.e. `!=`. - Ne, - /// Greater-or-equal relation, i.e. `>=`. + /// Greater-or-equal relation, i.e. `>=`, `.ge()`. Ge, - /// Greater-than relation, i.e. `>=`. + /// Greater-than relation, i.e. `>`, `.gt()`. Gt, - /// Lesser-or-equal relation, i.e. `<=`. + /// Lesser-or-equal relation, i.e. `<=`, `.le()`. Le, - /// Lesser-than relation, i.e. `<`. + /// Lesser-than relation, i.e. `<`, `.lt()`. Lt, + /// Multiplication, i.e. `*`, `.mul()`. + Mul, + /// Wrapping multiplication, i.e. `.mul_wrapped()`. + MulWrapped, + /// Boolean NAND, i.e. `.nand()`. + Nand, + /// In-equality relation, i.e. `!=`, `.neq()`. + Neq, + /// Boolean NOR, i.e. `.nor()`. + Nor, + /// Logical-or, i.e. `||`. + Or, + /// Exponentiation, i.e. `**` in `a ** b`, `.pow()`. + Pow, + /// Wrapping exponentiation, i.e. `.pow_wrapped()`. + PowWrapped, + /// Shift left operation, i.e. `<<`, `.shl()`. + Shl, + /// Wrapping shift left operation, i.e. `<<`, `.shl_wrapped()`. + ShlWrapped, + /// Shift right operation, i.e. >>, `.shr()`. + Shr, + /// Wrapping shift right operation, i.e. >>, `.shr_wrapped()`. + ShrWrapped, + /// Subtraction, i.e. `-`, `.sub()`. + Sub, + /// Wrapped subtraction, i.e. `.sub_wrapped()`. + SubWrapped, + /// Bitwise XOR, i.e. `.xor()`. + Xor, } -/// The category a binary operation belongs to. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum BinaryOperationClass { - /// A numeric one, that is, the result is numeric. - Numeric, - /// A boolean one, meaning the result is of type `bool`. - Boolean, +impl BinaryOperation { + /// Returns a `BinaryOperation` from the given `Symbol`. + pub fn from_symbol(symbol: &Symbol) -> Option { + Some(match symbol.as_u32() { + 8 => BinaryOperation::Add, + 9 => BinaryOperation::AddWrapped, + 10 => BinaryOperation::And, + 11 => BinaryOperation::Div, + 12 => BinaryOperation::DivWrapped, + 13 => BinaryOperation::Eq, + 14 => BinaryOperation::Ge, + 15 => BinaryOperation::Gt, + 16 => BinaryOperation::Le, + 17 => BinaryOperation::Lt, + 18 => BinaryOperation::Mul, + 19 => BinaryOperation::MulWrapped, + 20 => BinaryOperation::Nand, + 21 => BinaryOperation::Neq, + 22 => BinaryOperation::Nor, + 23 => BinaryOperation::Or, + 24 => BinaryOperation::Pow, + 25 => BinaryOperation::PowWrapped, + 26 => BinaryOperation::Shl, + 27 => BinaryOperation::ShlWrapped, + 28 => BinaryOperation::Shr, + 29 => BinaryOperation::ShrWrapped, + 30 => BinaryOperation::Sub, + 31 => BinaryOperation::SubWrapped, + 32 => BinaryOperation::Xor, + _ => return None + }) + } } impl AsRef for BinaryOperation { @@ -65,18 +113,29 @@ impl AsRef for BinaryOperation { match self { BinaryOperation::Add => "add", BinaryOperation::AddWrapped => "add_wrapped", - BinaryOperation::Sub => "sub", - BinaryOperation::Mul => "mul", - BinaryOperation::Div => "div", - BinaryOperation::Pow => "pow", - BinaryOperation::Or => "or", BinaryOperation::And => "and", + BinaryOperation::Div => "div", + BinaryOperation::DivWrapped => "div_wrapped", BinaryOperation::Eq => "eq", - BinaryOperation::Ne => "ne", BinaryOperation::Ge => "ge", BinaryOperation::Gt => "gt", BinaryOperation::Le => "le", BinaryOperation::Lt => "lt", + BinaryOperation::Mul => "mul", + BinaryOperation::MulWrapped => "mul_wrapped", + BinaryOperation::Nand => "nand", + BinaryOperation::Neq => "neq", + BinaryOperation::Nor => "nor", + BinaryOperation::Or => "or", + BinaryOperation::Pow => "pow", + BinaryOperation::PowWrapped => "pow_wrapped", + BinaryOperation::Shl => "shl", + BinaryOperation::ShlWrapped => "shl_wrapped", + BinaryOperation::Shr => "shr", + BinaryOperation::ShrWrapped => "shr_wrapped", + BinaryOperation::Sub => "sub", + BinaryOperation::SubWrapped => "sub_wrapped", + BinaryOperation::Xor => "xor", } } } diff --git a/compiler/parser/src/parser/expression.rs b/compiler/parser/src/parser/expression.rs index 2820a2d415..797ef85e63 100644 --- a/compiler/parser/src/parser/expression.rs +++ b/compiler/parser/src/parser/expression.rs @@ -120,7 +120,7 @@ impl ParserContext<'_> { fn eat_bin_op(&mut self, tokens: &[Token]) -> Option { self.eat_any(tokens).then(|| match &self.prev_token.token { Token::Eq => BinaryOperation::Eq, - Token::NotEq => BinaryOperation::Ne, + Token::NotEq => BinaryOperation::Neq, Token::Lt => BinaryOperation::Lt, Token::LtEq => BinaryOperation::Le, Token::Gt => BinaryOperation::Gt, @@ -224,15 +224,8 @@ impl ParserContext<'_> { self.bump(); // Check if the method exists. - let index = method.as_u32(); - - if index <= 1 { - // Binary operators. - let operator = match index { - 0 => BinaryOperation::Add, - 1 => BinaryOperation::AddWrapped, - _ => unimplemented!("throw error for invalid method call"), - }; + if let Some(operator) = BinaryOperation::from_symbol(&method) { + // Handle binary operators. // Parse left parenthesis `(`. self.expect(&Token::LeftParen)?; diff --git a/compiler/passes/src/type_checker/check_expressions.rs b/compiler/passes/src/type_checker/check_expressions.rs index 0e6b3b062f..b343214afe 100644 --- a/compiler/passes/src/type_checker/check_expressions.rs +++ b/compiler/passes/src/type_checker/check_expressions.rs @@ -186,13 +186,20 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> { fn visit_binary(&mut self, input: &'a BinaryExpression, expected: &Self::AdditionalInput) -> Option { if let VisitResult::VisitChildren = self.visitor.visit_binary(input) { return match input.op { - BinaryOperation::And | BinaryOperation::Or => { + BinaryOperation::Nand | BinaryOperation::Nor => { self.visitor.assert_type(Type::Boolean, expected, input.span()); let t1 = self.visit_expression(&input.left, expected); let t2 = self.visit_expression(&input.right, expected); return_incorrect_type(t1, t2, expected) } + BinaryOperation::And | BinaryOperation::Or | BinaryOperation::Xor => { + self.visitor.assert_bool_int_type(expected, input.span()); + let t1 = self.visit_expression(&input.left, expected); + let t2 = self.visit_expression(&input.right, expected); + + return_incorrect_type(t1, t2, expected) + } BinaryOperation::Add => { self.visitor.assert_field_group_scalar_int_type(expected, input.span()); let t1 = self.visit_expression(&input.left, expected); @@ -291,7 +298,7 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> { t1 } - BinaryOperation::Eq | BinaryOperation::Ne => { + BinaryOperation::Eq | BinaryOperation::Neq => { let t1 = self.visit_expression(&input.left, &None); let t2 = self.visit_expression(&input.right, &None); @@ -310,7 +317,15 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> { Some(Type::Boolean) } - BinaryOperation::AddWrapped => { + BinaryOperation::AddWrapped | BinaryOperation::SubWrapped | BinaryOperation::DivWrapped | BinaryOperation::MulWrapped | BinaryOperation::PowWrapped => { + self.visitor.assert_int_type(expected, input.span); + let t1 = self.visit_expression(&input.left, expected); + let t2 = self.visit_expression(&input.right, expected); + + return_incorrect_type(t1, t2, expected) + } + BinaryOperation::Shl| BinaryOperation::ShlWrapped | BinaryOperation::Shr | BinaryOperation::ShrWrapped => { + // todo @collinc97: add magnitude check for second operand (u8, u16, u32). self.visitor.assert_int_type(expected, input.span); let t1 = self.visit_expression(&input.left, expected); let t2 = self.visit_expression(&input.right, expected); diff --git a/compiler/passes/src/type_checker/checker.rs b/compiler/passes/src/type_checker/checker.rs index 547f515f41..c0bad022d2 100644 --- a/compiler/passes/src/type_checker/checker.rs +++ b/compiler/passes/src/type_checker/checker.rs @@ -63,6 +63,8 @@ const fn create_type_superset( superset } +const BOOL_INT_TYPES: [Type; 11] = create_type_superset(INT_TYPES, [Type::Boolean]); + const FIELD_INT_TYPES: [Type; 11] = create_type_superset(INT_TYPES, [Type::Field]); const FIELD_SCALAR_INT_TYPES: [Type; 12] = create_type_superset(FIELD_INT_TYPES, [Type::Scalar]); @@ -136,6 +138,11 @@ impl<'a> TypeChecker<'a> { } } + /// Emits an error to the handler if the given type is not a boolean or an integer. + pub(crate) fn assert_bool_int_type(&self, type_: &Option, span: Span) { + self.assert_one_of_types(type_, &BOOL_INT_TYPES, span) + } + /// Emits an error to the handler if the given type is not an integer. pub(crate) fn assert_int_type(&self, type_: &Option, span: Span) { self.assert_one_of_types(type_, &INT_TYPES, span) diff --git a/leo/span/src/symbol.rs b/leo/span/src/symbol.rs index 2470027927..7422830158 100644 --- a/leo/span/src/symbol.rs +++ b/leo/span/src/symbol.rs @@ -100,13 +100,56 @@ macro_rules! symbols { } symbols! { - // unary operators + // unary operators index 0-7 + abs, + abs_wrapped, + double, + inv, + neg, + not, + square, + sqrt, - // binary operators + // binary operators index 8-32 add, add_wrapped, + and, + div, + div_wrapped, + eq, + ge, + gt, + le, + lt, + mul, + mul_wrapped, + nand, + neq, + nor, + or, + pow, + pow_wrapped, + shl, + shl_wrapped, + shr, + shr_wrapped, + sub, + sub_wrapped, + xor, - // arity three operators + // arity three operators 33-44 + bhp256, + bhp512, + bhp768, + bhp1024, + commit, + hash, + ped64, + ped128, + prf, + psd2, + psd4, + psd8, // types address,