Refactor Variant to account for asynchronous function variants

This commit is contained in:
evan-schott 2024-04-03 17:00:03 -07:00
parent e7ac5e0bf1
commit 9889aa01c5
21 changed files with 192 additions and 122 deletions

View File

@ -46,8 +46,6 @@ use std::fmt;
pub struct Function {
/// Annotations on the function.
pub annotations: Vec<Annotation>,
/// Is this function asynchronous or synchronous?
pub is_async: bool,
/// Is this function a transition, inlined, or a regular function?.
pub variant: Variant,
/// The function identifier, e.g., `foo` in `function foo(...) { ... }`.
@ -79,7 +77,6 @@ impl Function {
#[allow(clippy::too_many_arguments)]
pub fn new(
annotations: Vec<Annotation>,
is_async: bool,
variant: Variant,
identifier: Identifier,
input: Vec<Input>,
@ -100,7 +97,7 @@ impl Function {
_ => Type::Tuple(TupleType::new(output.iter().map(get_output_type).collect())),
};
Function { annotations, is_async, variant, identifier, input, output, output_type, block, span, id }
Function { annotations, variant, identifier, input, output, output_type, block, span, id }
}
/// Returns function name.
@ -114,8 +111,8 @@ impl Function {
fn format(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.variant {
Variant::Inline => write!(f, "inline ")?,
Variant::Standard => write!(f, "function ")?,
Variant::Transition => write!(f, "transition ")?,
Variant::Function | Variant::AsyncFunction => write!(f, "function ")?,
Variant::Transition | Variant::AsyncTransition => write!(f, "transition ")?,
}
write!(f, "{}", self.identifier)?;
@ -135,7 +132,6 @@ impl From<FunctionStub> for Function {
fn from(function: FunctionStub) -> Self {
Self {
annotations: function.annotations,
is_async: function.is_async,
variant: function.variant,
identifier: function.identifier,
input: function.input,

View File

@ -16,13 +16,43 @@
use serde::{Deserialize, Serialize};
/// Functions are always one of three variants.
/// Functions are always one of five variants.
/// A transition function is permitted the ability to manipulate records.
/// An asynchronous transition function is a transition function that calls an asynchronous function.
/// A regular function is not permitted to manipulate records.
/// An asynchronous function contains on-chain operations.
/// An inline function is directly copied at the call site.
#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum Variant {
Inline,
Standard,
Function,
Transition,
AsyncTransition,
AsyncFunction,
}
impl Variant {
/// Returns true if the variant is async.
pub fn is_async(self) -> bool {
match self {
Variant::AsyncFunction | Variant::AsyncTransition => true,
_ => false,
}
}
/// Returns true if the variant is a transition.
pub fn is_transition(self) -> bool {
match self {
Variant::Transition | Variant::AsyncTransition => true,
_ => false,
}
}
/// Returns true if the variant is a function.
pub fn is_function(self) -> bool {
match self {
Variant::Function | Variant::AsyncFunction => true,
_ => false,
}
}
}

View File

@ -475,7 +475,6 @@ pub trait ProgramReconstructor: StatementReconstructor {
fn reconstruct_function(&mut self, input: Function) -> Function {
Function {
annotations: input.annotations,
is_async: input.is_async,
variant: input.variant,
identifier: input.identifier,
input: input.input,

View File

@ -54,8 +54,6 @@ use std::fmt;
pub struct FunctionStub {
/// Annotations on the function.
pub annotations: Vec<Annotation>,
/// Is this function asynchronous or synchronous?
pub is_async: bool,
/// Is this function a transition, inlined, or a regular function?.
pub variant: Variant,
/// The function identifier, e.g., `foo` in `function foo(...) { ... }`.
@ -109,7 +107,6 @@ impl FunctionStub {
FunctionStub {
annotations,
is_async,
variant,
identifier,
future_locations: Vec::new(),
@ -137,8 +134,8 @@ impl FunctionStub {
fn format(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.variant {
Variant::Inline => write!(f, "inline ")?,
Variant::Standard => write!(f, "function ")?,
Variant::Transition => write!(f, "transition ")?,
Variant::Function | Variant::AsyncFunction => write!(f, "function ")?,
Variant::Transition | Variant::AsyncTransition => write!(f, "transition ")?,
}
write!(f, "{}", self.identifier)?;
@ -218,8 +215,10 @@ impl FunctionStub {
Self {
annotations: Vec::new(),
is_async: function.finalize_logic().is_some(),
variant: Variant::Transition,
variant: match function.finalize_logic().is_some() {
true => Variant::AsyncTransition,
false => Variant::Transition,
},
identifier: Identifier::from(function.name()),
future_locations: Vec::new(),
input: function
@ -281,8 +280,7 @@ impl FunctionStub {
) -> Self {
Self {
annotations: Vec::new(),
is_async: true,
variant: Variant::Standard,
variant: Variant::AsyncFunction,
identifier: Identifier::new(name, Default::default()),
future_locations: function
.finalize_logic()
@ -291,7 +289,7 @@ impl FunctionStub {
.iter()
.filter_map(|input| match input.finalize_type() {
FinalizeType::Future(val) => Some(Location::new(
Identifier::from(val.program_id().name()).name,
Some(Identifier::from(val.program_id().name()).name),
Symbol::intern(&format!("finalize/{}", val.resource())),
)),
_ => None,
@ -361,8 +359,7 @@ impl FunctionStub {
};
Self {
annotations: Vec::new(),
is_async: false,
variant: Variant::Standard,
variant: Variant::Function,
identifier: Identifier::from(closure.name()),
future_locations: Vec::new(),
input: closure
@ -397,7 +394,6 @@ impl From<Function> for FunctionStub {
fn from(function: Function) -> Self {
Self {
annotations: function.annotations,
is_async: function.is_async,
variant: function.variant,
identifier: function.identifier,
future_locations: Vec::new(),

View File

@ -138,7 +138,7 @@ impl ParserContext<'_> {
let (id, function) = self.parse_function()?;
// Partition into transitions and functions so that don't have to sort later.
if function.variant == Variant::Transition {
if function.variant.is_transition() {
transitions.push((id, function));
} else {
functions.push((id, function));
@ -409,10 +409,12 @@ impl ParserContext<'_> {
let (is_async, start_async) =
if self.token.token == Token::Async { (true, self.expect(&Token::Async)?) } else { (false, Span::dummy()) };
// Parse `<variant> IDENT`, where `<variant>` is `function`, `transition`, or `inline`.
let (variant, start) = match self.token.token {
let (variant, start) = match self.token.token.clone() {
Token::Inline => (Variant::Inline, self.expect(&Token::Inline)?),
Token::Function => (Variant::Standard, self.expect(&Token::Function)?),
Token::Transition => (Variant::Transition, self.expect(&Token::Transition)?),
Token::Function => {
(if is_async { Variant::AsyncFunction } else { Variant::Function }, self.expect(&Token::Function)?)
}
Token::Transition => (if is_async { Variant::AsyncTransition } else { Variant::Transition }, self.expect(&Token::Transition)?),
_ => self.unexpected("'function', 'transition', or 'inline'")?,
};
let name = self.expect_identifier()?;
@ -450,7 +452,6 @@ impl ParserContext<'_> {
name.name,
Function::new(
annotations,
is_async,
variant,
name,
inputs,

View File

@ -14,8 +14,35 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use crate::{CodeGenerator};
use leo_ast::{AccessExpression, ArrayAccess, ArrayExpression, AssociatedConstant, AssociatedFunction, BinaryExpression, BinaryOperation, CallExpression, CastExpression, ErrExpression, Expression, Identifier, Literal, Location, LocatorExpression, MemberAccess, MethodCall, Node, StructExpression, TernaryExpression, TupleExpression, Type, UnaryExpression, UnaryOperation, UnitExpression, Variant};
use crate::CodeGenerator;
use leo_ast::{
AccessExpression,
ArrayAccess,
ArrayExpression,
AssociatedConstant,
AssociatedFunction,
BinaryExpression,
BinaryOperation,
CallExpression,
CastExpression,
ErrExpression,
Expression,
Identifier,
Literal,
Location,
LocatorExpression,
MemberAccess,
MethodCall,
Node,
StructExpression,
TernaryExpression,
TupleExpression,
Type,
UnaryExpression,
UnaryOperation,
UnitExpression,
Variant,
};
use leo_span::sym;
use std::borrow::Borrow;
@ -513,7 +540,7 @@ impl<'a> CodeGenerator<'a> {
} else {
// Lookup in symbol table to determine if its an async function.
if let Some(func) = self.symbol_table.lookup_fn_symbol(Location::new(input.program, function_name)) {
if func.is_async && input.program.unwrap() == self.program_id.unwrap().name.name {
if func.variant.is_async() && input.program.unwrap() == self.program_id.unwrap().name.name {
format!(" async {}", self.current_function.unwrap().identifier)
} else {
format!(" call {}", input.function)
@ -534,8 +561,7 @@ impl<'a> CodeGenerator<'a> {
let mut destinations = Vec::new();
// Create operands for the output registers.
let func =
&self.symbol_table.lookup_fn_symbol(Location::new(Some(main_program), function_name)).unwrap();
let func = &self.symbol_table.lookup_fn_symbol(Location::new(Some(main_program), function_name)).unwrap();
match func.output_type.clone() {
Type::Unit => {} // Do nothing
Type::Tuple(tuple) => match tuple.length() {
@ -556,7 +582,7 @@ impl<'a> CodeGenerator<'a> {
}
// Add a register for async functions to represent the future created.
if func.is_async && func.variant == Variant::Standard {
if func.variant == Variant::AsyncFunction {
let destination_register = format!("r{}", self.next_register);
destinations.push(destination_register);
self.next_register += 1;

View File

@ -16,7 +16,7 @@
use crate::CodeGenerator;
use leo_ast::{functions, Composite, Function, Mapping, Mode, Program, ProgramScope, Type, Variant, Location};
use leo_ast::{functions, Composite, Function, Location, Mapping, Mode, Program, ProgramScope, Type, Variant};
use indexmap::IndexMap;
use itertools::Itertools;
@ -84,7 +84,7 @@ impl<'a> CodeGenerator<'a> {
.functions
.iter()
.map(|(_, function)| {
if !(function.is_async && function.variant == Variant::Standard) {
if function.variant != Variant::AsyncFunction {
// Set the `is_transition_function` flag.
self.is_transition_function = matches!(function.variant, Variant::Transition);
@ -94,15 +94,15 @@ impl<'a> CodeGenerator<'a> {
self.is_transition_function = false;
// Attach the associated finalize to async transitions.
if function.variant == Variant::Transition && function.is_async {
if function.variant == Variant::AsyncTransition {
// Set state variables.
self.is_transition_function = false;
self.finalize_caller = Some(function.identifier.name.clone());
// Generate code for the associated finalize function.
let finalize = &self
.symbol_table
.lookup_fn_symbol(
Location::new(Some(self.program_id.unwrap().name.name),
.lookup_fn_symbol(Location::new(
Some(self.program_id.unwrap().name.name),
function.identifier.name,
))
.unwrap()
@ -178,7 +178,7 @@ impl<'a> CodeGenerator<'a> {
// Initialize the state of `self` with the appropriate values before visiting `function`.
self.next_register = 0;
self.variable_mapping = IndexMap::new();
self.in_finalize = function.is_async && function.variant == Variant::Standard;
self.in_finalize = function.variant == Variant::AsyncFunction;
// TODO: Figure out a better way to initialize.
self.variable_mapping.insert(&sym::SelfLower, "self".to_string());
self.variable_mapping.insert(&sym::block, "block".to_string());
@ -188,11 +188,11 @@ impl<'a> CodeGenerator<'a> {
// If a function is a program function, generate an Aleo `function`,
// if it is a standard function generate an Aleo `closure`,
// otherwise, it is an inline function, in which case a function should not be generated.
let mut function_string = match (function.is_async, function.variant) {
(_, Variant::Transition) => format!("\nfunction {}:\n", function.identifier),
(false, Variant::Standard) => format!("\nclosure {}:\n", function.identifier),
(true, Variant::Standard) => format!("\nfinalize {}:\n", self.finalize_caller.unwrap()),
(_, Variant::Inline) => return String::from("\n"),
let mut function_string = match function.variant {
Variant::Transition | Variant::AsyncTransition => format!("\nfunction {}:\n", function.identifier),
Variant::Function => format!("\nclosure {}:\n", function.identifier),
Variant::AsyncFunction => format!("\nfinalize {}:\n", self.finalize_caller.unwrap()),
Variant::Inline => return String::from("\n"),
};
// Construct and append the input declarations of the function.

View File

@ -26,8 +26,6 @@ use crate::SymbolTable;
pub struct FunctionSymbol {
/// The index associated with the scope in the parent symbol table.
pub(crate) id: usize,
/// Whether the function is asynchronous or not.
pub(crate) is_async: bool,
/// The output type of the function.
pub(crate) output_type: Type,
/// Is this function a transition, inlined, or a regular function?.
@ -46,7 +44,6 @@ impl SymbolTable {
pub(crate) fn new_function_symbol(id: usize, func: &Function) -> FunctionSymbol {
FunctionSymbol {
id,
is_async: func.is_async,
output_type: func.output_type.clone(),
variant: func.variant,
_span: func.span,

View File

@ -247,7 +247,6 @@ mod tests {
let func_loc = Location::new(Some(Symbol::intern("credits")), Symbol::intern("transfer_public"));
let insert = Function {
annotations: Vec::new(),
is_async: false,
id: 0,
output_type: Type::Address,
variant: Variant::Inline,

View File

@ -29,7 +29,6 @@ impl ProgramReconstructor for DeadCodeEliminator<'_> {
Function {
annotations: input.annotations,
is_async: input.is_async,
variant: input.variant,
identifier: input.identifier,
input: input.input,

View File

@ -14,9 +14,19 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use crate::{Flattener};
use crate::Flattener;
use leo_ast::{Expression, ExpressionReconstructor, Location, Node, Statement, StructExpression, StructVariableInitializer, TernaryExpression, Type};
use leo_ast::{
Expression,
ExpressionReconstructor,
Location,
Node,
Statement,
StructExpression,
StructVariableInitializer,
TernaryExpression,
Type,
};
impl ExpressionReconstructor for Flattener<'_> {
type AdditionalOutput = Vec<Statement>;

View File

@ -32,7 +32,6 @@ impl ProgramReconstructor for Flattener<'_> {
Function {
annotations: function.annotations,
is_async: function.is_async,
variant: function.variant,
identifier: function.identifier,
input: function.input,

View File

@ -53,7 +53,6 @@ impl ExpressionReconstructor for FunctionInliner<'_> {
// Inline the callee function, if required, otherwise, return the call expression.
match callee.variant {
Variant::Transition | Variant::Standard => (Expression::Call(input), Default::default()),
Variant::Inline => {
// Construct a mapping from input variables of the callee function to arguments passed to the callee.
let parameter_to_argument = callee
@ -103,6 +102,7 @@ impl ExpressionReconstructor for FunctionInliner<'_> {
(result, inlined_statements)
}
_ => (Expression::Call(input), Default::default()),
}
}
}

View File

@ -16,7 +16,7 @@
use leo_ast::*;
use crate::{Unroller};
use crate::Unroller;
impl ProgramReconstructor for Unroller<'_> {
fn reconstruct_stub(&mut self, input: Stub) -> Stub {
@ -92,7 +92,6 @@ impl ProgramReconstructor for Unroller<'_> {
// Reconstruct the function block.
let reconstructed_function = Function {
is_async: function.is_async,
annotations: function.annotations,
variant: function.variant,
identifier: function.identifier,

View File

@ -14,9 +14,33 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use crate::{StaticSingleAssigner};
use crate::StaticSingleAssigner;
use leo_ast::{AccessExpression, ArrayAccess, ArrayExpression, AssociatedFunction, BinaryExpression, CallExpression, CastExpression, Composite, Expression, ExpressionConsumer, Identifier, Literal, Location, LocatorExpression, MemberAccess, Statement, StructExpression, StructVariableInitializer, TernaryExpression, TupleAccess, TupleExpression, UnaryExpression, UnitExpression};
use leo_ast::{
AccessExpression,
ArrayAccess,
ArrayExpression,
AssociatedFunction,
BinaryExpression,
CallExpression,
CastExpression,
Composite,
Expression,
ExpressionConsumer,
Identifier,
Literal,
Location,
LocatorExpression,
MemberAccess,
Statement,
StructExpression,
StructVariableInitializer,
TernaryExpression,
TupleAccess,
TupleExpression,
UnaryExpression,
UnitExpression,
};
use leo_span::{sym, Symbol};
use indexmap::IndexMap;

View File

@ -81,7 +81,6 @@ impl FunctionConsumer for StaticSingleAssigner<'_> {
Function {
annotations: function.annotations,
is_async: function.is_async,
variant: function.variant,
identifier: function.identifier,
input: function.input,

View File

@ -14,7 +14,7 @@
// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
use crate::{TypeChecker};
use crate::TypeChecker;
use leo_ast::*;
use leo_errors::{emitter::Handler, TypeCheckerError};
@ -24,10 +24,10 @@ use itertools::Itertools;
use leo_ast::{
CoreFunction::FutureAwait,
Type::{Future, Tuple},
Variant::Standard,
};
use snarkvm::console::network::{MainnetV0, Network};
use std::str::FromStr;
use leo_ast::Variant::{Transition, Function, AsyncFunction, AsyncTransition};
fn return_incorrect_type(t1: Option<Type>, t2: Option<Type>, expected: &Option<Type>) -> Option<Type> {
match (t1, t2) {
@ -101,7 +101,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
// Check core struct name and function.
if let Some(core_instruction) = self.get_core_function_call(&access.variant, &access.name) {
// Check that operation is not restricted to finalize blocks.
if !self.scope_state.is_finalize && core_instruction.is_finalize_command() {
if self.scope_state.variant != Some(Variant::AsyncFunction) && core_instruction.is_finalize_command() {
self.emit_err(TypeCheckerError::operation_must_be_in_finalize_block(input.span()));
}
@ -142,7 +142,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
self.get_core_function_call(&Identifier::new(sym::Future, Default::default()), &call.name)
{
// Check that operation is not restricted to finalize blocks.
if !self.scope_state.is_finalize && core_instruction.is_finalize_command() {
if self.scope_state.variant != Some(AsyncFunction) && core_instruction.is_finalize_command() {
self.emit_err(TypeCheckerError::operation_must_be_in_finalize_block(input.span()));
}
@ -222,7 +222,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
Expression::Identifier(identifier) if identifier.name == sym::SelfLower => match access.name.name {
sym::caller => {
// Check that the operation is not invoked in a `finalize` block.
if self.scope_state.is_finalize {
if self.scope_state.variant == Some(Variant::AsyncFunction) {
self.handler.emit_err(TypeCheckerError::invalid_operation_inside_finalize(
"self.caller",
access.name.span(),
@ -232,7 +232,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
}
sym::signer => {
// Check that operation is not invoked in a `finalize` block.
if self.scope_state.is_finalize {
if self.scope_state.variant == Some(Variant::AsyncFunction) {
self.handler.emit_err(TypeCheckerError::invalid_operation_inside_finalize(
"self.signer",
access.name.span(),
@ -248,7 +248,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
Expression::Identifier(identifier) if identifier.name == sym::block => match access.name.name {
sym::height => {
// Check that the operation is invoked in a `finalize` block.
if !self.scope_state.is_finalize {
if self.scope_state.variant != Some(Variant::AsyncFunction) {
self.handler.emit_err(TypeCheckerError::invalid_operation_outside_finalize(
"block.height",
access.name.span(),
@ -636,22 +636,11 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
// Check that the call is valid.
// Note that this unwrap is safe since we always set the variant before traversing the body of the function.
match self.scope_state.variant.unwrap() {
// If the function is not a transition function, it can only call "inline" functions.
Variant::Inline | Variant::Standard => {
if !matches!(func.variant, Variant::Inline) {
self.emit_err(TypeCheckerError::can_only_call_inline_function(input.span));
}
}
// If the function is a transition function, then check that the call is not to another local transition function.
Variant::Transition => {
if matches!(func.variant, Variant::Transition)
&& input.program.unwrap() == self.scope_state.program_name.unwrap()
{
self.emit_err(TypeCheckerError::cannot_invoke_call_to_local_transition_function(
input.span,
));
}
}
Variant::AsyncFunction | Variant::Function if !matches!(func.variant, Variant::Inline) => self.emit_err(TypeCheckerError::can_only_call_inline_function(input.span)),
Variant::Transition | Variant::AsyncTransition if matches!(func.variant, Variant::Transition) && input.program.unwrap() == self.scope_state.program_name.unwrap() => self.emit_err(TypeCheckerError::cannot_invoke_call_to_local_transition_function(
input.span,
)),
_ => {}
}
// Check that the call is not to an external `inline` function.
@ -661,7 +650,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
self.emit_err(TypeCheckerError::cannot_call_external_inline_function(input.span));
}
// Async functions return a single future.
let mut ret = if func.is_async && func.variant == Standard {
let mut ret = if func.variant == AsyncFunction {
if let Some(Type::Future(_)) = expected {
Type::Future(FutureType::new(Vec::new()))
} else {
@ -687,7 +676,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
func.input.iter().zip(input.arguments.iter()).for_each(|(expected, argument)| {
let ty = self.visit_expression(argument, &Some(expected.type_()));
// Extract information about futures that are being consumed.
if func.is_async && func.variant == Standard && matches!(expected.type_(), Type::Future(_)) {
if func.variant == AsyncFunction && matches!(expected.type_(), Type::Future(_)) {
match argument {
Expression::Identifier(_)
| Expression::Call(_)
@ -732,20 +721,20 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
}
// Propagate futures from async functions and transitions.
if func.is_async {
if func.variant.is_async() {
// Cannot have async calls in a conditional block.
if self.scope_state.is_conditional {
self.emit_err(TypeCheckerError::async_call_in_conditional(input.span));
}
// Can only call async functions and external async transitions from an async transition body.
if !self.scope_state.is_async_transition {
if self.scope_state.variant != Some(AsyncTransition) {
self.emit_err(TypeCheckerError::async_call_can_only_be_done_from_async_transition(
input.span,
));
}
if func.variant == Variant::Transition {
if func.variant.is_transition() {
// Cannot call an external async transition after having called the async function.
if self.scope_state.has_called_finalize {
self.emit_err(TypeCheckerError::external_transition_call_must_be_before_finalize(
@ -776,7 +765,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
ret
}
};
} else if func.variant == Variant::Standard {
} else if func.variant.is_function() {
// Can only call an async function once in a transition function body.
if self.scope_state.has_called_finalize {
self.emit_err(TypeCheckerError::must_call_finalize_once(input.span));
@ -837,8 +826,11 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
}
fn visit_struct_init(&mut self, input: &'a StructExpression, additional: &Self::AdditionalInput) -> Self::Output {
let struct_ =
self.symbol_table.borrow().lookup_struct(Location::new(self.scope_state.program_name, input.name.name)).cloned();
let struct_ = self
.symbol_table
.borrow()
.lookup_struct(Location::new(self.scope_state.program_name, input.name.name))
.cloned();
if let Some(struct_) = struct_ {
// Check struct type name.
let ret = self.check_expected_struct(&struct_, additional, input.name.span());
@ -886,7 +878,7 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> {
fn visit_identifier(&mut self, input: &'a Identifier, expected: &Self::AdditionalInput) -> Self::Output {
if let Some(var) = self.symbol_table.borrow().lookup_variable(Location::new(None, input.name)) {
if matches!(var.type_, Type::Future(_)) && matches!(expected, Some(Type::Future(_))) {
if self.scope_state.is_async_transition && self.scope_state.is_call {
if self.scope_state.variant == Some(AsyncTransition) && self.scope_state.is_call {
// Consume future.
match self.scope_state.futures.remove(&input.name) {
Some(future) => {

View File

@ -27,6 +27,7 @@ use leo_ast::{
Type::Future,
};
use std::collections::HashSet;
use leo_ast::Variant::{AsyncFunction, AsyncTransition};
// TODO: Cleanup logic for tuples.
@ -88,7 +89,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
let scope_index = self.create_child_scope();
// Create future stubs.
if input.variant == Variant::Standard && input.is_async {
if input.variant == Variant::AsyncFunction {
let finalize_input_map = &mut self.finalize_input_types;
let mut future_stubs = input.future_locations.clone();
let resolved_inputs: Vec<Type> = input
@ -302,7 +303,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
}
// Set type checker variables for function variant details.
self.scope_state.initialize_function_state(function.variant, function.is_async);
self.scope_state.initialize_function_state(function.variant);
// Lookup function metadata in the symbol table.
// Note that this unwrap is safe since function metadata is stored in a prior pass.
@ -328,7 +329,7 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
// Query helper function to type check function parameters and outputs.
self.check_function_signature(function);
if self.scope_state.is_finalize {
if self.scope_state.variant == Some(Variant::AsyncFunction) {
// Async functions cannot have empty blocks
if function.block.statements.is_empty() {
self.emit_err(TypeCheckerError::finalize_block_must_not_be_empty(function.block.span));
@ -367,12 +368,12 @@ impl<'a> ProgramVisitor<'a> for TypeChecker<'a> {
self.exit_scope(function_index);
// Make sure that async transitions call finalize.
if self.scope_state.is_async_transition && !self.scope_state.has_called_finalize {
if self.scope_state.variant == Some(AsyncTransition) && !self.scope_state.has_called_finalize {
self.emit_err(TypeCheckerError::async_transition_must_call_async_function(function.span));
}
// Check that all futures were awaited exactly once.
if self.scope_state.is_finalize {
if self.scope_state.variant == Some(AsyncFunction) {
// Throw error if not all futures awaits even appear once.
if !self.await_checker.static_to_await.is_empty() {
self.emit_err(TypeCheckerError::future_awaits_missing(

View File

@ -120,7 +120,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
// Create scope for checking awaits in `then` branch of conditional.
let current_bst_nodes: Vec<ConditionalTreeNode> =
match self.await_checker.create_then_scope(self.scope_state.is_finalize, input.span) {
match self.await_checker.create_then_scope(self.scope_state.variant == Some(Variant::AsyncFunction), input.span) {
Ok(nodes) => nodes,
Err(err) => return self.emit_err(err),
};
@ -132,7 +132,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
then_block_has_return = self.scope_state.has_return;
// Exit scope for checking awaits in `then` branch of conditional.
let saved_paths = self.await_checker.exit_then_scope(self.scope_state.is_finalize, current_bst_nodes);
let saved_paths = self.await_checker.exit_then_scope(self.scope_state.variant == Some(Variant::AsyncFunction), current_bst_nodes);
if let Some(otherwise) = &input.otherwise {
// Set the `has_return` flag for the otherwise-block.
@ -152,7 +152,7 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
}
// Update the set of all possible BST paths.
self.await_checker.exit_statement_scope(self.scope_state.is_finalize, saved_paths);
self.await_checker.exit_statement_scope(self.scope_state.variant == Some(Variant::AsyncFunction), saved_paths);
// Restore the previous `has_return` flag.
self.scope_state.has_return = previous_has_return || (then_block_has_return && otherwise_block_has_return);
@ -385,17 +385,18 @@ impl<'a> StatementVisitor<'a> for TypeChecker<'a> {
fn visit_return(&mut self, input: &'a ReturnStatement) {
// Cannot return anything from finalize.
if self.scope_state.is_finalize {
if self.scope_state.variant == Some(Variant::AsyncFunction) {
self.emit_err(TypeCheckerError::return_in_finalize(input.span()));
}
// We can safely unwrap all self.parent instances because
// statements should always have some parent block
let parent = self.scope_state.function.unwrap();
let func = self.symbol_table.borrow().lookup_fn_symbol(Location::new(self.scope_state.program_name, parent)).cloned();
let func =
self.symbol_table.borrow().lookup_fn_symbol(Location::new(self.scope_state.program_name, parent)).cloned();
let mut return_type = func.clone().map(|f| f.output_type.clone());
// Fully type the expected return value.
if self.scope_state.is_async_transition && self.scope_state.has_called_finalize {
if self.scope_state.variant == Some(Variant::AsyncTransition) && self.scope_state.has_called_finalize {
let inferred_future_type = match self.finalize_input_types.get(&func.unwrap().finalize.clone().unwrap()) {
Some(types) => Future(FutureType::new(types.clone())),
None => {

View File

@ -47,6 +47,7 @@ use leo_ast::{
Type::{Future, Tuple},
};
use std::cell::RefCell;
use leo_ast::Variant::AsyncTransition;
pub struct TypeChecker<'a> {
/// The symbol table for the program.
@ -975,7 +976,7 @@ impl<'a> TypeChecker<'a> {
}
CoreFunction::MappingGet => {
// Check that the operation is invoked in a `finalize` block.
if !self.scope_state.is_finalize {
if self.scope_state.variant != Some(Variant::AsyncFunction) {
self.handler
.emit_err(TypeCheckerError::invalid_operation_outside_finalize("Mapping::get", function_span))
}
@ -991,7 +992,7 @@ impl<'a> TypeChecker<'a> {
}
CoreFunction::MappingGetOrUse => {
// Check that the operation is invoked in a `finalize` block.
if !self.scope_state.is_finalize {
if self.scope_state.variant != Some(Variant::AsyncFunction) {
self.handler.emit_err(TypeCheckerError::invalid_operation_outside_finalize(
"Mapping::get_or",
function_span,
@ -1011,7 +1012,7 @@ impl<'a> TypeChecker<'a> {
}
CoreFunction::MappingSet => {
// Check that the operation is invoked in a `finalize` block.
if !self.scope_state.is_finalize {
if self.scope_state.variant != Some(Variant::AsyncFunction) {
self.handler
.emit_err(TypeCheckerError::invalid_operation_outside_finalize("Mapping::set", function_span))
}
@ -1033,7 +1034,7 @@ impl<'a> TypeChecker<'a> {
}
CoreFunction::MappingRemove => {
// Check that the operation is invoked in a `finalize` block.
if !self.scope_state.is_finalize {
if self.scope_state.variant != Some(Variant::AsyncFunction) {
self.handler.emit_err(TypeCheckerError::invalid_operation_outside_finalize(
"Mapping::remove",
function_span,
@ -1056,7 +1057,7 @@ impl<'a> TypeChecker<'a> {
}
CoreFunction::MappingContains => {
// Check that the operation is invoked in a `finalize` block.
if !self.scope_state.is_finalize {
if self.scope_state.variant != Some(Variant::AsyncFunction) {
self.handler.emit_err(TypeCheckerError::invalid_operation_outside_finalize(
"Mapping::contains",
function_span,
@ -1266,7 +1267,7 @@ impl<'a> TypeChecker<'a> {
self.scope_state.variant = Some(function.variant);
// Special type checking for finalize blocks. Can skip for stubs.
if self.scope_state.is_finalize & !self.scope_state.is_stub {
if self.scope_state.variant == Some(Variant::AsyncFunction) && !self.scope_state.is_stub {
// Finalize functions are not allowed to return values.
if !function.output.is_empty() {
self.emit_err(TypeCheckerError::finalize_function_cannot_return_value(function.span()));
@ -1335,7 +1336,7 @@ impl<'a> TypeChecker<'a> {
}
// Check that the finalize input parameter is not constant or private.
if self.scope_state.is_finalize
if self.scope_state.variant == Some(Variant::AsyncFunction)
&& (input_var.mode() == Mode::Constant || input_var.mode() == Mode::Private)
&& (input_var.mode() == Mode::Constant || input_var.mode() == Mode::Private)
{
@ -1345,11 +1346,11 @@ impl<'a> TypeChecker<'a> {
// Note that this unwrap is safe since we assign to `self.variant` above.
match self.scope_state.variant.unwrap() {
// If the function is a transition function, then check that the parameter mode is not a constant.
Variant::Transition if input_var.mode() == Mode::Constant => {
Variant::Transition | Variant::AsyncTransition if input_var.mode() == Mode::Constant => {
self.emit_err(TypeCheckerError::transition_function_inputs_cannot_be_const(input_var.span()))
}
// If the function is not a transition function, then check that the parameters do not have an associated mode.
Variant::Standard | Variant::Inline if input_var.mode() != Mode::None => {
Variant::Function | Variant::AsyncFunction | Variant::Inline if input_var.mode() != Mode::None => {
self.emit_err(TypeCheckerError::regular_function_inputs_cannot_have_modes(input_var.span()))
}
_ => {} // Do nothing.
@ -1357,8 +1358,9 @@ impl<'a> TypeChecker<'a> {
// Add function inputs to the symbol table. Futures have already been added.
if !matches!(&input_var.type_(), &Type::Future(_)) {
if let Err(err) =
self.symbol_table.borrow_mut().insert_variable(Location::new(None, input_var.identifier().name), VariableSymbol {
if let Err(err) = self.symbol_table.borrow_mut().insert_variable(
Location::new(None, input_var.identifier().name),
VariableSymbol {
type_: input_var.type_(),
span: input_var.identifier().span(),
declaration: VariableType::Input(input_var.mode()),
@ -1409,7 +1411,7 @@ impl<'a> TypeChecker<'a> {
self.emit_err(TypeCheckerError::cannot_have_constant_output_mode(function_output.span));
}
// Async transitions must return exactly one future, and it must be in the last position.
if self.scope_state.is_async_transition
if self.scope_state.variant == Some(AsyncTransition)
&& ((index < function.output.len() - 1 && matches!(function_output.type_, Type::Future(_)))
|| (index == function.output.len() - 1
&& !matches!(function_output.type_, Type::Future(_))))
@ -1488,11 +1490,13 @@ impl<'a> TypeChecker<'a> {
type_
};
// Insert the variable into the symbol table.
if let Err(err) = self.symbol_table.borrow_mut().insert_variable(Location::new(None, name.name), VariableSymbol {
type_: ty,
span,
declaration: VariableType::Mut,
}) {
if let Err(err) =
self.symbol_table.borrow_mut().insert_variable(Location::new(None, name.name), VariableSymbol {
type_: ty,
span,
declaration: VariableType::Mut,
})
{
self.handler.emit_err(err);
}
}

View File

@ -62,10 +62,8 @@ impl ScopeState {
}
/// Initialize state variables for new function.
pub fn initialize_function_state(&mut self, variant: Variant, is_async: bool) {
pub fn initialize_function_state(&mut self, variant: Variant) {
self.variant = Some(variant);
self.is_finalize = variant == Variant::Standard && is_async;
self.is_async_transition = variant == Variant::Transition && is_async;
self.has_called_finalize = false;
self.futures = IndexMap::new();
}