type check record init expressions

This commit is contained in:
collin 2022-06-26 10:46:32 -10:00
parent 18a74cfb85
commit 21c6a2167a
15 changed files with 135 additions and 28 deletions

View File

@ -57,7 +57,7 @@ pub enum Expression {
Binary(BinaryExpression), Binary(BinaryExpression),
/// A call expression, e.g., `my_fun(args)`. /// A call expression, e.g., `my_fun(args)`.
Call(CallExpression), Call(CallExpression),
/// An expression constructing a structure like `Foo { bar: 42, baz }`. /// An expression constructing a circuit like `Foo { bar: 42, baz }`.
CircuitInit(CircuitInitExpression), CircuitInit(CircuitInitExpression),
/// An expression of type "error". /// An expression of type "error".
/// Will result in a compile error eventually. /// Will result in a compile error eventually.

View File

@ -225,6 +225,7 @@ pub trait ProgramVisitorDirector<'a>: VisitorDirector<'a> + StatementVisitorDire
.values() .values()
.for_each(|function| self.visit_function(function)); .for_each(|function| self.visit_function(function));
input.circuits.values().for_each(|circuit| self.visit_circuit(circuit)); input.circuits.values().for_each(|circuit| self.visit_circuit(circuit));
input.records.values().for_each(|record| self.visit_record(record));
} }
} }

View File

@ -36,7 +36,7 @@ impl RecordVariable {
} }
pub fn name(&self) -> Symbol { pub fn name(&self) -> Symbol {
return self.ident.name return self.ident.name;
} }
} }

View File

