Function signature checking w/ nested struct/future

This commit is contained in:
evan-schott 2024-02-19 18:28:33 -08:00
parent 8feb612ce7
commit 5a499937e6
6 changed files with 105 additions and 97 deletions

View File

@ -113,7 +113,8 @@ impl ParserContext<'_> {
// Parse the body of the program scope.
let mut consts: Vec<(Symbol, ConstDeclaration)> = Vec::new();
let (mut transitions, mut functions): (Vec<(Symbol, Function)>, Vec<(Symbol, Function)>) = (Vec::new(), Vec::new());
let (mut transitions, mut functions): (Vec<(Symbol, Function)>, Vec<(Symbol, Function)>) =
(Vec::new(), Vec::new());
let mut structs: Vec<(Symbol, Composite)> = Vec::new();
let mut mappings: Vec<(Symbol, Mapping)> = Vec::new();
@ -160,7 +161,14 @@ impl ParserContext<'_> {
// Parse `}`.
let end = self.expect(&Token::RightCurly)?;
Ok(ProgramScope { program_id, consts, functions: [transitions, functions].concat(), structs, mappings, span: start + end })
Ok(ProgramScope {
program_id,
consts,
functions: [transitions, functions].concat(),
structs,
mappings,
span: start + end,
})
}
/// Returns a [`Vec<Member>`] AST node if the next tokens represent a struct member.

View File

@ -55,7 +55,7 @@ impl SymbolTable {
output_type: func.output_type.clone(),
variant: func.variant,
_span: func.span,
input: func.input.clone()
input: func.input.clone(),
}
}
}

View File

@ -254,15 +254,13 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
expected,
access.span(),
));
}
else {
} else {
// Future arguments must be addressed by their index. Ex: `f.1.3`.
self.emit_err(TypeCheckerError::future_access_must_be_number(
access.name.name,
access.name.span(),
));
}
}
Some(type_) => {
self.emit_err(TypeCheckerError::type_should_be(type_, "struct", access.inner.span()));

View File

@ -22,8 +22,9 @@ use leo_span::sym;
use snarkvm::console::network::{Network, Testnet3};
use std::collections::HashSet;
use indexmap::IndexSet;
use leo_ast::Input::{External, Internal};
use std::collections::HashSet;
// TODO: Cleanup logic for tuples.
@ -269,6 +270,8 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
self.variant = Some(function.variant);
self.is_finalize = function.variant == Variant::Standard && function.is_async;
self.is_finalize_caller = function.variant == Variant::Transition && function.is_async;
self.has_finalize = false;
self.futures = IndexSet::new();
// Lookup function metadata in the symbol table.
// Note that this unwrap is safe since function metadata is stored in a prior pass.
@ -301,16 +304,20 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
}
// Initialize the list of input futures. Each one must be awaited before the end of the function.
self.to_await = function.input.iter().filter_map(|input| match input {
self.to_await = function
.input
.iter()
.filter_map(|input| match input {
Internal(parameter) => {
if matches!(parameter.type_, Type::Future(ty)) {
if let Some(Type::Future(ty)) = parameter.type_.clone() {
Some(parameter.identifier.name)
} else {
None
}
}
External(_) => None,
}).collect();
})
.collect();
}
self.visit_block(&function.block);
@ -326,7 +333,14 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
// Exit the function's scope.
self.exit_scope(function_index);
// Unset the function variant variables.
(self.variant, self.is_finalize_caller, self.is_finalize) = (None, false, false);
// Make sure that async transitions call finalize.
if self.is_finalize_caller && !self.has_finalize {
self.emit_err(TypeCheckerError::async_transition_must_call_async_function(function.span));
}
// Must have awaited all futures.
if self.is_finalize && !self.to_await.is_empty() {
self.emit_err(TypeCheckerError::must_await_all_futures(&self.to_await, function.span()));
}
}
}

View File

