Finished TYC pass

This commit is contained in:
evan-schott 2024-02-28 19:54:58 -08:00
parent 29f1a97ee3
commit c516db61f0
5 changed files with 264 additions and 55 deletions

View File

@ -21,6 +21,7 @@ use leo_errors::{emitter::Handler, TypeCheckerError};
use leo_span::{sym, Span}; use leo_span::{sym, Span};
use itertools::Itertools; use itertools::Itertools;
use leo_ast::CoreFunction::FutureAwait;
use snarkvm::console::network::{Network, Testnet3}; use snarkvm::console::network::{Network, Testnet3};
use std::str::FromStr; use std::str::FromStr;
@ -95,7 +96,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
// Check core struct name and function. // Check core struct name and function.
if let Some(core_instruction) = self.get_core_function_call(&access.variant, &access.name) { if let Some(core_instruction) = self.get_core_function_call(&access.variant, &access.name) {
// Check that operation is not restricted to finalize blocks. // Check that operation is not restricted to finalize blocks.
if !self.is_finalize && core_instruction.is_finalize_command() { if !self.scope_state.is_finalize && core_instruction.is_finalize_command() {
self.emit_err(TypeCheckerError::operation_must_be_in_finalize_block(input.span())); self.emit_err(TypeCheckerError::operation_must_be_in_finalize_block(input.span()));
} }
@ -114,11 +115,45 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
self.assert_type(&return_type, expected, input.span()); self.assert_type(&return_type, expected, input.span());
} }
// Await futures here so that can use the argument variable names to lookup.
if core_instruction == FutureAwait {
if access.arguments.len() != 1 {
self.emit_err(TypeCheckerError::can_only_await_one_future_at_a_time(access.span));
return Some(Type::Unit);
}
self.assert_future_await(&access.arguments.get(0), input.span());
}
return return_type; return return_type;
} else { } else {
self.emit_err(TypeCheckerError::invalid_core_function_call(access, access.span())); self.emit_err(TypeCheckerError::invalid_core_function_call(access, access.span()));
} }
} }
AccessExpression::MethodCall(call) => {
if call.name.name == sym::Await {
// Check core struct name and function.
if let Some(core_instruction) =
self.get_core_function_call(&Identifier::new(sym::Future, Default::default()), &call.name)
{
// Check that operation is not restricted to finalize blocks.
if !self.scope_state.is_finalize && core_instruction.is_finalize_command() {
self.emit_err(TypeCheckerError::operation_must_be_in_finalize_block(input.span()));
}
// Await futures here so that can use the argument variable names to lookup.
if core_instruction == FutureAwait {
self.assert_future_await(&Some(&call.receiver), input.span());
}
else {
self.emit_err(TypeCheckerError::invalid_method_call(call.span()));
}
return Some(Type::Unit);
} else {
self.emit_err(TypeCheckerError::invalid_method_call(call.span()));
}
}
}
AccessExpression::Tuple(access) => { AccessExpression::Tuple(access) => {
if let Some(type_) = self.visit_expression(&access.tuple, &None) { if let Some(type_) = self.visit_expression(&access.tuple, &None) {
match type_ { match type_ {
@ -162,7 +197,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
Expression::Identifier(identifier) if identifier.name == sym::SelfLower => match access.name.name { Expression::Identifier(identifier) if identifier.name == sym::SelfLower => match access.name.name {
sym::caller => { sym::caller => {
// Check that the operation is not invoked in a `finalize` block. // Check that the operation is not invoked in a `finalize` block.
if self.is_finalize { if self.scope_state.is_finalize {
self.handler.emit_err(TypeCheckerError::invalid_operation_inside_finalize( self.handler.emit_err(TypeCheckerError::invalid_operation_inside_finalize(
"self.caller", "self.caller",
access.name.span(), access.name.span(),
@ -172,7 +207,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
} }
sym::signer => { sym::signer => {
// Check that operation is not invoked in a `finalize` block. // Check that operation is not invoked in a `finalize` block.
if self.is_finalize { if self.scope_state.is_finalize {
self.handler.emit_err(TypeCheckerError::invalid_operation_inside_finalize( self.handler.emit_err(TypeCheckerError::invalid_operation_inside_finalize(
"self.signer", "self.signer",
access.name.span(), access.name.span(),
@ -188,7 +223,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
Expression::Identifier(identifier) if identifier.name == sym::block => match access.name.name { Expression::Identifier(identifier) if identifier.name == sym::block => match access.name.name {
sym::height => { sym::height => {
// Check that the operation is invoked in a `finalize` block. // Check that the operation is invoked in a `finalize` block.
if !self.is_finalize { if !self.scope_state.is_finalize {
self.handler.emit_err(TypeCheckerError::invalid_operation_outside_finalize( self.handler.emit_err(TypeCheckerError::invalid_operation_outside_finalize(
"block.height", "block.height",
access.name.span(), access.name.span(),
@ -236,8 +271,6 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
} }
} }
Some(Type::Future(f)) => { Some(Type::Future(f)) => {
// Retrieve the inferred input types for the future argument access.
// Make sure that the input parameter accessed is valid. // Make sure that the input parameter accessed is valid.
if let Some(arg_num) = access.name.name.to_string().parse::<usize>() { if let Some(arg_num) = access.name.name.to_string().parse::<usize>() {
// Make sure in range. // Make sure in range.
@ -288,6 +321,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
self.emit_err(TypeCheckerError::invalid_associated_constant(access, access.span)) self.emit_err(TypeCheckerError::invalid_associated_constant(access, access.span))
} }
} }
_ => {}
} }
None None
} }
@ -601,7 +635,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
if let Some(func) = func { if let Some(func) = func {
// Check that the call is valid. // Check that the call is valid.
// Note that this unwrap is safe since we always set the variant before traversing the body of the function. // Note that this unwrap is safe since we always set the variant before traversing the body of the function.
match self.variant.unwrap() { match self.scope_state.variant.unwrap() {
// If the function is not a transition function, it can only call "inline" functions. // If the function is not a transition function, it can only call "inline" functions.
Variant::Inline | Variant::Standard => { Variant::Inline | Variant::Standard => {
if !matches!(func.variant, Variant::Inline) { if !matches!(func.variant, Variant::Inline) {
@ -611,7 +645,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
// If the function is a transition function, then check that the call is not to another local transition function. // If the function is a transition function, then check that the call is not to another local transition function.
Variant::Transition => { Variant::Transition => {
if matches!(func.variant, Variant::Transition) if matches!(func.variant, Variant::Transition)
&& input.program.unwrap() == self.program_name.unwrap() && input.program.unwrap() == self.scope_state.program_name.unwrap()
{ {
self.emit_err(TypeCheckerError::cannot_invoke_call_to_local_transition_function( self.emit_err(TypeCheckerError::cannot_invoke_call_to_local_transition_function(
input.span, input.span,
@ -621,11 +655,13 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
} }
// Check that the call is not to an external `inline` function. // Check that the call is not to an external `inline` function.
if func.variant == Variant::Inline && input.program.unwrap() != self.program_name.unwrap() { if func.variant == Variant::Inline
&& input.program.unwrap() != self.scope_state.program_name.unwrap()
{
self.emit_err(TypeCheckerError::cannot_call_external_inline_function(input.span)); self.emit_err(TypeCheckerError::cannot_call_external_inline_function(input.span));
} }
let ret = self.assert_and_return_type(func.output_type, expected, input.span()); let mut ret = self.assert_and_return_type(func.output_type, expected, input.span());
// Check number of function arguments. // Check number of function arguments.
if func.input.len() != input.arguments.len() { if func.input.len() != input.arguments.len() {
@ -642,17 +678,101 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
}); });
// Add the call to the call graph. // Add the call to the call graph.
let caller_name = match self.function { let caller_name = match self.scope_state.function {
None => unreachable!("`self.function` is set every time a function is visited."), None => unreachable!("`self.function` is set every time a function is visited."),
Some(func) => func, Some(func) => func,
}; };
// Don't add external functions to call graph. // Don't add external functions to call graph. Since imports are acyclic, these can never produce a cycle.
// We check that there is no dependency cycle of imports, so we know that external functions can never lead to a call graph cycle if input.program.unwrap() == self.scope_state.program_name.unwrap() {
if input.program.unwrap() == self.program_name.unwrap() {
self.call_graph.add_edge(caller_name, ident.name); self.call_graph.add_edge(caller_name, ident.name);
} }
// Propagate futures from async functions and transitions.
if func.is_async {
// Cannot have async calls in a conditional block.
if self.scope_state.is_conditional {
self.emit_err(TypeCheckerError::async_call_in_conditional(input.span));
}
// Can only call async functions and external async transitions from an async transition body.
if !self.scope_state.is_async_transition {
self.emit_err(TypeCheckerError::async_call_can_only_be_done_from_async_transition(
input.span,
));
}
if func.variant == Variant::Transition {
// Cannot call an external async transition after having called the async function.
if self.scope_state.has_called_finalize {
self.emit_err(
TypeCheckerError::external_transition_call_must_be_before_finalize(
input.span,
),
);
}
// Fully infer future type.
let future_type = Type::Future(FutureType::new(
// Assumes that external function stubs have been processed.
self.finalize_input_types.get(&(input.program.unwrap(), ident.name)).unwrap().clone(),
));
ret = match ret.clone() {
Some(Type::Tuple(tup)) => {
// Replace first element of `tup.elements` with `future_type`. This will always be a future.
let mut elements: Vec<Type> = tup.elements().clone().to_vec();
elements[0] = future_type.clone();
Type::Tuple(TupleType::new(elements))
}
Some(Type::Future(f)) => future_type,
_ => {
self.emit_err(TypeCheckerError::async_transition_invalid_output(input.span));
ret
}
}
} else if func.variant == Variant::Standard {
// Can only call an async function once in a transition function body.
if self.scope_state.has_called_finalize {
self.emit_err(TypeCheckerError::must_call_finalize_once(input.span));
}
// Consume futures.
let st = self.symbol_table.borrow();
let mut inferred_finalize_inputs = Vec::new();
input.arguments.iter().for_each(|arg| {
if let Expression::Identifier(ident) = arg {
if let Some(variable) = st.lookup_variable(ident.name) {
if let Type::Future(_) = variable {
if !self.scope_state.futures.remove(ident) {
self.emit_err(TypeCheckerError::unknown_future_consumed(
ident.name, ident.span,
));
}
}
// Add to expected finalize inputs signature.
inferred_finalize_inputs.push(variable.clone().type_);
}
}
});
// Check that all futures consumed.
if !self.scope_state.futures.is_empty() {
self.emit_err(TypeCheckerError::not_all_futures_consumed(
self.scope_state.futures.iter().map(|f| f.name.to_string()).join(", "),
input.span,
));
}
// Create expectation for finalize inputs that will be checked when checking corresponding finalize function signature.
self.finalize_input_types.insert(
(self.scope_state.program_name.unwrap(), self.scope_state.function.unwrap()),
inferred_finalize_inputs.clone(),
);
// Set scope state flag.
self.scope_state.has_called_finalize = true;
// Update ret to reflect fully inferred future type.
ret = Type::Future(FutureType::new(inferred_finalize_inputs));
}
}
Some(ret) Some(ret)
} else { } else {
self.emit_err(TypeCheckerError::unknown_sym("function", ident.name, ident.span())); self.emit_err(TypeCheckerError::unknown_sym("function", ident.name, ident.span()));
@ -676,7 +796,8 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
} }
fn visit_struct_init(&mut self, input: &'a StructExpression, additional: &Self::AdditionalInput) -> Self::Output { fn visit_struct_init(&mut self, input: &'a StructExpression, additional: &Self::AdditionalInput) -> Self::Output {
let struct_ = self.symbol_table.borrow().lookup_struct(self.program_name.unwrap(), input.name.name).cloned(); let struct_ =
self.symbol_table.borrow().lookup_struct(self.scope_state.program_name.unwrap(), input.name.name).cloned();
if let Some(struct_) = struct_ { if let Some(struct_) = struct_ {
// Check struct type name. // Check struct type name.
let ret = self.check_expected_struct(&struct_, additional, input.name.span()); let ret = self.check_expected_struct(&struct_, additional, input.name.span());
@ -890,7 +1011,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
fn visit_unit(&mut self, input: &'a UnitExpression, _additional: &Self::AdditionalInput) -> Self::Output { fn visit_unit(&mut self, input: &'a UnitExpression, _additional: &Self::AdditionalInput) -> Self::Output {
// Unit expression are only allowed inside a return statement. // Unit expression are only allowed inside a return statement.
if !self.is_return { if !self.scope_state.is_return {
self.emit_err(TypeCheckerError::unit_expression_only_in_return_statements(input.span())); self.emit_err(TypeCheckerError::unit_expression_only_in_return_statements(input.span()));
} }
Some(Type::Unit) Some(Type::Unit)

View File

@ -23,9 +23,11 @@ use leo_span::sym;
use snarkvm::console::network::{Network, Testnet3}; use snarkvm::console::network::{Network, Testnet3};
use indexmap::IndexSet; use indexmap::IndexSet;
use leo_ast::Input::{External, Internal}; use leo_ast::{
Input::{External, Internal},
Type::Future,
};
use std::collections::HashSet; use std::collections::HashSet;
use leo_ast::Type::Future;
// TODO: Cleanup logic for tuples. // TODO: Cleanup logic for tuples.
@ -89,18 +91,24 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
// Create future stubs. // Create future stubs.
let finalize_input_map = &mut self.finalize_input_types; let finalize_input_map = &mut self.finalize_input_types;
let mut future_stubs = input.future_stubs.clone(); let mut future_stubs = input.future_stubs.clone();
let resolved_inputs = input.input.iter().map(|input_mode| { let resolved_inputs = input
.input
.iter()
.map(|input_mode| {
match input_mode { match input_mode {
Internal(function_input) => match &function_input.type_ { Internal(function_input) => match &function_input.type_ {
Future(_) => { Future(_) => {
// Since we traverse stubs in post-order, we can assume that the corresponding finalize stub has already been traversed. // Since we traverse stubs in post-order, we can assume that the corresponding finalize stub has already been traversed.
Future(FutureType::new(finalize_input_map.get(&future_stubs.pop().unwrap().to_key()).unwrap().clone())) Future(FutureType::new(
finalize_input_map.get(&future_stubs.pop().unwrap().to_key()).unwrap().clone(),
))
} }
_ => function_input.clone().type_, _ => function_input.clone().type_,
}, },
External(_) => {} External(_) => {}
} }
}).collect(); })
.collect();
assert!(future_stubs.is_empty(), "Disassembler produced malformed stub."); assert!(future_stubs.is_empty(), "Disassembler produced malformed stub.");
finalize_input_map.insert((self.scope_state.program_name.unwrap(), input.identifier.name), resolved_inputs); finalize_input_map.insert((self.scope_state.program_name.unwrap(), input.identifier.name), resolved_inputs);
@ -327,19 +335,20 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
// Initialize the list of input futures. Each one must be awaited before the end of the function. // Initialize the list of input futures. Each one must be awaited before the end of the function.
self.await_checker.set_futures( self.await_checker.set_futures(
function function
.input .input
.iter() .iter()
.filter_map(|input| match input { .filter_map(|input| match input {
Internal(parameter) => { Internal(parameter) => {
if let Some(Type::Future(ty)) = parameter.type_.clone() { if let Some(Type::Future(ty)) = parameter.type_.clone() {
Some(parameter.identifier) Some(parameter.identifier)
} else { } else {
None None
}
} }
} External(_) => None,
External(_) => None, })
}) .collect(),
.collect()); );
} }
self.visit_block(&function.block); self.visit_block(&function.block);
@ -356,7 +365,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
self.exit_scope(function_index); self.exit_scope(function_index);
// Make sure that async transitions call finalize. // Make sure that async transitions call finalize.
if self.scope_state.is_finalize_caller && !self.scope_state.has_finalize { if self.scope_state.is_async_transition && !self.scope_state.has_called_finalize {
self.emit_err(TypeCheckerError::async_transition_must_call_async_function(function.span)); self.emit_err(TypeCheckerError::async_transition_must_call_async_function(function.span));
} }

View File

@ -18,7 +18,7 @@ use crate::{ConditionalTreeNode, TreeNode, TypeChecker, VariableSymbol, Variable
use indexmap::IndexSet; use indexmap::IndexSet;
use itertools::Itertools; use itertools::Itertools;
use leo_ast::*; use leo_ast::{Type::Future, *};
use leo_errors::TypeCheckerError; use leo_errors::TypeCheckerError;
use leo_span::{Span, Symbol}; use leo_span::{Span, Symbol};
@ -77,6 +77,10 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
} }
_ => {} _ => {}
} }
// Prohibit reassignment of futures.
if let Type::Future(_) = var.type_ {
self.emit_err(TypeCheckerError::cannot_reassign_future_variable(var_name, var.span));
}
Some(var.type_.clone()) Some(var.type_.clone())
} else { } else {
@ -241,12 +245,17 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
} }
// Check the expression on the right-hand side. // Check the expression on the right-hand side.
self.visit_expression(&input.value, &Some(input.type_.clone())); let inferred_type = self.visit_expression(&input.value, &Some(input.type_.clone()));
// TODO: Dedup with unrolling pass. // TODO: Dedup with unrolling pass.
// Helper to insert the variables into the symbol table. // Helper to insert the variables into the symbol table.
let insert_variable = |symbol: Symbol, type_: Type, span: Span| { let insert_variable = |name: &Identifier, type_: Type, span: Span| {
if let Err(err) = self.symbol_table.borrow_mut().insert_variable(symbol, VariableSymbol { // Add to list of futures that must be consumed.
if let Type::Future(_) = type_ {
self.scope_state.futures.insert(name.clone());
}
// Insert the variable into the symbol table.
if let Err(err) = self.symbol_table.borrow_mut().insert_variable(name.name, VariableSymbol {
type_, type_,
span, span,
declaration: VariableType::Mut, declaration: VariableType::Mut,
@ -257,9 +266,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
// Insert the variables into the symbol table. // Insert the variables into the symbol table.
match &input.place { match &input.place {
Expression::Identifier(identifier) => { Expression::Identifier(identifier) => insert_variable(identifier, input.type_.clone(), identifier.span),
insert_variable(identifier.name, input.type_.clone(), identifier.span)
}
Expression::Tuple(tuple_expression) => { Expression::Tuple(tuple_expression) => {
let tuple_type = match &input.type_ { let tuple_type = match &input.type_ {
Type::Tuple(tuple_type) => tuple_type, Type::Tuple(tuple_type) => tuple_type,
@ -285,7 +292,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
)); ));
} }
}; };
insert_variable(identifier.name, type_.clone(), identifier.span) insert_variable(identifier, type_.clone(), identifier.span)
}, },
); );
} }
@ -323,7 +330,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
} }
let prior_has_return = core::mem::take(&mut self.scope_state.has_return); let prior_has_return = core::mem::take(&mut self.scope_state.has_return);
let prior_has_finalize = core::mem::take(&mut self.scope_state.has_finalize); let prior_has_finalize = core::mem::take(&mut self.scope_state.has_called_finalize);
self.visit_block(&input.block); self.visit_block(&input.block);
@ -331,12 +338,12 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
self.emit_err(TypeCheckerError::loop_body_contains_return(input.span())); self.emit_err(TypeCheckerError::loop_body_contains_return(input.span()));
} }
if self.scope_state.has_finalize { if self.scope_state.has_called_finalize {
self.emit_err(TypeCheckerError::loop_body_contains_finalize(input.span())); self.emit_err(TypeCheckerError::loop_body_contains_finalize(input.span()));
} }
self.scope_state.has_return = prior_has_return; self.scope_state.has_return = prior_has_return;
self.scope_state.has_finalize = prior_has_finalize; self.scope_state.has_called_finalize = prior_has_finalize;
// Exit the scope. // Exit the scope.
self.exit_scope(scope_index); self.exit_scope(scope_index);
@ -384,15 +391,44 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
} }
fn visit_return(&mut self, input: &'a ReturnStatement) { fn visit_return(&mut self, input: &'a ReturnStatement) {
// Cannot return anything from finalize.
if self.scope_state.is_finalize {
self.emit_err(TypeCheckerError::return_in_finalize(input.span()));
}
// We can safely unwrap all self.parent instances because // We can safely unwrap all self.parent instances because
// statements should always have some parent block // statements should always have some parent block
let parent = self.scope_state.function.unwrap(); let parent = self.scope_state.function.unwrap();
let return_type = &self let mut return_type = &self
.symbol_table .symbol_table
.borrow() .borrow()
.lookup_fn_symbol(self.scope_state.program_name.unwrap(), parent) .lookup_fn_symbol(self.scope_state.program_name.unwrap(), parent)
.map(|f| f.output_type.clone()); .map(|f| f.output_type.clone());
// Fully type the expected return value.
if self.scope_state.is_async_transition {
let inferred_future_type = match self
.finalize_input_types
.get(&(self.scope_state.program_name.unwrap(), self.scope_state.function.unwrap()))
{
Some(types) => Future(FutureType::new(types.clone())),
None => {
return self.emit_err(TypeCheckerError::async_transition_missing_future_to_return(input.span()));
}
};
return_type = &match return_type {
Some(Future(_)) => Some(inferred_future_type),
Some(Type::Tuple(tuple)) => {
let mut elements = tuple.elements().clone().to_vec();
elements[0] = inferred_future_type;
Some(Type::new(elements))
}
_ => {
self.emit_err(TypeCheckerError::async_transition_invalid_output_type(input.span()));
None
}
}
}
// Set the `has_return` flag. // Set the `has_return` flag.
self.scope_state.has_return = true; self.scope_state.has_return = true;

View File

@ -16,7 +16,23 @@
use crate::{CallGraph, StructGraph, SymbolTable, TreeNode, TypeTable, VariableSymbol, VariableType}; use crate::{CallGraph, StructGraph, SymbolTable, TreeNode, TypeTable, VariableSymbol, VariableType};
use leo_ast::{Composite, CompositeType, CoreConstant, CoreFunction, Function, Identifier, Input, IntegerType, MappingType, Mode, Node, Output, Type, Variant}; use leo_ast::{
Composite,
CompositeType,
CoreConstant,
CoreFunction,
Expression,
Function,
Identifier,
Input,
IntegerType,
MappingType,
Mode,
Node,
Output,
Type,
Variant,
};
use leo_errors::{emitter::Handler, TypeCheckerError, TypeCheckerWarning}; use leo_errors::{emitter::Handler, TypeCheckerError, TypeCheckerWarning};
use leo_span::{Span, Symbol}; use leo_span::{Span, Symbol};
@ -1070,10 +1086,7 @@ impl<'a> TypeChecker<'a> {
// Return a boolean. // Return a boolean.
Some(Type::Boolean) Some(Type::Boolean)
} }
CoreFunction::FutureAwait => { CoreFunction::FutureAwait => Some(Type::Unit),
// TODO: check that were in finalize here?
None
}
} }
} }
@ -1357,13 +1370,11 @@ impl<'a> TypeChecker<'a> {
self.emit_err(TypeCheckerError::async_function_must_return_single_future(function_output.span)); self.emit_err(TypeCheckerError::async_function_must_return_single_future(function_output.span));
} }
// Async transitions must return one future in the first position. // Async transitions must return one future in the first position.
if self.scope_state.is_finalize_caller if self.scope_state.is_async_transition
&& ((index > 0 && matches!(function_output.type_, Type::Future(_))) && ((index > 0 && matches!(function_output.type_, Type::Future(_)))
|| (index == 0 && !matches!(function_output.type_, Type::Future(_)))) || (index == 0 && !matches!(function_output.type_, Type::Future(_))))
{ {
self.emit_err(TypeCheckerError::async_transition_must_return_future_as_first_output( self.emit_err(TypeCheckerError::async_transition_invalid_output(function_output.span));
function_output.span,
));
} }
} }
} }
@ -1383,6 +1394,31 @@ impl<'a> TypeChecker<'a> {
} }
} }
} }
/// Type checks the awaiting of a future.
pub(crate) fn assert_future_await(&mut self, future: &Option<&Expression>, span: Span) {
// Make sure that it is an identifier expression.
let future_variable = match future {
Some(Expression::Identifier(name)) => name,
_ => {
return self.emit_err(TypeCheckerError::invalid_await_call(span));
}
};
// Make sure that the future is defined.
match self.symbol_table.borrow().lookup_variable(future_variable.name) {
Some(var) => {
if !matches!(&var.type_, &Type::Future(_)) {
self.emit_err(TypeCheckerError::expected_future(future_variable.name, future_variable.span()));
}
// Mark the future as consumed.
self.await_checker.remove(future_variable);
}
None => {
self.emit_err(TypeCheckerError::expected_future(future_variable.name, future_variable.span()));
}
}
}
} }
fn types_to_string(types: &[Type]) -> String { fn types_to_string(types: &[Type]) -> String {

View File

@ -923,4 +923,11 @@ create_messages!(
msg: "Cannot return a value in an async function block.".to_string(), msg: "Cannot return a value in an async function block.".to_string(),
help: Some("Async functions execute on-chain. Since async transitions call async functions, and async transitions execute offline, it would be impossible for the async function to be able to return on-chain state to the transition function.".to_string()), help: Some("Async functions execute on-chain. Since async transitions call async functions, and async transitions execute offline, it would be impossible for the async function to be able to return on-chain state to the transition function.".to_string()),
} }
@formatted
async_transition_missing_future_to_return {
args: (),
msg: "An async transition must return a future.".to_string(),
help: Some("Call an async function inside of the async transition body so that there is a future to return.".to_string()),
}
); );