@ -16,7 +16,6 @@
use super::*; use super::*;
use leo_errors::{ParserError, Result}; use leo_errors::{ParserError, Result};
use leo_span::sym;
use snarkvm_dpc::{prelude::Address, testnet2::Testnet2}; use snarkvm_dpc::{prelude::Address, testnet2::Testnet2};
@ -537,17 +536,8 @@ impl ParserContext<'_> {
Token::Ident(name) => { Token::Ident(name) => {
let ident = Identifier { name, span }; let ident = Identifier { name, span };
if !self.disallow_circuit_construction && self.check(&Token::LeftCurly) { if !self.disallow_circuit_construction && self.check(&Token::LeftCurly) {
self.parse_circuit_expression(ident)? // Parse circuit and records inits as circuit expressions.
} else { // Enforce circuit or record type later at type checking.
Expression::Identifier(ident)
}
}
Token::SelfUpper => {
let ident = Identifier {
name: sym::SelfUpper,
span,
};
if !self.disallow_circuit_construction && self.check(&Token::LeftCurly) {
self.parse_circuit_expression(ident)? self.parse_circuit_expression(ident)?
} else { } else {
Expression::Identifier(ident) Expression::Identifier(ident)

View File

@ -185,7 +185,9 @@ impl ParserContext<'_> {
let actual_type = self.parse_all_types()?.0; let actual_type = self.parse_all_types()?.0;
if expected_name != actual_name.name || expected_type != actual_type { if expected_name != actual_name.name || expected_type != actual_type {
self.emit_err(ParserError::required_record_variable(expected_name, expected_type, actual_name.span()).into()); self.emit_err(
ParserError::required_record_variable(expected_name, expected_type, actual_name.span()).into(),
);
} }
// Emit an error for a record variable without an ending comma or semicolon. // Emit an error for a record variable without an ending comma or semicolon.

View File

@ -89,7 +89,6 @@ pub enum Token {
U32, U32,
U64, U64,
U128, U128,
SelfUpper,
Record, Record,
// Regular Keywords // Regular Keywords
@ -144,7 +143,6 @@ pub const KEYWORD_TOKENS: &[Token] = &[
Token::Record, Token::Record,
Token::Return, Token::Return,
Token::SelfLower, Token::SelfLower,
Token::SelfUpper,
Token::Scalar, Token::Scalar,
Token::Static, Token::Static,
Token::String, Token::String,
@ -190,7 +188,6 @@ impl Token {
Token::Return => sym::Return, Token::Return => sym::Return,
Token::Scalar => sym::scalar, Token::Scalar => sym::scalar,
Token::SelfLower => sym::SelfLower, Token::SelfLower => sym::SelfLower,
Token::SelfUpper => sym::SelfUpper,
Token::Static => sym::Static, Token::Static => sym::Static,
Token::String => sym::string, Token::String => sym::string,
Token::True => sym::True, Token::True => sym::True,
@ -270,7 +267,6 @@ impl fmt::Display for Token {
U32 => write!(f, "u32"), U32 => write!(f, "u32"),
U64 => write!(f, "u64"), U64 => write!(f, "u64"),
U128 => write!(f, "u128"), U128 => write!(f, "u128"),
SelfUpper => write!(f, "Self"),
Record => write!(f, "record"), Record => write!(f, "record"),
Circuit => write!(f, "circuit"), Circuit => write!(f, "circuit"),

View File

@ -54,4 +54,11 @@ impl<'a> ProgramVisitor<'a> for CreateSymbolTable<'a> {
} }
VisitResult::SkipChildren VisitResult::SkipChildren
} }
fn visit_record(&mut self, input: &'a Record) -> VisitResult {
if let Err(err) = self.symbol_table.insert_record(input.name(), input) {
self.handler.emit_err(err);
}
VisitResult::SkipChildren
}
} }

View File

@ -16,7 +16,7 @@
use std::fmt::Display; use std::fmt::Display;
use leo_ast::{Circuit, Function}; use leo_ast::{Circuit, Function, Record};
use leo_errors::{AstError, Result}; use leo_errors::{AstError, Result};
use leo_span::{Span, Symbol}; use leo_span::{Span, Symbol};
@ -32,6 +32,9 @@ pub struct SymbolTable<'a> {
/// Maps circuit names to circuit definitions. /// Maps circuit names to circuit definitions.
/// This field is populated at a first pass. /// This field is populated at a first pass.
circuits: IndexMap<Symbol, &'a Circuit>, circuits: IndexMap<Symbol, &'a Circuit>,
/// Maps record names to record definitions.
/// This field is populated at a first pass.
records: IndexMap<Symbol, &'a Record>,
/// Variables represents functions variable definitions and input variables. /// Variables represents functions variable definitions and input variables.
/// This field is not populated till necessary. /// This field is not populated till necessary.
pub(crate) variables: VariableScope<'a>, pub(crate) variables: VariableScope<'a>,
@ -62,10 +65,27 @@ impl<'a> SymbolTable<'a> {
// Return an error if the circuit name has already been inserted. // Return an error if the circuit name has already been inserted.
return Err(AstError::shadowed_circuit(symbol, insert.span).into()); return Err(AstError::shadowed_circuit(symbol, insert.span).into());
} }
if self.records.contains_key(&symbol) {
// Return an error if the record name has already been inserted.
return Err(AstError::shadowed_record(symbol, insert.span).into());
}
self.circuits.insert(symbol, insert); self.circuits.insert(symbol, insert);
Ok(()) Ok(())
} }
pub fn insert_record(&mut self, symbol: Symbol, insert: &'a Record) -> Result<()> {
if self.circuits.contains_key(&symbol) {
// Return an error if the circuit name has already been inserted.
return Err(AstError::shadowed_circuit(symbol, insert.span).into());
}
if self.records.contains_key(&symbol) {
// Return an error if the record name has already been inserted.
return Err(AstError::shadowed_record(symbol, insert.span).into());
}
self.records.insert(symbol, insert);
Ok(())
}
pub fn insert_variable(&mut self, symbol: Symbol, insert: VariableSymbol<'a>) -> Result<()> { pub fn insert_variable(&mut self, symbol: Symbol, insert: VariableSymbol<'a>) -> Result<()> {
self.check_shadowing(symbol, insert.span)?; self.check_shadowing(symbol, insert.span)?;
self.variables.variables.insert(symbol, insert); self.variables.variables.insert(symbol, insert);
@ -80,6 +100,10 @@ impl<'a> SymbolTable<'a> {
self.circuits.get(symbol) self.circuits.get(symbol)
} }
pub fn lookup_record(&self, symbol: &Symbol) -> Option<&&'a Record> {
self.records.get(symbol)
}
pub fn lookup_variable(&self, symbol: &Symbol) -> Option<&VariableSymbol<'a>> { pub fn lookup_variable(&self, symbol: &Symbol) -> Option<&VariableSymbol<'a>> {
self.variables.lookup_variable(symbol) self.variables.lookup_variable(symbol)
} }
@ -111,6 +135,10 @@ impl<'a> Display for SymbolTable<'a> {
write!(f, "{circ}")?; write!(f, "{circ}")?;
} }
for rec in self.records.values() {
write!(f, "{rec}")?;
}
write!(f, "{}", self.variables) write!(f, "{}", self.variables)
} }
} }

View File