@ -14,8 +14,8 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use indexmap::IndexSet;
use crate::{TypeChecker, VariableSymbol, VariableType};
use indexmap::IndexSet;
use itertools::Itertools;
use leo_ast::*;
@ -91,9 +91,6 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
}
fn visit_block(&mut self, input: &'a Block) {
// Reset environment flag.
if self.is_finalize_caller { self.has_called_finalize = false; self.futures = IndexSet::new() };
// Create a new scope for the then-block.
let scope_index = self.create_child_scope();
@ -101,15 +98,6 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
// Exit the scope for the then-block.
self.exit_scope(scope_index);
// Must have awaited all futures.
if self.is_finalize && !self.to_await.is_empty() {
self.emit_err(TypeCheckerError::must_await_all_futures(&self.to_await, input.span()));
}
// Check that an async function call was made to propagate futures to a finalize block.
else if self.is_finalize_caller && !self.has_called_finalize {
self.emit_err(TypeCheckerError::async_transition_must_call_async_function(input.span()));
}
}
fn visit_conditional(&mut self, input: &'a ConditionalStatement) {
@ -157,6 +145,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
// Restore the previous `has_return` flag.
self.has_return = previous_has_return || (then_block_has_return && otherwise_block_has_return);
// Restore the previous `has_finalize` flag.
// TODO: doesn't this mean that we allow multiple finalizes?
self.has_finalize = previous_has_finalize || (then_block_has_finalize && otherwise_block_has_finalize);
}
@ -394,14 +383,11 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
// We can safely unwrap all self.parent instances because
// statements should always have some parent block
let parent = self.function.unwrap();
let return_type = &self.symbol_table.borrow().lookup_fn_symbol(self.program_name.unwrap(), parent).map(|f| {
match self.is_finalize {
// TODO: Check this.
// Note that this `unwrap()` is safe since we checked that the function has a finalize block.
true => f.finalize.as_ref().unwrap().output_type.clone(),
false => f.output_type.clone(),
}
});
let return_type = &self
.symbol_table
.borrow()
.lookup_fn_symbol(self.program_name.unwrap(), parent)
.map(|f| f.output_type.clone());
// Set the `has_return` flag.
self.has_return = true;
@ -421,43 +407,5 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
self.visit_expression(&input.expression, return_type);
// Unset the `is_return` flag.
self.is_return = false;
if let Some(arguments) = &input.finalize_arguments {
if self.is_finalize {
self.emit_err(TypeCheckerError::finalize_in_finalize(input.span()));
}
// Set the `has_finalize` flag.
self.has_finalize = true;
// Check that the function has a finalize block.
// Note that `self.function.unwrap()` is safe since every `self.function` is set for every function.
// Note that `(self.function.unwrap()).unwrap()` is safe since all functions have been checked to exist.
let finalize = self
.symbol_table
.borrow()
.lookup_fn_symbol(self.program_name.unwrap(), self.function.unwrap())
.unwrap()
.finalize
.clone();
match finalize {
None => self.emit_err(TypeCheckerError::finalize_without_finalize_block(input.span())),
Some(finalize) => {
// Check number of function arguments.
if finalize.input.len() != arguments.len() {
self.emit_err(TypeCheckerError::incorrect_num_args_to_finalize(
finalize.input.len(),
arguments.len(),
input.span(),
));
}
// Check function argument types.
finalize.input.iter().zip(arguments.iter()).for_each(|(expected, argument)| {
self.visit_expression(argument, &Some(expected.type_()));
});
}
}
}
}
}

View File

