mirror of
https://github.com/ProvableHQ/leo.git
synced 2024-11-23 23:23:50 +03:00
Function signature checking w/ nested struct/future
This commit is contained in:
parent
8feb612ce7
commit
5a499937e6
@ -113,7 +113,8 @@ impl ParserContext<'_> {
|
||||
|
||||
// Parse the body of the program scope.
|
||||
let mut consts: Vec<(Symbol, ConstDeclaration)> = Vec::new();
|
||||
let (mut transitions, mut functions): (Vec<(Symbol, Function)>, Vec<(Symbol, Function)>) = (Vec::new(), Vec::new());
|
||||
let (mut transitions, mut functions): (Vec<(Symbol, Function)>, Vec<(Symbol, Function)>) =
|
||||
(Vec::new(), Vec::new());
|
||||
let mut structs: Vec<(Symbol, Composite)> = Vec::new();
|
||||
let mut mappings: Vec<(Symbol, Mapping)> = Vec::new();
|
||||
|
||||
@ -160,7 +161,14 @@ impl ParserContext<'_> {
|
||||
// Parse `}`.
|
||||
let end = self.expect(&Token::RightCurly)?;
|
||||
|
||||
Ok(ProgramScope { program_id, consts, functions: [transitions, functions].concat(), structs, mappings, span: start + end })
|
||||
Ok(ProgramScope {
|
||||
program_id,
|
||||
consts,
|
||||
functions: [transitions, functions].concat(),
|
||||
structs,
|
||||
mappings,
|
||||
span: start + end,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a [`Vec<Member>`] AST node if the next tokens represent a struct member.
|
||||
|
@ -55,7 +55,7 @@ impl SymbolTable {
|
||||
output_type: func.output_type.clone(),
|
||||
variant: func.variant,
|
||||
_span: func.span,
|
||||
input: func.input.clone()
|
||||
input: func.input.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -254,15 +254,13 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
|
||||
expected,
|
||||
access.span(),
|
||||
));
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
// Future arguments must be addressed by their index. Ex: `f.1.3`.
|
||||
self.emit_err(TypeCheckerError::future_access_must_be_number(
|
||||
access.name.name,
|
||||
access.name.span(),
|
||||
));
|
||||
}
|
||||
|
||||
}
|
||||
Some(type_) => {
|
||||
self.emit_err(TypeCheckerError::type_should_be(type_, "struct", access.inner.span()));
|
||||
|
@ -22,8 +22,9 @@ use leo_span::sym;
|
||||
|
||||
use snarkvm::console::network::{Network, Testnet3};
|
||||
|
||||
use std::collections::HashSet;
|
||||
use indexmap::IndexSet;
|
||||
use leo_ast::Input::{External, Internal};
|
||||
use std::collections::HashSet;
|
||||
|
||||
// TODO: Cleanup logic for tuples.
|
||||
|
||||
@ -269,6 +270,8 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
|
||||
self.variant = Some(function.variant);
|
||||
self.is_finalize = function.variant == Variant::Standard && function.is_async;
|
||||
self.is_finalize_caller = function.variant == Variant::Transition && function.is_async;
|
||||
self.has_finalize = false;
|
||||
self.futures = IndexSet::new();
|
||||
|
||||
// Lookup function metadata in the symbol table.
|
||||
// Note that this unwrap is safe since function metadata is stored in a prior pass.
|
||||
@ -301,16 +304,20 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
|
||||
}
|
||||
|
||||
// Initialize the list of input futures. Each one must be awaited before the end of the function.
|
||||
self.to_await = function.input.iter().filter_map(|input| match input {
|
||||
Internal(parameter) => {
|
||||
if matches!(parameter.type_, Type::Future(ty)) {
|
||||
Some(parameter.identifier.name)
|
||||
} else {
|
||||
None
|
||||
self.to_await = function
|
||||
.input
|
||||
.iter()
|
||||
.filter_map(|input| match input {
|
||||
Internal(parameter) => {
|
||||
if let Some(Type::Future(ty)) = parameter.type_.clone() {
|
||||
Some(parameter.identifier.name)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
External(_) => None,
|
||||
}).collect();
|
||||
External(_) => None,
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
self.visit_block(&function.block);
|
||||
@ -326,7 +333,14 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
|
||||
// Exit the function's scope.
|
||||
self.exit_scope(function_index);
|
||||
|
||||
// Unset the function variant variables.
|
||||
(self.variant, self.is_finalize_caller, self.is_finalize) = (None, false, false);
|
||||
// Make sure that async transitions call finalize.
|
||||
if self.is_finalize_caller && !self.has_finalize {
|
||||
self.emit_err(TypeCheckerError::async_transition_must_call_async_function(function.span));
|
||||
}
|
||||
|
||||
// Must have awaited all futures.
|
||||
if self.is_finalize && !self.to_await.is_empty() {
|
||||
self.emit_err(TypeCheckerError::must_await_all_futures(&self.to_await, function.span()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14,8 +14,8 @@
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
use indexmap::IndexSet;
|
||||
use crate::{TypeChecker, VariableSymbol, VariableType};
|
||||
use indexmap::IndexSet;
|
||||
use itertools::Itertools;
|
||||
|
||||
use leo_ast::*;
|
||||
@ -91,9 +91,6 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
|
||||
}
|
||||
|
||||
fn visit_block(&mut self, input: &'a Block) {
|
||||
// Reset environment flag.
|
||||
if self.is_finalize_caller { self.has_called_finalize = false; self.futures = IndexSet::new() };
|
||||
|
||||
// Create a new scope for the then-block.
|
||||
let scope_index = self.create_child_scope();
|
||||
|
||||
@ -101,15 +98,6 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
|
||||
|
||||
// Exit the scope for the then-block.
|
||||
self.exit_scope(scope_index);
|
||||
|
||||
// Must have awaited all futures.
|
||||
if self.is_finalize && !self.to_await.is_empty() {
|
||||
self.emit_err(TypeCheckerError::must_await_all_futures(&self.to_await, input.span()));
|
||||
}
|
||||
// Check that an async function call was made to propagate futures to a finalize block.
|
||||
else if self.is_finalize_caller && !self.has_called_finalize {
|
||||
self.emit_err(TypeCheckerError::async_transition_must_call_async_function(input.span()));
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_conditional(&mut self, input: &'a ConditionalStatement) {
|
||||
@ -157,6 +145,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
|
||||
// Restore the previous `has_return` flag.
|
||||
self.has_return = previous_has_return || (then_block_has_return && otherwise_block_has_return);
|
||||
// Restore the previous `has_finalize` flag.
|
||||
// TODO: doesn't this mean that we allow multiple finalizes?
|
||||
self.has_finalize = previous_has_finalize || (then_block_has_finalize && otherwise_block_has_finalize);
|
||||
}
|
||||
|
||||
@ -394,14 +383,11 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
|
||||
// We can safely unwrap all self.parent instances because
|
||||
// statements should always have some parent block
|
||||
let parent = self.function.unwrap();
|
||||
let return_type = &self.symbol_table.borrow().lookup_fn_symbol(self.program_name.unwrap(), parent).map(|f| {
|
||||
match self.is_finalize {
|
||||
// TODO: Check this.
|
||||
// Note that this `unwrap()` is safe since we checked that the function has a finalize block.
|
||||
true => f.finalize.as_ref().unwrap().output_type.clone(),
|
||||
false => f.output_type.clone(),
|
||||
}
|
||||
});
|
||||
let return_type = &self
|
||||
.symbol_table
|
||||
.borrow()
|
||||
.lookup_fn_symbol(self.program_name.unwrap(), parent)
|
||||
.map(|f| f.output_type.clone());
|
||||
|
||||
// Set the `has_return` flag.
|
||||
self.has_return = true;
|
||||
@ -421,43 +407,5 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
|
||||
self.visit_expression(&input.expression, return_type);
|
||||
// Unset the `is_return` flag.
|
||||
self.is_return = false;
|
||||
|
||||
if let Some(arguments) = &input.finalize_arguments {
|
||||
if self.is_finalize {
|
||||
self.emit_err(TypeCheckerError::finalize_in_finalize(input.span()));
|
||||
}
|
||||
|
||||
// Set the `has_finalize` flag.
|
||||
self.has_finalize = true;
|
||||
|
||||
// Check that the function has a finalize block.
|
||||
// Note that `self.function.unwrap()` is safe since every `self.function` is set for every function.
|
||||
// Note that `(self.function.unwrap()).unwrap()` is safe since all functions have been checked to exist.
|
||||
let finalize = self
|
||||
.symbol_table
|
||||
.borrow()
|
||||
.lookup_fn_symbol(self.program_name.unwrap(), self.function.unwrap())
|
||||
.unwrap()
|
||||
.finalize
|
||||
.clone();
|
||||
match finalize {
|
||||
None => self.emit_err(TypeCheckerError::finalize_without_finalize_block(input.span())),
|
||||
Some(finalize) => {
|
||||
// Check number of function arguments.
|
||||
if finalize.input.len() != arguments.len() {
|
||||
self.emit_err(TypeCheckerError::incorrect_num_args_to_finalize(
|
||||
finalize.input.len(),
|
||||
arguments.len(),
|
||||
input.span(),
|
||||
));
|
||||
}
|
||||
|
||||
// Check function argument types.
|
||||
finalize.input.iter().zip(arguments.iter()).for_each(|(expected, argument)| {
|
||||
self.visit_expression(argument, &Some(expected.type_()));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -36,9 +36,9 @@ use leo_span::{Span, Symbol};
|
||||
|
||||
use snarkvm::console::network::{Network, Testnet3};
|
||||
|
||||
use indexmap::{IndexMap, IndexSet};
|
||||
use itertools::Itertools;
|
||||
use std::cell::RefCell;
|
||||
use indexmap::{IndexMap, IndexSet};
|
||||
|
||||
pub struct TypeChecker<'a> {
|
||||
/// The symbol table for the program.
|
||||
@ -72,9 +72,9 @@ pub struct TypeChecker<'a> {
|
||||
/// The futures that must be propagated to an async function.
|
||||
pub(crate) futures: IndexSet<Symbol>,
|
||||
/// Whether the finalize caller has called the finalize function.
|
||||
pub(crate) has_called_finalize: bool,
|
||||
pub(crate) has_finalize: bool,
|
||||
/// Mapping from async function name to the inferred input types.
|
||||
pub(crate) future_map: IndexMap<Symbol, Vec<Type>>
|
||||
pub(crate) inferred_future_types: IndexMap<Symbol, Vec<Type>>,
|
||||
}
|
||||
|
||||
const ADDRESS_TYPE: Type = Type::Address;
|
||||
@ -144,8 +144,8 @@ impl<'a> TypeChecker<'a> {
|
||||
is_finalize_caller: false,
|
||||
to_await: IndexSet::new(),
|
||||
futures: IndexSet::new(),
|
||||
has_called_finalize: false,
|
||||
future_map: IndexMap::new(),
|
||||
has_finalize: false,
|
||||
inferred_future_types: IndexMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -201,7 +201,8 @@ impl<'a> TypeChecker<'a> {
|
||||
}
|
||||
(Type::Integer(left), Type::Integer(right)) => left.eq(right),
|
||||
(Type::Mapping(left), Type::Mapping(right)) => {
|
||||
self.check_eq_type_structure(&left.key, &right.key, span) && self.check_eq_type_structure(&left.value, &right.value, span)
|
||||
self.check_eq_type_structure(&left.key, &right.key, span)
|
||||
&& self.check_eq_type_structure(&left.value, &right.value, span)
|
||||
}
|
||||
(Type::Tuple(left), Type::Tuple(right)) if left.length() == right.length() => left
|
||||
.elements()
|
||||
@ -221,8 +222,7 @@ impl<'a> TypeChecker<'a> {
|
||||
span,
|
||||
));
|
||||
false
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
@ -243,7 +243,9 @@ impl<'a> TypeChecker<'a> {
|
||||
self.emit_err(TypeCheckerError::expected_one_type_of(t1.to_string(), t2, span));
|
||||
}
|
||||
}
|
||||
(Some(type_), None) | (None, Some(type_)) => self.emit_err(TypeCheckerError::type_should_be("no type", type_, span)),
|
||||
(Some(type_), None) | (None, Some(type_)) => {
|
||||
self.emit_err(TypeCheckerError::type_should_be("no type", type_, span))
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@ -1097,6 +1099,10 @@ impl<'a> TypeChecker<'a> {
|
||||
// Return a boolean.
|
||||
Some(Type::Boolean)
|
||||
}
|
||||
CoreFunction::FutureAwait => {
|
||||
// TODO: check that were in finalize here?
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1252,8 +1258,34 @@ impl<'a> TypeChecker<'a> {
|
||||
pub(crate) fn check_function_signature(&mut self, function: &Function) {
|
||||
self.variant = Some(function.variant);
|
||||
|
||||
// Special type checking for finalize blocks.
|
||||
if self.is_finalize {
|
||||
if let Some(inferred_future_types) = self.inferred_future_types.borrow().get(&self.function.unwrap()) {
|
||||
// Check same number of inputs as expected.
|
||||
if inferred_future_types.len() != function.input.len() {
|
||||
self.emit_err(TypeCheckerError::async_function_input_length_mismatch(
|
||||
inferred_future_types.len(),
|
||||
function.input.len(),
|
||||
function.span(),
|
||||
));
|
||||
}
|
||||
// Check that the input parameters match the inferred types from when the async function is invoked.
|
||||
function
|
||||
.input
|
||||
.iter()
|
||||
.zip_eq(inferred_future_types.iter())
|
||||
.for_each(|(t1, t2)| self.check_eq_type(&t1.type_(), t2, t1.span()));
|
||||
} else if function.input.len() > 0 {
|
||||
self.emit_err(TypeCheckerError::async_function_input_length_mismatch(
|
||||
0,
|
||||
function.input.len(),
|
||||
function.span(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Type check the function's parameters.
|
||||
function.input.iter().for_each(|input_var| {
|
||||
function.input.iter().enumerate().for_each(|(index, input_var)| {
|
||||
// Check that the type of input parameter is defined.
|
||||
self.assert_type_is_valid(&input_var.type_(), input_var.span());
|
||||
// Check that the type of the input parameter is not a tuple.
|
||||
@ -1263,12 +1295,13 @@ impl<'a> TypeChecker<'a> {
|
||||
// Check that the input parameter is not a record.
|
||||
else if let Type::Composite(struct_) = input_var.type_() {
|
||||
// Note that this unwrap is safe, as the type is defined.
|
||||
if !matches!(function.variant, Variant::Transition) && self
|
||||
.symbol_table
|
||||
.borrow()
|
||||
.lookup_struct(struct_.program.unwrap(), struct_.id.name)
|
||||
.unwrap()
|
||||
.is_record
|
||||
if !matches!(function.variant, Variant::Transition)
|
||||
&& self
|
||||
.symbol_table
|
||||
.borrow()
|
||||
.lookup_struct(struct_.program.unwrap(), struct_.id.name)
|
||||
.unwrap()
|
||||
.is_record
|
||||
{
|
||||
self.emit_err(TypeCheckerError::function_cannot_input_or_output_a_record(input_var.span()))
|
||||
}
|
||||
@ -1276,7 +1309,9 @@ impl<'a> TypeChecker<'a> {
|
||||
|
||||
// Check that the finalize input parameter is not constant or private.
|
||||
if self.is_finalize && (self.mode() == Mode::Constant || input_var.mode() == Mode::Private) {
|
||||
self.emit_err(TypeCheckerError::finalize_input_mode_must_be_public(input_var.span()));
|
||||
if (self.mode() == Mode::Constant || input_var.mode() == Mode::Private) {
|
||||
self.emit_err(TypeCheckerError::finalize_input_mode_must_be_public(input_var.span()));
|
||||
}
|
||||
}
|
||||
|
||||
// Note that this unwrap is safe since we assign to `self.variant` above.
|
||||
@ -1306,7 +1341,7 @@ impl<'a> TypeChecker<'a> {
|
||||
|
||||
// Type check the function's return type.
|
||||
// Note that checking that each of the component types are defined is sufficient to check that `output_type` is defined.
|
||||
function.output.iter().enumerate().for_each(|(index,output)| {
|
||||
function.output.iter().enumerate().for_each(|(index, output)| {
|
||||
match output {
|
||||
Output::External(external) => {
|
||||
// If the function is not a transition function, then it cannot output a record.
|
||||
@ -1348,8 +1383,13 @@ impl<'a> TypeChecker<'a> {
|
||||
self.emit_err(TypeCheckerError::async_function_must_return_single_future(function_output.span));
|
||||
}
|
||||
// Async transitions must return one future in the first position.
|
||||
if self.is_finalize_caller && ((index > 0 && matches!(function_output.type_, Type::Future(_))) || (index == 0 && !matches!(function_output.type_, Type::Future(_)))) {
|
||||
self.emit_err(TypeCheckerError::async_transition_must_return_future_as_first_output(function_output.span));
|
||||
if self.is_finalize_caller
|
||||
&& ((index > 0 && matches!(function_output.type_, Type::Future(_)))
|
||||
|| (index == 0 && !matches!(function_output.type_, Type::Future(_))))
|
||||
{
|
||||
self.emit_err(TypeCheckerError::async_transition_must_return_future_as_first_output(
|
||||
function_output.span,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user