@ -70,6 +70,12 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> {
expected, expected,
circuit.span(), circuit.span(),
)); ));
} else if let Some(record) = self.visitor.symbol_table.clone().lookup_record(&var.name) {
return Some(self.visitor.assert_expected_option(
Type::Identifier(record.identifier.clone()),
expected,
record.span(),
));
} else if let VisitResult::VisitChildren = self.visitor.visit_identifier(var) { } else if let VisitResult::VisitChildren = self.visitor.visit_identifier(var) {
if let Some(var) = self.visitor.symbol_table.clone().lookup_variable(&var.name) { if let Some(var) = self.visitor.symbol_table.clone().lookup_variable(&var.name) {
return Some(self.visitor.assert_expected_option(*var.type_, expected, var.span)); return Some(self.visitor.assert_expected_option(*var.type_, expected, var.span));
@ -624,7 +630,62 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> {
input: &'a CircuitInitExpression, input: &'a CircuitInitExpression,
additional: &Self::AdditionalInput, additional: &Self::AdditionalInput,
) -> Option<Self::Output> { ) -> Option<Self::Output> {
if let Some(circ) = self.visitor.symbol_table.clone().lookup_circuit(&input.name.name) { // Type check record init expression.
if let Some(expected) = self.visitor.symbol_table.clone().lookup_record(&input.name.name) {
// Check record type name.
let ret = self
.visitor
.assert_expected_circuit(expected.identifier, additional, input.name.span());
// Check number of record data variables.
if expected.data.len() != input.members.len() - 2 {
self.visitor.handler.emit_err(
TypeCheckerError::incorrect_num_record_variables(
expected.data.len(),
input.members.len(),
input.span(),
)
.into(),
);
}
// Check record variable types.
input.members.iter().for_each(|actual| {
// Check record owner.
if actual.identifier.matches(&expected.owner.ident) {
if let Some(owner_expr) = &actual.expression {
self.visit_expression(owner_expr, &Some(Type::Address));
}
}
// Check record balance.
if actual.identifier.matches(&expected.balance.ident) {
if let Some(balance_expr) = &actual.expression {
self.visit_expression(balance_expr, &Some(Type::IntegerType(IntegerType::U64)));
}
}
// Check record data variable.
if let Some(expected_var) = expected
.data
.iter()
.find(|member| member.ident.matches(&actual.identifier))
{
if let Some(var_expr) = &actual.expression {
self.visit_expression(var_expr, &Some(expected_var.type_));
}
} else {
self.visitor.handler.emit_err(
TypeCheckerError::unknown_sym("record variable", actual.identifier, actual.identifier.span())
.into(),
);
}
});
Some(ret)
} else if let Some(circ) = self.visitor.symbol_table.clone().lookup_circuit(&input.name.name) {
// Type check circuit init expression.
// Check circuit type name. // Check circuit type name.
let ret = self let ret = self
.visitor .visitor

View File

@ -68,6 +68,7 @@ impl<'a> ProgramVisitorDirector<'a> for Director<'a> {
} }
fn visit_record(&mut self, input: &'a Record) { fn visit_record(&mut self, input: &'a Record) {
println!("visit record");
if let VisitResult::VisitChildren = self.visitor_ref().visit_record(input) { if let VisitResult::VisitChildren = self.visitor_ref().visit_record(input) {
// Check for conflicting record member names. // Check for conflicting record member names.
let mut used = HashSet::new(); let mut used = HashSet::new();

View File

@ -147,7 +147,7 @@ impl<'a> TypeChecker<'a> {
/// Returns the `circuit` type and emits an error if the `expected` type does not match. /// Returns the `circuit` type and emits an error if the `expected` type does not match.
pub(crate) fn assert_expected_circuit(&mut self, circuit: Identifier, expected: &Option<Type>, span: Span) -> Type { pub(crate) fn assert_expected_circuit(&mut self, circuit: Identifier, expected: &Option<Type>, span: Span) -> Type {
if let Some(Type::Identifier(expected)) = expected { if let Some(Type::Identifier(expected)) = expected {
if expected.name != circuit.name { if !circuit.matches(expected) {
self.handler self.handler
.emit_err(TypeCheckerError::type_should_be(circuit.name, expected.name, span).into()); .emit_err(TypeCheckerError::type_should_be(circuit.name, expected.name, span).into());
} }

View File

@ -1,13 +1,16 @@
record Token { circuit Token {
// The token owner. // The token owner.
owner: address,
// The Aleo balance (in gates).
balance: u64, balance: u64,
// The token amount. data: u64
amount: u64, }
function mint() -> Token {
let tok: Token = Token { balance: 1u64, data: 1u64};
return tok;
} }
function main(a: u8) -> group { function main(a: u8) -> group {
return Pedersen64::hash(a); return Pedersen64::hash(a);
} }

View File

@ -42,7 +42,7 @@ function transfer(r0: Token, r1: Receiver) -> Token {
let r4: Token = mint(r0.owner, r0.amount); let r4: Token = mint(r0.owner, r0.amount);
// return (r3, r4); // return (r3, r4);
return r3 return r3;
} }
function main() -> u8 { function main() -> u8 {

View File

@ -155,6 +155,14 @@ create_messages!(
help: None, help: None,
} }
/// For when a user shadows a record.
@formatted
shadowed_record {
args: (record: impl Display),
msg: format!("record `{record}` shadowed by"),
help: None,
}
/// For when a user shadows a variable. /// For when a user shadows a variable.
@formatted @formatted
shadowed_variable { shadowed_variable {

View File

@ -191,6 +191,16 @@ create_messages!(
help: None, help: None,
} }
/// For when the user tries initialize a circuit with the incorrect number of args.
@formatted
incorrect_num_record_variables {
args: (expected: impl Display, received: impl Display),
msg: format!(
"Record expected `{expected}` variables, but got `{received}`",
),
help: None,
}
/// An invalid access call is made e.g., `bool::MAX` /// An invalid access call is made e.g., `bool::MAX`
@formatted @formatted
invalid_access_expression { invalid_access_expression {