@ -36,9 +36,9 @@ use leo_span::{Span, Symbol};
use snarkvm::console::network::{Network, Testnet3};
use indexmap::{IndexMap, IndexSet};
use itertools::Itertools;
use std::cell::RefCell;
use indexmap::{IndexMap, IndexSet};
pub struct TypeChecker<'a> {
/// The symbol table for the program.
@ -72,9 +72,9 @@ pub struct TypeChecker<'a> {
/// The futures that must be propagated to an async function.
pub(crate) futures: IndexSet<Symbol>,
/// Whether the finalize caller has called the finalize function.
pub(crate) has_called_finalize: bool,
pub(crate) has_finalize: bool,
/// Mapping from async function name to the inferred input types.
pub(crate) future_map: IndexMap<Symbol, Vec<Type>>
pub(crate) inferred_future_types: IndexMap<Symbol, Vec<Type>>,
}
const ADDRESS_TYPE: Type = Type::Address;
@ -144,8 +144,8 @@ impl<'a> TypeChecker<'a> {
is_finalize_caller: false,
to_await: IndexSet::new(),
futures: IndexSet::new(),
has_called_finalize: false,
future_map: IndexMap::new(),
has_finalize: false,
inferred_future_types: IndexMap::new(),
}
}
@ -201,7 +201,8 @@ impl<'a> TypeChecker<'a> {
}
(Type::Integer(left), Type::Integer(right)) => left.eq(right),
(Type::Mapping(left), Type::Mapping(right)) => {
self.check_eq_type_structure(&left.key, &right.key, span) && self.check_eq_type_structure(&left.value, &right.value, span)
self.check_eq_type_structure(&left.key, &right.key, span)
&& self.check_eq_type_structure(&left.value, &right.value, span)
}
(Type::Tuple(left), Type::Tuple(right)) if left.length() == right.length() => left
.elements()
@ -221,8 +222,7 @@ impl<'a> TypeChecker<'a> {
span,
));
false
}
else {
} else {
true
}
}
@ -243,7 +243,9 @@ impl<'a> TypeChecker<'a> {
self.emit_err(TypeCheckerError::expected_one_type_of(t1.to_string(), t2, span));
}
}
(Some(type_), None) | (None, Some(type_)) => self.emit_err(TypeCheckerError::type_should_be("no type", type_, span)),
(Some(type_), None) | (None, Some(type_)) => {
self.emit_err(TypeCheckerError::type_should_be("no type", type_, span))
}
_ => {}
}
}
@ -1097,6 +1099,10 @@ impl<'a> TypeChecker<'a> {
// Return a boolean.
Some(Type::Boolean)
}
CoreFunction::FutureAwait => {
// TODO: check that were in finalize here?
None
}
}
}
@ -1252,8 +1258,34 @@ impl<'a> TypeChecker<'a> {
pub(crate) fn check_function_signature(&mut self, function: &Function) {
self.variant = Some(function.variant);
// Special type checking for finalize blocks.
if self.is_finalize {
if let Some(inferred_future_types) = self.inferred_future_types.borrow().get(&self.function.unwrap()) {
// Check same number of inputs as expected.
if inferred_future_types.len() != function.input.len() {
self.emit_err(TypeCheckerError::async_function_input_length_mismatch(
inferred_future_types.len(),
function.input.len(),
function.span(),
));
}
// Check that the input parameters match the inferred types from when the async function is invoked.
function
.input
.iter()
.zip_eq(inferred_future_types.iter())
.for_each(|(t1, t2)| self.check_eq_type(&t1.type_(), t2, t1.span()));
} else if function.input.len() > 0 {
self.emit_err(TypeCheckerError::async_function_input_length_mismatch(
0,
function.input.len(),
function.span(),
));
}
}
// Type check the function's parameters.
function.input.iter().for_each(|input_var| {
function.input.iter().enumerate().for_each(|(index, input_var)| {
// Check that the type of input parameter is defined.
self.assert_type_is_valid(&input_var.type_(), input_var.span());
// Check that the type of the input parameter is not a tuple.
@ -1263,7 +1295,8 @@ impl<'a> TypeChecker<'a> {
// Check that the input parameter is not a record.
else if let Type::Composite(struct_) = input_var.type_() {
// Note that this unwrap is safe, as the type is defined.
if !matches!(function.variant, Variant::Transition) && self
if !matches!(function.variant, Variant::Transition)
&& self
.symbol_table
.borrow()
.lookup_struct(struct_.program.unwrap(), struct_.id.name)
@ -1276,8 +1309,10 @@ impl<'a> TypeChecker<'a> {
// Check that the finalize input parameter is not constant or private.
if self.is_finalize && (self.mode() == Mode::Constant || input_var.mode() == Mode::Private) {
if (self.mode() == Mode::Constant || input_var.mode() == Mode::Private) {
self.emit_err(TypeCheckerError::finalize_input_mode_must_be_public(input_var.span()));
}
}
// Note that this unwrap is safe since we assign to `self.variant` above.
match self.variant.unwrap() {
@ -1348,8 +1383,13 @@ impl<'a> TypeChecker<'a> {
self.emit_err(TypeCheckerError::async_function_must_return_single_future(function_output.span));
}
// Async transitions must return one future in the first position.
if self.is_finalize_caller && ((index > 0 && matches!(function_output.type_, Type::Future(_))) || (index == 0 && !matches!(function_output.type_, Type::Future(_)))) {
self.emit_err(TypeCheckerError::async_transition_must_return_future_as_first_output(function_output.span));
if self.is_finalize_caller
&& ((index > 0 && matches!(function_output.type_, Type::Future(_)))
|| (index == 0 && !matches!(function_output.type_, Type::Future(_))))
{
self.emit_err(TypeCheckerError::async_transition_must_return_future_as_first_output(
function_output.span,
));
}
}
}