From 5a499937e60dea3f81a83b6f5aa82f21a85cd87c Mon Sep 17 00:00:00 2001 From: evan-schott <53463459+evan-schott@users.noreply.github.com> Date: Mon, 19 Feb 2024 18:28:33 -0800 Subject: [PATCH] Function signature checking w/ nested struct/future --- compiler/parser/src/parser/file.rs | 12 ++- .../common/symbol_table/function_symbol.rs | 2 +- .../src/type_checking/check_expressions.rs | 4 +- .../passes/src/type_checking/check_program.rs | 38 ++++++--- .../src/type_checking/check_statements.rs | 66 ++------------- compiler/passes/src/type_checking/checker.rs | 80 ++++++++++++++----- 6 files changed, 105 insertions(+), 97 deletions(-) diff --git a/compiler/parser/src/parser/file.rs b/compiler/parser/src/parser/file.rs index 87c4777708..a152a77de8 100644 --- a/compiler/parser/src/parser/file.rs +++ b/compiler/parser/src/parser/file.rs @@ -113,7 +113,8 @@ impl ParserContext<'_> { // Parse the body of the program scope. let mut consts: Vec<(Symbol, ConstDeclaration)> = Vec::new(); - let (mut transitions, mut functions): (Vec<(Symbol, Function)>, Vec<(Symbol, Function)>) = (Vec::new(), Vec::new()); + let (mut transitions, mut functions): (Vec<(Symbol, Function)>, Vec<(Symbol, Function)>) = + (Vec::new(), Vec::new()); let mut structs: Vec<(Symbol, Composite)> = Vec::new(); let mut mappings: Vec<(Symbol, Mapping)> = Vec::new(); @@ -160,7 +161,14 @@ impl ParserContext<'_> { // Parse `}`. let end = self.expect(&Token::RightCurly)?; - Ok(ProgramScope { program_id, consts, functions: [transitions, functions].concat(), structs, mappings, span: start + end }) + Ok(ProgramScope { + program_id, + consts, + functions: [transitions, functions].concat(), + structs, + mappings, + span: start + end, + }) } /// Returns a [`Vec`] AST node if the next tokens represent a struct member. diff --git a/compiler/passes/src/common/symbol_table/function_symbol.rs b/compiler/passes/src/common/symbol_table/function_symbol.rs index 9afa9bb0fa..c1a54d7773 100644 --- a/compiler/passes/src/common/symbol_table/function_symbol.rs +++ b/compiler/passes/src/common/symbol_table/function_symbol.rs @@ -55,7 +55,7 @@ impl SymbolTable { output_type: func.output_type.clone(), variant: func.variant, _span: func.span, - input: func.input.clone() + input: func.input.clone(), } } } diff --git a/compiler/passes/src/type_checking/check_expressions.rs b/compiler/passes/src/type_checking/check_expressions.rs index 182fa691fa..ab75d08a4c 100644 --- a/compiler/passes/src/type_checking/check_expressions.rs +++ b/compiler/passes/src/type_checking/check_expressions.rs @@ -254,15 +254,13 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { expected, access.span(), )); - } - else { + } else { // Future arguments must be addressed by their index. Ex: `f.1.3`. self.emit_err(TypeCheckerError::future_access_must_be_number( access.name.name, access.name.span(), )); } - } Some(type_) => { self.emit_err(TypeCheckerError::type_should_be(type_, "struct", access.inner.span())); diff --git a/compiler/passes/src/type_checking/check_program.rs b/compiler/passes/src/type_checking/check_program.rs index a31d33afd0..9061695e63 100644 --- a/compiler/passes/src/type_checking/check_program.rs +++ b/compiler/passes/src/type_checking/check_program.rs @@ -22,8 +22,9 @@ use leo_span::sym; use snarkvm::console::network::{Network, Testnet3}; -use std::collections::HashSet; +use indexmap::IndexSet; use leo_ast::Input::{External, Internal}; +use std::collections::HashSet; // TODO: Cleanup logic for tuples. @@ -269,6 +270,8 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { self.variant = Some(function.variant); self.is_finalize = function.variant == Variant::Standard && function.is_async; self.is_finalize_caller = function.variant == Variant::Transition && function.is_async; + self.has_finalize = false; + self.futures = IndexSet::new(); // Lookup function metadata in the symbol table. // Note that this unwrap is safe since function metadata is stored in a prior pass. @@ -301,16 +304,20 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { } // Initialize the list of input futures. Each one must be awaited before the end of the function. - self.to_await = function.input.iter().filter_map(|input| match input { - Internal(parameter) => { - if matches!(parameter.type_, Type::Future(ty)) { - Some(parameter.identifier.name) - } else { - None + self.to_await = function + .input + .iter() + .filter_map(|input| match input { + Internal(parameter) => { + if let Some(Type::Future(ty)) = parameter.type_.clone() { + Some(parameter.identifier.name) + } else { + None + } } - } - External(_) => None, - }).collect(); + External(_) => None, + }) + .collect(); } self.visit_block(&function.block); @@ -326,7 +333,14 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> { // Exit the function's scope. self.exit_scope(function_index); - // Unset the function variant variables. - (self.variant, self.is_finalize_caller, self.is_finalize) = (None, false, false); + // Make sure that async transitions call finalize. + if self.is_finalize_caller && !self.has_finalize { + self.emit_err(TypeCheckerError::async_transition_must_call_async_function(function.span)); + } + + // Must have awaited all futures. + if self.is_finalize && !self.to_await.is_empty() { + self.emit_err(TypeCheckerError::must_await_all_futures(&self.to_await, function.span())); + } } } diff --git a/compiler/passes/src/type_checking/check_statements.rs b/compiler/passes/src/type_checking/check_statements.rs index e47fa732dc..633fcfd052 100644 --- a/compiler/passes/src/type_checking/check_statements.rs +++ b/compiler/passes/src/type_checking/check_statements.rs @@ -14,8 +14,8 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . -use indexmap::IndexSet; use crate::{TypeChecker, VariableSymbol, VariableType}; +use indexmap::IndexSet; use itertools::Itertools; use leo_ast::*; @@ -91,9 +91,6 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { } fn visit_block(&mut self, input: &'a Block) { - // Reset environment flag. - if self.is_finalize_caller { self.has_called_finalize = false; self.futures = IndexSet::new() }; - // Create a new scope for the then-block. let scope_index = self.create_child_scope(); @@ -101,15 +98,6 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { // Exit the scope for the then-block. self.exit_scope(scope_index); - - // Must have awaited all futures. - if self.is_finalize && !self.to_await.is_empty() { - self.emit_err(TypeCheckerError::must_await_all_futures(&self.to_await, input.span())); - } - // Check that an async function call was made to propagate futures to a finalize block. - else if self.is_finalize_caller && !self.has_called_finalize { - self.emit_err(TypeCheckerError::async_transition_must_call_async_function(input.span())); - } } fn visit_conditional(&mut self, input: &'a ConditionalStatement) { @@ -157,6 +145,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { // Restore the previous `has_return` flag. self.has_return = previous_has_return || (then_block_has_return && otherwise_block_has_return); // Restore the previous `has_finalize` flag. + // TODO: doesn't this mean that we allow multiple finalizes? self.has_finalize = previous_has_finalize || (then_block_has_finalize && otherwise_block_has_finalize); } @@ -394,14 +383,11 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { // We can safely unwrap all self.parent instances because // statements should always have some parent block let parent = self.function.unwrap(); - let return_type = &self.symbol_table.borrow().lookup_fn_symbol(self.program_name.unwrap(), parent).map(|f| { - match self.is_finalize { - // TODO: Check this. - // Note that this `unwrap()` is safe since we checked that the function has a finalize block. - true => f.finalize.as_ref().unwrap().output_type.clone(), - false => f.output_type.clone(), - } - }); + let return_type = &self + .symbol_table + .borrow() + .lookup_fn_symbol(self.program_name.unwrap(), parent) + .map(|f| f.output_type.clone()); // Set the `has_return` flag. self.has_return = true; @@ -421,43 +407,5 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> { self.visit_expression(&input.expression, return_type); // Unset the `is_return` flag. self.is_return = false; - - if let Some(arguments) = &input.finalize_arguments { - if self.is_finalize { - self.emit_err(TypeCheckerError::finalize_in_finalize(input.span())); - } - - // Set the `has_finalize` flag. - self.has_finalize = true; - - // Check that the function has a finalize block. - // Note that `self.function.unwrap()` is safe since every `self.function` is set for every function. - // Note that `(self.function.unwrap()).unwrap()` is safe since all functions have been checked to exist. - let finalize = self - .symbol_table - .borrow() - .lookup_fn_symbol(self.program_name.unwrap(), self.function.unwrap()) - .unwrap() - .finalize - .clone(); - match finalize { - None => self.emit_err(TypeCheckerError::finalize_without_finalize_block(input.span())), - Some(finalize) => { - // Check number of function arguments. - if finalize.input.len() != arguments.len() { - self.emit_err(TypeCheckerError::incorrect_num_args_to_finalize( - finalize.input.len(), - arguments.len(), - input.span(), - )); - } - - // Check function argument types. - finalize.input.iter().zip(arguments.iter()).for_each(|(expected, argument)| { - self.visit_expression(argument, &Some(expected.type_())); - }); - } - } - } } } diff --git a/compiler/passes/src/type_checking/checker.rs b/compiler/passes/src/type_checking/checker.rs index 3ccb79aa82..6a381dbfb7 100644 --- a/compiler/passes/src/type_checking/checker.rs +++ b/compiler/passes/src/type_checking/checker.rs @@ -36,9 +36,9 @@ use leo_span::{Span, Symbol}; use snarkvm::console::network::{Network, Testnet3}; +use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; use std::cell::RefCell; -use indexmap::{IndexMap, IndexSet}; pub struct TypeChecker<'a> { /// The symbol table for the program. @@ -72,9 +72,9 @@ pub struct TypeChecker<'a> { /// The futures that must be propagated to an async function. pub(crate) futures: IndexSet, /// Whether the finalize caller has called the finalize function. - pub(crate) has_called_finalize: bool, + pub(crate) has_finalize: bool, /// Mapping from async function name to the inferred input types. - pub(crate) future_map: IndexMap> + pub(crate) inferred_future_types: IndexMap>, } const ADDRESS_TYPE: Type = Type::Address; @@ -144,8 +144,8 @@ impl<'a> TypeChecker<'a> { is_finalize_caller: false, to_await: IndexSet::new(), futures: IndexSet::new(), - has_called_finalize: false, - future_map: IndexMap::new(), + has_finalize: false, + inferred_future_types: IndexMap::new(), } } @@ -201,7 +201,8 @@ impl<'a> TypeChecker<'a> { } (Type::Integer(left), Type::Integer(right)) => left.eq(right), (Type::Mapping(left), Type::Mapping(right)) => { - self.check_eq_type_structure(&left.key, &right.key, span) && self.check_eq_type_structure(&left.value, &right.value, span) + self.check_eq_type_structure(&left.key, &right.key, span) + && self.check_eq_type_structure(&left.value, &right.value, span) } (Type::Tuple(left), Type::Tuple(right)) if left.length() == right.length() => left .elements() @@ -221,8 +222,7 @@ impl<'a> TypeChecker<'a> { span, )); false - } - else { + } else { true } } @@ -243,7 +243,9 @@ impl<'a> TypeChecker<'a> { self.emit_err(TypeCheckerError::expected_one_type_of(t1.to_string(), t2, span)); } } - (Some(type_), None) | (None, Some(type_)) => self.emit_err(TypeCheckerError::type_should_be("no type", type_, span)), + (Some(type_), None) | (None, Some(type_)) => { + self.emit_err(TypeCheckerError::type_should_be("no type", type_, span)) + } _ => {} } } @@ -1097,6 +1099,10 @@ impl<'a> TypeChecker<'a> { // Return a boolean. Some(Type::Boolean) } + CoreFunction::FutureAwait => { + // TODO: check that were in finalize here? + None + } } } @@ -1252,8 +1258,34 @@ impl<'a> TypeChecker<'a> { pub(crate) fn check_function_signature(&mut self, function: &Function) { self.variant = Some(function.variant); + // Special type checking for finalize blocks. + if self.is_finalize { + if let Some(inferred_future_types) = self.inferred_future_types.borrow().get(&self.function.unwrap()) { + // Check same number of inputs as expected. + if inferred_future_types.len() != function.input.len() { + self.emit_err(TypeCheckerError::async_function_input_length_mismatch( + inferred_future_types.len(), + function.input.len(), + function.span(), + )); + } + // Check that the input parameters match the inferred types from when the async function is invoked. + function + .input + .iter() + .zip_eq(inferred_future_types.iter()) + .for_each(|(t1, t2)| self.check_eq_type(&t1.type_(), t2, t1.span())); + } else if function.input.len() > 0 { + self.emit_err(TypeCheckerError::async_function_input_length_mismatch( + 0, + function.input.len(), + function.span(), + )); + } + } + // Type check the function's parameters. - function.input.iter().for_each(|input_var| { + function.input.iter().enumerate().for_each(|(index, input_var)| { // Check that the type of input parameter is defined. self.assert_type_is_valid(&input_var.type_(), input_var.span()); // Check that the type of the input parameter is not a tuple. @@ -1263,12 +1295,13 @@ impl<'a> TypeChecker<'a> { // Check that the input parameter is not a record. else if let Type::Composite(struct_) = input_var.type_() { // Note that this unwrap is safe, as the type is defined. - if !matches!(function.variant, Variant::Transition) && self - .symbol_table - .borrow() - .lookup_struct(struct_.program.unwrap(), struct_.id.name) - .unwrap() - .is_record + if !matches!(function.variant, Variant::Transition) + && self + .symbol_table + .borrow() + .lookup_struct(struct_.program.unwrap(), struct_.id.name) + .unwrap() + .is_record { self.emit_err(TypeCheckerError::function_cannot_input_or_output_a_record(input_var.span())) } @@ -1276,7 +1309,9 @@ impl<'a> TypeChecker<'a> { // Check that the finalize input parameter is not constant or private. if self.is_finalize && (self.mode() == Mode::Constant || input_var.mode() == Mode::Private) { - self.emit_err(TypeCheckerError::finalize_input_mode_must_be_public(input_var.span())); + if (self.mode() == Mode::Constant || input_var.mode() == Mode::Private) { + self.emit_err(TypeCheckerError::finalize_input_mode_must_be_public(input_var.span())); + } } // Note that this unwrap is safe since we assign to `self.variant` above. @@ -1306,7 +1341,7 @@ impl<'a> TypeChecker<'a> { // Type check the function's return type. // Note that checking that each of the component types are defined is sufficient to check that `output_type` is defined. - function.output.iter().enumerate().for_each(|(index,output)| { + function.output.iter().enumerate().for_each(|(index, output)| { match output { Output::External(external) => { // If the function is not a transition function, then it cannot output a record. @@ -1348,8 +1383,13 @@ impl<'a> TypeChecker<'a> { self.emit_err(TypeCheckerError::async_function_must_return_single_future(function_output.span)); } // Async transitions must return one future in the first position. - if self.is_finalize_caller && ((index > 0 && matches!(function_output.type_, Type::Future(_))) || (index == 0 && !matches!(function_output.type_, Type::Future(_)))) { - self.emit_err(TypeCheckerError::async_transition_must_return_future_as_first_output(function_output.span)); + if self.is_finalize_caller + && ((index > 0 && matches!(function_output.type_, Type::Future(_))) + || (index == 0 && !matches!(function_output.type_, Type::Future(_)))) + { + self.emit_err(TypeCheckerError::async_transition_must_return_future_as_first_output( + function_output.span, + )); } } }