diff --git a/compiler/ast/src/functions/mod.rs b/compiler/ast/src/functions/mod.rs index 463db8e30e..6daee616eb 100644 --- a/compiler/ast/src/functions/mod.rs +++ b/compiler/ast/src/functions/mod.rs @@ -46,8 +46,6 @@ use std::fmt; pub struct Function { /// Annotations on the function. pub annotations: Vec, - /// Is this function asynchronous or synchronous? - pub is_async: bool, /// Is this function a transition, inlined, or a regular function?. pub variant: Variant, /// The function identifier, e.g., `foo` in `function foo(...) { ... }`. @@ -79,7 +77,6 @@ impl Function { #[allow(clippy::too_many_arguments)] pub fn new( annotations: Vec, - is_async: bool, variant: Variant, identifier: Identifier, input: Vec, @@ -100,7 +97,7 @@ impl Function { _ => Type::Tuple(TupleType::new(output.iter().map(get_output_type).collect())), }; - Function { annotations, is_async, variant, identifier, input, output, output_type, block, span, id } + Function { annotations, variant, identifier, input, output, output_type, block, span, id } } /// Returns function name. @@ -114,8 +111,8 @@ impl Function { fn format(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.variant { Variant::Inline => write!(f, "inline ")?, - Variant::Standard => write!(f, "function ")?, - Variant::Transition => write!(f, "transition ")?, + Variant::Function | Variant::AsyncFunction => write!(f, "function ")?, + Variant::Transition | Variant::AsyncTransition => write!(f, "transition ")?, } write!(f, "{}", self.identifier)?; @@ -135,7 +132,6 @@ impl From for Function { fn from(function: FunctionStub) -> Self { Self { annotations: function.annotations, - is_async: function.is_async, variant: function.variant, identifier: function.identifier, input: function.input, diff --git a/compiler/ast/src/functions/variant.rs b/compiler/ast/src/functions/variant.rs index 407246de9b..bbbeeec689 100644 --- a/compiler/ast/src/functions/variant.rs +++ b/compiler/ast/src/functions/variant.rs @@ -16,13 +16,43 @@ use serde::{Deserialize, Serialize}; -/// Functions are always one of three variants. +/// Functions are always one of five variants. /// A transition function is permitted the ability to manipulate records. +/// An asynchronous transition function is a transition function that calls an asynchronous function. /// A regular function is not permitted to manipulate records. +/// An asynchronous function contains on-chain operations. /// An inline function is directly copied at the call site. #[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] pub enum Variant { Inline, - Standard, + Function, Transition, + AsyncTransition, + AsyncFunction, +} + +impl Variant { + /// Returns true if the variant is async. + pub fn is_async(self) -> bool { + match self { + Variant::AsyncFunction | Variant::AsyncTransition => true, + _ => false, + } + } + + /// Returns true if the variant is a transition. + pub fn is_transition(self) -> bool { + match self { + Variant::Transition | Variant::AsyncTransition => true, + _ => false, + } + } + + /// Returns true if the variant is a function. + pub fn is_function(self) -> bool { + match self { + Variant::Function | Variant::AsyncFunction => true, + _ => false, + } + } } diff --git a/compiler/ast/src/passes/reconstructor.rs b/compiler/ast/src/passes/reconstructor.rs index 1b74a1da8d..441c5afcbe 100644 --- a/compiler/ast/src/passes/reconstructor.rs +++ b/compiler/ast/src/passes/reconstructor.rs @@ -475,7 +475,6 @@ pub trait ProgramReconstructor: StatementReconstructor { fn reconstruct_function(&mut self, input: Function) -> Function { Function { annotations: input.annotations, - is_async: input.is_async, variant: input.variant, identifier: input.identifier, input: input.input, diff --git a/compiler/ast/src/stub/function_stub.rs b/compiler/ast/src/stub/function_stub.rs index 3fe6dd1a93..c65c90d463 100644 --- a/compiler/ast/src/stub/function_stub.rs +++ b/compiler/ast/src/stub/function_stub.rs @@ -54,8 +54,6 @@ use std::fmt; pub struct FunctionStub { /// Annotations on the function. pub annotations: Vec, - /// Is this function asynchronous or synchronous? - pub is_async: bool, /// Is this function a transition, inlined, or a regular function?. pub variant: Variant, /// The function identifier, e.g., `foo` in `function foo(...) { ... }`. @@ -109,7 +107,6 @@ impl FunctionStub { FunctionStub { annotations, - is_async, variant, identifier, future_locations: Vec::new(), @@ -137,8 +134,8 @@ impl FunctionStub { fn format(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.variant { Variant::Inline => write!(f, "inline ")?, - Variant::Standard => write!(f, "function ")?, - Variant::Transition => write!(f, "transition ")?, + Variant::Function | Variant::AsyncFunction => write!(f, "function ")?, + Variant::Transition | Variant::AsyncTransition => write!(f, "transition ")?, } write!(f, "{}", self.identifier)?; @@ -218,8 +215,10 @@ impl FunctionStub { Self { annotations: Vec::new(), - is_async: function.finalize_logic().is_some(), - variant: Variant::Transition, + variant: match function.finalize_logic().is_some() { + true => Variant::AsyncTransition, + false => Variant::Transition, + }, identifier: Identifier::from(function.name()), future_locations: Vec::new(), input: function @@ -281,8 +280,7 @@ impl FunctionStub { ) -> Self { Self { annotations: Vec::new(), - is_async: true, - variant: Variant::Standard, + variant: Variant::AsyncFunction, identifier: Identifier::new(name, Default::default()), future_locations: function .finalize_logic() @@ -291,7 +289,7 @@ impl FunctionStub { .iter() .filter_map(|input| match input.finalize_type() { FinalizeType::Future(val) => Some(Location::new( - Identifier::from(val.program_id().name()).name, + Some(Identifier::from(val.program_id().name()).name), Symbol::intern(&format!("finalize/{}", val.resource())), )), _ => None, @@ -361,8 +359,7 @@ impl FunctionStub { }; Self { annotations: Vec::new(), - is_async: false, - variant: Variant::Standard, + variant: Variant::Function, identifier: Identifier::from(closure.name()), future_locations: Vec::new(), input: closure @@ -397,7 +394,6 @@ impl From for FunctionStub { fn from(function: Function) -> Self { Self { annotations: function.annotations, - is_async: function.is_async, variant: function.variant, identifier: function.identifier, future_locations: Vec::new(), diff --git a/compiler/parser/src/parser/file.rs b/compiler/parser/src/parser/file.rs index b24281d85f..1cffe90369 100644 --- a/compiler/parser/src/parser/file.rs +++ b/compiler/parser/src/parser/file.rs @@ -138,7 +138,7 @@ impl ParserContext<'_> { let (id, function) = self.parse_function()?; // Partition into transitions and functions so that don't have to sort later. - if function.variant == Variant::Transition { + if function.variant.is_transition() { transitions.push((id, function)); } else { functions.push((id, function)); @@ -409,10 +409,12 @@ impl ParserContext<'_> { let (is_async, start_async) = if self.token.token == Token::Async { (true, self.expect(&Token::Async)?) } else { (false, Span::dummy()) }; // Parse ` IDENT`, where `` is `function`, `transition`, or `inline`. - let (variant, start) = match self.token.token { + let (variant, start) = match self.token.token.clone() { Token::Inline => (Variant::Inline, self.expect(&Token::Inline)?), - Token::Function => (Variant::Standard, self.expect(&Token::Function)?), - Token::Transition => (Variant::Transition, self.expect(&Token::Transition)?), + Token::Function => { + (if is_async { Variant::AsyncFunction } else { Variant::Function }, self.expect(&Token::Function)?) + } + Token::Transition => (if is_async { Variant::AsyncTransition } else { Variant::Transition }, self.expect(&Token::Transition)?), _ => self.unexpected("'function', 'transition', or 'inline'")?, }; let name = self.expect_identifier()?; @@ -450,7 +452,6 @@ impl ParserContext<'_> { name.name, Function::new( annotations, - is_async, variant, name, inputs, diff --git a/compiler/passes/src/code_generation/visit_expressions.rs b/compiler/passes/src/code_generation/visit_expressions.rs index 54bd72e35a..4acadab384 100644 --- a/compiler/passes/src/code_generation/visit_expressions.rs +++ b/compiler/passes/src/code_generation/visit_expressions.rs @@ -14,8 +14,35 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . -use crate::{CodeGenerator}; -use leo_ast::{AccessExpression, ArrayAccess, ArrayExpression, AssociatedConstant, AssociatedFunction, BinaryExpression, BinaryOperation, CallExpression, CastExpression, ErrExpression, Expression, Identifier, Literal, Location, LocatorExpression, MemberAccess, MethodCall, Node, StructExpression, TernaryExpression, TupleExpression, Type, UnaryExpression, UnaryOperation, UnitExpression, Variant}; +use crate::CodeGenerator; +use leo_ast::{ + AccessExpression, + ArrayAccess, + ArrayExpression, + AssociatedConstant, + AssociatedFunction, + BinaryExpression, + BinaryOperation, + CallExpression, + CastExpression, + ErrExpression, + Expression, + Identifier, + Literal, + Location, + LocatorExpression, + MemberAccess, + MethodCall, + Node, + StructExpression, + TernaryExpression, + TupleExpression, + Type, + UnaryExpression, + UnaryOperation, + UnitExpression, + Variant, +}; use leo_span::sym; use std::borrow::Borrow; @@ -513,7 +540,7 @@ impl<'a> CodeGenerator<'a> { } else { // Lookup in symbol table to determine if its an async function. if let Some(func) = self.symbol_table.lookup_fn_symbol(Location::new(input.program, function_name)) { - if func.is_async && input.program.unwrap() == self.program_id.unwrap().name.name { + if func.variant.is_async() && input.program.unwrap() == self.program_id.unwrap().name.name { format!(" async {}", self.current_function.unwrap().identifier) } else { format!(" call {}", input.function) @@ -534,8 +561,7 @@ impl<'a> CodeGenerator<'a> { let mut destinations = Vec::new(); // Create operands for the output registers. - let func = - &self.symbol_table.lookup_fn_symbol(Location::new(Some(main_program), function_name)).unwrap(); + let func = &self.symbol_table.lookup_fn_symbol(Location::new(Some(main_program), function_name)).unwrap(); match func.output_type.clone() { Type::Unit => {} // Do nothing Type::Tuple(tuple) => match tuple.length() { @@ -556,7 +582,7 @@ impl<'a> CodeGenerator<'a> { } // Add a register for async functions to represent the future created. - if func.is_async && func.variant == Variant::Standard { + if func.variant == Variant::AsyncFunction { let destination_register = format!("r{}", self.next_register); destinations.push(destination_register); self.next_register += 1; diff --git a/compiler/passes/src/code_generation/visit_program.rs b/compiler/passes/src/code_generation/visit_program.rs index 8b58ee82d7..395c8d7bfa 100644 --- a/compiler/passes/src/code_generation/visit_program.rs +++ b/compiler/passes/src/code_generation/visit_program.rs @@ -16,7 +16,7 @@ use crate::CodeGenerator; -use leo_ast::{functions, Composite, Function, Mapping, Mode, Program, ProgramScope, Type, Variant, Location}; +use leo_ast::{functions, Composite, Function, Location, Mapping, Mode, Program, ProgramScope, Type, Variant}; use indexmap::IndexMap; use itertools::Itertools; @@ -84,7 +84,7 @@ impl<'a> CodeGenerator<'a> { .functions .iter() .map(|(_, function)| { - if !(function.is_async && function.variant == Variant::Standard) { + if function.variant != Variant::AsyncFunction { // Set the `is_transition_function` flag. self.is_transition_function = matches!(function.variant, Variant::Transition); @@ -94,15 +94,15 @@ impl<'a> CodeGenerator<'a> { self.is_transition_function = false; // Attach the associated finalize to async transitions. - if function.variant == Variant::Transition && function.is_async { + if function.variant == Variant::AsyncTransition { // Set state variables. self.is_transition_function = false; self.finalize_caller = Some(function.identifier.name.clone()); // Generate code for the associated finalize function. let finalize = &self .symbol_table - .lookup_fn_symbol( - Location::new(Some(self.program_id.unwrap().name.name), + .lookup_fn_symbol(Location::new( + Some(self.program_id.unwrap().name.name), function.identifier.name, )) .unwrap() @@ -178,7 +178,7 @@ impl<'a> CodeGenerator<'a> { // Initialize the state of `self` with the appropriate values before visiting `function`. self.next_register = 0; self.variable_mapping = IndexMap::new(); - self.in_finalize = function.is_async && function.variant == Variant::Standard; + self.in_finalize = function.variant == Variant::AsyncFunction; // TODO: Figure out a better way to initialize. self.variable_mapping.insert(&sym::SelfLower, "self".to_string()); self.variable_mapping.insert(&sym::block, "block".to_string()); @@ -188,11 +188,11 @@ impl<'a> CodeGenerator<'a> { // If a function is a program function, generate an Aleo `function`, // if it is a standard function generate an Aleo `closure`, // otherwise, it is an inline function, in which case a function should not be generated. - let mut function_string = match (function.is_async, function.variant) { - (_, Variant::Transition) => format!("\nfunction {}:\n", function.identifier), - (false, Variant::Standard) => format!("\nclosure {}:\n", function.identifier), - (true, Variant::Standard) => format!("\nfinalize {}:\n", self.finalize_caller.unwrap()), - (_, Variant::Inline) => return String::from("\n"), + let mut function_string = match function.variant { + Variant::Transition | Variant::AsyncTransition => format!("\nfunction {}:\n", function.identifier), + Variant::Function => format!("\nclosure {}:\n", function.identifier), + Variant::AsyncFunction => format!("\nfinalize {}:\n", self.finalize_caller.unwrap()), + Variant::Inline => return String::from("\n"), }; // Construct and append the input declarations of the function. diff --git a/compiler/passes/src/common/symbol_table/function_symbol.rs b/compiler/passes/src/common/symbol_table/function_symbol.rs index 2b49feee76..644b590ee3 100644 --- a/compiler/passes/src/common/symbol_table/function_symbol.rs +++ b/compiler/passes/src/common/symbol_table/function_symbol.rs @@ -26,8 +26,6 @@ use crate::SymbolTable; pub struct FunctionSymbol { /// The index associated with the scope in the parent symbol table. pub(crate) id: usize, - /// Whether the function is asynchronous or not. - pub(crate) is_async: bool, /// The output type of the function. pub(crate) output_type: Type, /// Is this function a transition, inlined, or a regular function?. @@ -46,7 +44,6 @@ impl SymbolTable { pub(crate) fn new_function_symbol(id: usize, func: &Function) -> FunctionSymbol { FunctionSymbol { id, - is_async: func.is_async, output_type: func.output_type.clone(), variant: func.variant, _span: func.span, diff --git a/compiler/passes/src/common/symbol_table/mod.rs b/compiler/passes/src/common/symbol_table/mod.rs index f59fabdbbe..c936041048 100644 --- a/compiler/passes/src/common/symbol_table/mod.rs +++ b/compiler/passes/src/common/symbol_table/mod.rs @@ -247,7 +247,6 @@ mod tests { let func_loc = Location::new(Some(Symbol::intern("credits")), Symbol::intern("transfer_public")); let insert = Function { annotations: Vec::new(), - is_async: false, id: 0, output_type: Type::Address, variant: Variant::Inline, diff --git a/compiler/passes/src/dead_code_elimination/eliminate_program.rs b/compiler/passes/src/dead_code_elimination/eliminate_program.rs index 8887a49038..2a760912a5 100644 --- a/compiler/passes/src/dead_code_elimination/eliminate_program.rs +++ b/compiler/passes/src/dead_code_elimination/eliminate_program.rs @@ -29,7 +29,6 @@ impl ProgramReconstructor for DeadCodeEliminator<'_> { Function { annotations: input.annotations, - is_async: input.is_async, variant: input.variant, identifier: input.identifier, input: input.input, diff --git a/compiler/passes/src/flattening/flatten_expression.rs b/compiler/passes/src/flattening/flatten_expression.rs index d337d49999..22ec8aa891 100644 --- a/compiler/passes/src/flattening/flatten_expression.rs +++ b/compiler/passes/src/flattening/flatten_expression.rs @@ -14,9 +14,19 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . -use crate::{Flattener}; +use crate::Flattener; -use leo_ast::{Expression, ExpressionReconstructor, Location, Node, Statement, StructExpression, StructVariableInitializer, TernaryExpression, Type}; +use leo_ast::{ + Expression, + ExpressionReconstructor, + Location, + Node, + Statement, + StructExpression, + StructVariableInitializer, + TernaryExpression, + Type, +}; impl ExpressionReconstructor for Flattener<'_> { type AdditionalOutput = Vec; diff --git a/compiler/passes/src/flattening/flatten_program.rs b/compiler/passes/src/flattening/flatten_program.rs index 149fce8885..97bc356446 100644 --- a/compiler/passes/src/flattening/flatten_program.rs +++ b/compiler/passes/src/flattening/flatten_program.rs @@ -32,7 +32,6 @@ impl ProgramReconstructor for Flattener<'_> { Function { annotations: function.annotations, - is_async: function.is_async, variant: function.variant, identifier: function.identifier, input: function.input, diff --git a/compiler/passes/src/function_inlining/inline_expression.rs b/compiler/passes/src/function_inlining/inline_expression.rs index ecdfbf1fbe..7bcd25ef23 100644 --- a/compiler/passes/src/function_inlining/inline_expression.rs +++ b/compiler/passes/src/function_inlining/inline_expression.rs @@ -53,7 +53,6 @@ impl ExpressionReconstructor for FunctionInliner<'_> { // Inline the callee function, if required, otherwise, return the call expression. match callee.variant { - Variant::Transition | Variant::Standard => (Expression::Call(input), Default::default()), Variant::Inline => { // Construct a mapping from input variables of the callee function to arguments passed to the callee. let parameter_to_argument = callee @@ -103,6 +102,7 @@ impl ExpressionReconstructor for FunctionInliner<'_> { (result, inlined_statements) } + _ => (Expression::Call(input), Default::default()), } } } diff --git a/compiler/passes/src/loop_unrolling/unroll_program.rs b/compiler/passes/src/loop_unrolling/unroll_program.rs index 7ec8609b41..b40361fbaa 100644 --- a/compiler/passes/src/loop_unrolling/unroll_program.rs +++ b/compiler/passes/src/loop_unrolling/unroll_program.rs @@ -16,7 +16,7 @@ use leo_ast::*; -use crate::{Unroller}; +use crate::Unroller; impl ProgramReconstructor for Unroller<'_> { fn reconstruct_stub(&mut self, input: Stub) -> Stub { @@ -92,7 +92,6 @@ impl ProgramReconstructor for Unroller<'_> { // Reconstruct the function block. let reconstructed_function = Function { - is_async: function.is_async, annotations: function.annotations, variant: function.variant, identifier: function.identifier, diff --git a/compiler/passes/src/static_single_assignment/rename_expression.rs b/compiler/passes/src/static_single_assignment/rename_expression.rs index aeab2773fc..e24a5eb295 100644 --- a/compiler/passes/src/static_single_assignment/rename_expression.rs +++ b/compiler/passes/src/static_single_assignment/rename_expression.rs @@ -14,9 +14,33 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . -use crate::{StaticSingleAssigner}; +use crate::StaticSingleAssigner; -use leo_ast::{AccessExpression, ArrayAccess, ArrayExpression, AssociatedFunction, BinaryExpression, CallExpression, CastExpression, Composite, Expression, ExpressionConsumer, Identifier, Literal, Location, LocatorExpression, MemberAccess, Statement, StructExpression, StructVariableInitializer, TernaryExpression, TupleAccess, TupleExpression, UnaryExpression, UnitExpression}; +use leo_ast::{ + AccessExpression, + ArrayAccess, + ArrayExpression, + AssociatedFunction, + BinaryExpression, + CallExpression, + CastExpression, + Composite, + Expression, + ExpressionConsumer, + Identifier, + Literal, + Location, + LocatorExpression, + MemberAccess, + Statement, + StructExpression, + StructVariableInitializer, + TernaryExpression, + TupleAccess, + TupleExpression, + UnaryExpression, + UnitExpression, +}; use leo_span::{sym, Symbol}; use indexmap::IndexMap; diff --git a/compiler/passes/src/static_single_assignment/rename_program.rs b/compiler/passes/src/static_single_assignment/rename_program.rs index a7c49f70b5..0f09713cdc 100644 --- a/compiler/passes/src/static_single_assignment/rename_program.rs +++ b/compiler/passes/src/static_single_assignment/rename_program.rs @@ -81,7 +81,6 @@ impl FunctionConsumer for StaticSingleAssigner<'_> { Function { annotations: function.annotations, - is_async: function.is_async, variant: function.variant, identifier: function.identifier, input: function.input, diff --git a/compiler/passes/src/type_checking/check_expressions.rs b/compiler/passes/src/type_checking/check_expressions.rs index 6baf3d7f1c..a7a36c8940 100644 --- a/compiler/passes/src/type_checking/check_expressions.rs +++ b/compiler/passes/src/type_checking/check_expressions.rs @@ -14,7 +14,7 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . -use crate::{TypeChecker}; +use crate::TypeChecker; use leo_ast::*; use leo_errors::{emitter::Handler, TypeCheckerError}; @@ -24,10 +24,10 @@ use itertools::Itertools; use leo_ast::{ CoreFunction::FutureAwait, Type::{Future, Tuple}, - Variant::Standard, }; use snarkvm::console::network::{MainnetV0, Network}; use std::str::FromStr; +use leo_ast::Variant::{Transition, Function, AsyncFunction, AsyncTransition}; fn return_incorrect_type(t1: Option, t2: Option, expected: &Option) -> Option { match (t1, t2) { @@ -101,7 +101,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { // Check core struct name and function. if let Some(core_instruction) = self.get_core_function_call(&access.variant, &access.name) { // Check that operation is not restricted to finalize blocks. - if !self.scope_state.is_finalize && core_instruction.is_finalize_command() { + if self.scope_state.variant != Some(Variant::AsyncFunction) && core_instruction.is_finalize_command() { self.emit_err(TypeCheckerError::operation_must_be_in_finalize_block(input.span())); } @@ -142,7 +142,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { self.get_core_function_call(&Identifier::new(sym::Future, Default::default()), &call.name) { // Check that operation is not restricted to finalize blocks. - if !self.scope_state.is_finalize && core_instruction.is_finalize_command() { + if self.scope_state.variant != Some(AsyncFunction) && core_instruction.is_finalize_command() { self.emit_err(TypeCheckerError::operation_must_be_in_finalize_block(input.span())); } @@ -222,7 +222,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { Expression::Identifier(identifier) if identifier.name == sym::SelfLower => match access.name.name { sym::caller => { // Check that the operation is not invoked in a `finalize` block. - if self.scope_state.is_finalize { + if self.scope_state.variant == Some(Variant::AsyncFunction) { self.handler.emit_err(TypeCheckerError::invalid_operation_inside_finalize( "self.caller", access.name.span(), @@ -232,7 +232,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { } sym::signer => { // Check that operation is not invoked in a `finalize` block. - if self.scope_state.is_finalize { + if self.scope_state.variant == Some(Variant::AsyncFunction) { self.handler.emit_err(TypeCheckerError::invalid_operation_inside_finalize( "self.signer", access.name.span(), @@ -248,7 +248,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { Expression::Identifier(identifier) if identifier.name == sym::block => match access.name.name { sym::height => { // Check that the operation is invoked in a `finalize` block. - if !self.scope_state.is_finalize { + if self.scope_state.variant != Some(Variant::AsyncFunction) { self.handler.emit_err(TypeCheckerError::invalid_operation_outside_finalize( "block.height", access.name.span(), @@ -636,22 +636,11 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { // Check that the call is valid. // Note that this unwrap is safe since we always set the variant before traversing the body of the function. match self.scope_state.variant.unwrap() { - // If the function is not a transition function, it can only call "inline" functions. - Variant::Inline | Variant::Standard => { - if !matches!(func.variant, Variant::Inline) { - self.emit_err(TypeCheckerError::can_only_call_inline_function(input.span)); - } - } - // If the function is a transition function, then check that the call is not to another local transition function. - Variant::Transition => { - if matches!(func.variant, Variant::Transition) - && input.program.unwrap() == self.scope_state.program_name.unwrap() - { - self.emit_err(TypeCheckerError::cannot_invoke_call_to_local_transition_function( - input.span, - )); - } - } + Variant::AsyncFunction | Variant::Function if !matches!(func.variant, Variant::Inline) => self.emit_err(TypeCheckerError::can_only_call_inline_function(input.span)), + Variant::Transition | Variant::AsyncTransition if matches!(func.variant, Variant::Transition) && input.program.unwrap() == self.scope_state.program_name.unwrap() => self.emit_err(TypeCheckerError::cannot_invoke_call_to_local_transition_function( + input.span, + )), + _ => {} } // Check that the call is not to an external `inline` function. @@ -661,7 +650,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { self.emit_err(TypeCheckerError::cannot_call_external_inline_function(input.span)); } // Async functions return a single future. - let mut ret = if func.is_async && func.variant == Standard { + let mut ret = if func.variant == AsyncFunction { if let Some(Type::Future(_)) = expected { Type::Future(FutureType::new(Vec::new())) } else { @@ -687,7 +676,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { func.input.iter().zip(input.arguments.iter()).for_each(|(expected, argument)| { let ty = self.visit_expression(argument, &Some(expected.type_())); // Extract information about futures that are being consumed. - if func.is_async && func.variant == Standard && matches!(expected.type_(), Type::Future(_)) { + if func.variant == AsyncFunction && matches!(expected.type_(), Type::Future(_)) { match argument { Expression::Identifier(_) | Expression::Call(_) @@ -732,20 +721,20 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { } // Propagate futures from async functions and transitions. - if func.is_async { + if func.variant.is_async() { // Cannot have async calls in a conditional block. if self.scope_state.is_conditional { self.emit_err(TypeCheckerError::async_call_in_conditional(input.span)); } // Can only call async functions and external async transitions from an async transition body. - if !self.scope_state.is_async_transition { + if self.scope_state.variant != Some(AsyncTransition) { self.emit_err(TypeCheckerError::async_call_can_only_be_done_from_async_transition( input.span, )); } - if func.variant == Variant::Transition { + if func.variant.is_transition() { // Cannot call an external async transition after having called the async function. if self.scope_state.has_called_finalize { self.emit_err(TypeCheckerError::external_transition_call_must_be_before_finalize( @@ -776,7 +765,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { ret } }; - } else if func.variant == Variant::Standard { + } else if func.variant.is_function() { // Can only call an async function once in a transition function body. if self.scope_state.has_called_finalize { self.emit_err(TypeCheckerError::must_call_finalize_once(input.span)); @@ -837,8 +826,11 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { } fn visit_struct_init(&mut self, input: &'a StructExpression, additional: &Self::AdditionalInput) -> Self::Output { - let struct_ = - self.symbol_table.borrow().lookup_struct(Location::new(self.scope_state.program_name, input.name.name)).cloned(); + let struct_ = self + .symbol_table + .borrow() + .lookup_struct(Location::new(self.scope_state.program_name, input.name.name)) + .cloned(); if let Some(struct_) = struct_ { // Check struct type name. let ret = self.check_expected_struct(&struct_, additional, input.name.span()); @@ -886,7 +878,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { fn visit_identifier(&mut self, input: &'a Identifier, expected: &Self::AdditionalInput) -> Self::Output { if let Some(var) = self.symbol_table.borrow().lookup_variable(Location::new(None, input.name)) { if matches!(var.type_, Type::Future(_)) && matches!(expected, Some(Type::Future(_))) { - if self.scope_state.is_async_transition && self.scope_state.is_call { + if self.scope_state.variant == Some(AsyncTransition) && self.scope_state.is_call { // Consume future. match self.scope_state.futures.remove(&input.name) { Some(future) => { diff --git a/compiler/passes/src/type_checking/check_program.rs b/compiler/passes/src/type_checking/check_program.rs index 178e10c649..a8b955c914 100644 --- a/compiler/passes/src/type_checking/check_program.rs +++ b/compiler/passes/src/type_checking/check_program.rs @@ -27,6 +27,7 @@ use leo_ast::{ Type::Future, }; use std::collections::HashSet; +use leo_ast::Variant::{AsyncFunction, AsyncTransition}; // TODO: Cleanup logic for tuples. @@ -88,7 +89,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { let scope_index = self.create_child_scope(); // Create future stubs. - if input.variant == Variant::Standard && input.is_async { + if input.variant == Variant::AsyncFunction { let finalize_input_map = &mut self.finalize_input_types; let mut future_stubs = input.future_locations.clone(); let resolved_inputs: Vec = input @@ -302,7 +303,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { } // Set type checker variables for function variant details. - self.scope_state.initialize_function_state(function.variant, function.is_async); + self.scope_state.initialize_function_state(function.variant); // Lookup function metadata in the symbol table. // Note that this unwrap is safe since function metadata is stored in a prior pass. @@ -328,7 +329,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { // Query helper function to type check function parameters and outputs. self.check_function_signature(function); - if self.scope_state.is_finalize { + if self.scope_state.variant == Some(Variant::AsyncFunction) { // Async functions cannot have empty blocks if function.block.statements.is_empty() { self.emit_err(TypeCheckerError::finalize_block_must_not_be_empty(function.block.span)); @@ -367,12 +368,12 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { self.exit_scope(function_index); // Make sure that async transitions call finalize. - if self.scope_state.is_async_transition && !self.scope_state.has_called_finalize { + if self.scope_state.variant == Some(AsyncTransition) && !self.scope_state.has_called_finalize { self.emit_err(TypeCheckerError::async_transition_must_call_async_function(function.span)); } // Check that all futures were awaited exactly once. - if self.scope_state.is_finalize { + if self.scope_state.variant == Some(AsyncFunction) { // Throw error if not all futures awaits even appear once. if !self.await_checker.static_to_await.is_empty() { self.emit_err(TypeCheckerError::future_awaits_missing( diff --git a/compiler/passes/src/type_checking/check_statements.rs b/compiler/passes/src/type_checking/check_statements.rs index c7962badbf..9cc4f3946f 100644 --- a/compiler/passes/src/type_checking/check_statements.rs +++ b/compiler/passes/src/type_checking/check_statements.rs @@ -120,7 +120,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { // Create scope for checking awaits in `then` branch of conditional. let current_bst_nodes: Vec = - match self.await_checker.create_then_scope(self.scope_state.is_finalize, input.span) { + match self.await_checker.create_then_scope(self.scope_state.variant == Some(Variant::AsyncFunction), input.span) { Ok(nodes) => nodes, Err(err) => return self.emit_err(err), }; @@ -132,7 +132,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { then_block_has_return = self.scope_state.has_return; // Exit scope for checking awaits in `then` branch of conditional. - let saved_paths = self.await_checker.exit_then_scope(self.scope_state.is_finalize, current_bst_nodes); + let saved_paths = self.await_checker.exit_then_scope(self.scope_state.variant == Some(Variant::AsyncFunction), current_bst_nodes); if let Some(otherwise) = &input.otherwise { // Set the `has_return` flag for the otherwise-block. @@ -152,7 +152,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { } // Update the set of all possible BST paths. - self.await_checker.exit_statement_scope(self.scope_state.is_finalize, saved_paths); + self.await_checker.exit_statement_scope(self.scope_state.variant == Some(Variant::AsyncFunction), saved_paths); // Restore the previous `has_return` flag. self.scope_state.has_return = previous_has_return || (then_block_has_return && otherwise_block_has_return); @@ -385,17 +385,18 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { fn visit_return(&mut self, input: &'a ReturnStatement) { // Cannot return anything from finalize. - if self.scope_state.is_finalize { + if self.scope_state.variant == Some(Variant::AsyncFunction) { self.emit_err(TypeCheckerError::return_in_finalize(input.span())); } // We can safely unwrap all self.parent instances because // statements should always have some parent block let parent = self.scope_state.function.unwrap(); - let func = self.symbol_table.borrow().lookup_fn_symbol(Location::new(self.scope_state.program_name, parent)).cloned(); + let func = + self.symbol_table.borrow().lookup_fn_symbol(Location::new(self.scope_state.program_name, parent)).cloned(); let mut return_type = func.clone().map(|f| f.output_type.clone()); // Fully type the expected return value. - if self.scope_state.is_async_transition && self.scope_state.has_called_finalize { + if self.scope_state.variant == Some(Variant::AsyncTransition) && self.scope_state.has_called_finalize { let inferred_future_type = match self.finalize_input_types.get(&func.unwrap().finalize.clone().unwrap()) { Some(types) => Future(FutureType::new(types.clone())), None => { diff --git a/compiler/passes/src/type_checking/checker.rs b/compiler/passes/src/type_checking/checker.rs index 5fda2f14a0..6332a35645 100644 --- a/compiler/passes/src/type_checking/checker.rs +++ b/compiler/passes/src/type_checking/checker.rs @@ -47,6 +47,7 @@ use leo_ast::{ Type::{Future, Tuple}, }; use std::cell::RefCell; +use leo_ast::Variant::AsyncTransition; pub struct TypeChecker<'a> { /// The symbol table for the program. @@ -975,7 +976,7 @@ impl<'a> TypeChecker<'a> { } CoreFunction::MappingGet => { // Check that the operation is invoked in a `finalize` block. - if !self.scope_state.is_finalize { + if self.scope_state.variant != Some(Variant::AsyncFunction) { self.handler .emit_err(TypeCheckerError::invalid_operation_outside_finalize("Mapping::get", function_span)) } @@ -991,7 +992,7 @@ impl<'a> TypeChecker<'a> { } CoreFunction::MappingGetOrUse => { // Check that the operation is invoked in a `finalize` block. - if !self.scope_state.is_finalize { + if self.scope_state.variant != Some(Variant::AsyncFunction) { self.handler.emit_err(TypeCheckerError::invalid_operation_outside_finalize( "Mapping::get_or", function_span, @@ -1011,7 +1012,7 @@ impl<'a> TypeChecker<'a> { } CoreFunction::MappingSet => { // Check that the operation is invoked in a `finalize` block. - if !self.scope_state.is_finalize { + if self.scope_state.variant != Some(Variant::AsyncFunction) { self.handler .emit_err(TypeCheckerError::invalid_operation_outside_finalize("Mapping::set", function_span)) } @@ -1033,7 +1034,7 @@ impl<'a> TypeChecker<'a> { } CoreFunction::MappingRemove => { // Check that the operation is invoked in a `finalize` block. - if !self.scope_state.is_finalize { + if self.scope_state.variant != Some(Variant::AsyncFunction) { self.handler.emit_err(TypeCheckerError::invalid_operation_outside_finalize( "Mapping::remove", function_span, @@ -1056,7 +1057,7 @@ impl<'a> TypeChecker<'a> { } CoreFunction::MappingContains => { // Check that the operation is invoked in a `finalize` block. - if !self.scope_state.is_finalize { + if self.scope_state.variant != Some(Variant::AsyncFunction) { self.handler.emit_err(TypeCheckerError::invalid_operation_outside_finalize( "Mapping::contains", function_span, @@ -1266,7 +1267,7 @@ impl<'a> TypeChecker<'a> { self.scope_state.variant = Some(function.variant); // Special type checking for finalize blocks. Can skip for stubs. - if self.scope_state.is_finalize & !self.scope_state.is_stub { + if self.scope_state.variant == Some(Variant::AsyncFunction) && !self.scope_state.is_stub { // Finalize functions are not allowed to return values. if !function.output.is_empty() { self.emit_err(TypeCheckerError::finalize_function_cannot_return_value(function.span())); @@ -1335,7 +1336,7 @@ impl<'a> TypeChecker<'a> { } // Check that the finalize input parameter is not constant or private. - if self.scope_state.is_finalize + if self.scope_state.variant == Some(Variant::AsyncFunction) && (input_var.mode() == Mode::Constant || input_var.mode() == Mode::Private) && (input_var.mode() == Mode::Constant || input_var.mode() == Mode::Private) { @@ -1345,11 +1346,11 @@ impl<'a> TypeChecker<'a> { // Note that this unwrap is safe since we assign to `self.variant` above. match self.scope_state.variant.unwrap() { // If the function is a transition function, then check that the parameter mode is not a constant. - Variant::Transition if input_var.mode() == Mode::Constant => { + Variant::Transition | Variant::AsyncTransition if input_var.mode() == Mode::Constant => { self.emit_err(TypeCheckerError::transition_function_inputs_cannot_be_const(input_var.span())) } // If the function is not a transition function, then check that the parameters do not have an associated mode. - Variant::Standard | Variant::Inline if input_var.mode() != Mode::None => { + Variant::Function | Variant::AsyncFunction | Variant::Inline if input_var.mode() != Mode::None => { self.emit_err(TypeCheckerError::regular_function_inputs_cannot_have_modes(input_var.span())) } _ => {} // Do nothing. @@ -1357,8 +1358,9 @@ impl<'a> TypeChecker<'a> { // Add function inputs to the symbol table. Futures have already been added. if !matches!(&input_var.type_(), &Type::Future(_)) { - if let Err(err) = - self.symbol_table.borrow_mut().insert_variable(Location::new(None, input_var.identifier().name), VariableSymbol { + if let Err(err) = self.symbol_table.borrow_mut().insert_variable( + Location::new(None, input_var.identifier().name), + VariableSymbol { type_: input_var.type_(), span: input_var.identifier().span(), declaration: VariableType::Input(input_var.mode()), @@ -1409,7 +1411,7 @@ impl<'a> TypeChecker<'a> { self.emit_err(TypeCheckerError::cannot_have_constant_output_mode(function_output.span)); } // Async transitions must return exactly one future, and it must be in the last position. - if self.scope_state.is_async_transition + if self.scope_state.variant == Some(AsyncTransition) && ((index < function.output.len() - 1 && matches!(function_output.type_, Type::Future(_))) || (index == function.output.len() - 1 && !matches!(function_output.type_, Type::Future(_)))) @@ -1488,11 +1490,13 @@ impl<'a> TypeChecker<'a> { type_ }; // Insert the variable into the symbol table. - if let Err(err) = self.symbol_table.borrow_mut().insert_variable(Location::new(None, name.name), VariableSymbol { - type_: ty, - span, - declaration: VariableType::Mut, - }) { + if let Err(err) = + self.symbol_table.borrow_mut().insert_variable(Location::new(None, name.name), VariableSymbol { + type_: ty, + span, + declaration: VariableType::Mut, + }) + { self.handler.emit_err(err); } } diff --git a/compiler/passes/src/type_checking/scope_state.rs b/compiler/passes/src/type_checking/scope_state.rs index 885c525903..965b981536 100644 --- a/compiler/passes/src/type_checking/scope_state.rs +++ b/compiler/passes/src/type_checking/scope_state.rs @@ -62,10 +62,8 @@ impl ScopeState { } /// Initialize state variables for new function. - pub fn initialize_function_state(&mut self, variant: Variant, is_async: bool) { + pub fn initialize_function_state(&mut self, variant: Variant) { self.variant = Some(variant); - self.is_finalize = variant == Variant::Standard && is_async; - self.is_async_transition = variant == Variant::Transition && is_async; self.has_called_finalize = false; self.futures = IndexMap::new(); }