diff --git a/compiler/ast/src/expressions/mod.rs b/compiler/ast/src/expressions/mod.rs index b1a52e6dcf..4fcca54b4b 100644 --- a/compiler/ast/src/expressions/mod.rs +++ b/compiler/ast/src/expressions/mod.rs @@ -57,7 +57,7 @@ pub enum Expression { Binary(BinaryExpression), /// A call expression, e.g., `my_fun(args)`. Call(CallExpression), - /// An expression constructing a structure like `Foo { bar: 42, baz }`. + /// An expression constructing a circuit like `Foo { bar: 42, baz }`. CircuitInit(CircuitInitExpression), /// An expression of type "error". /// Will result in a compile error eventually. diff --git a/compiler/ast/src/passes/visitor_director.rs b/compiler/ast/src/passes/visitor_director.rs index 7df85c34f2..8eb633da4d 100644 --- a/compiler/ast/src/passes/visitor_director.rs +++ b/compiler/ast/src/passes/visitor_director.rs @@ -225,6 +225,7 @@ pub trait ProgramVisitorDirector<'a>: VisitorDirector<'a> + StatementVisitorDire .values() .for_each(|function| self.visit_function(function)); input.circuits.values().for_each(|circuit| self.visit_circuit(circuit)); + input.records.values().for_each(|record| self.visit_record(record)); } } diff --git a/compiler/ast/src/records/record_variable.rs b/compiler/ast/src/records/record_variable.rs index d87d95b4c1..8b944503b3 100644 --- a/compiler/ast/src/records/record_variable.rs +++ b/compiler/ast/src/records/record_variable.rs @@ -36,7 +36,7 @@ impl RecordVariable { } pub fn name(&self) -> Symbol { - return self.ident.name + return self.ident.name; } } diff --git a/compiler/parser/src/parser/expression.rs b/compiler/parser/src/parser/expression.rs index b56a69719c..1cdf2233e7 100644 --- a/compiler/parser/src/parser/expression.rs +++ b/compiler/parser/src/parser/expression.rs @@ -16,7 +16,6 @@ use super::*; use leo_errors::{ParserError, Result}; -use leo_span::sym; use snarkvm_dpc::{prelude::Address, testnet2::Testnet2}; @@ -537,17 +536,8 @@ impl ParserContext<'_> { Token::Ident(name) => { let ident = Identifier { name, span }; if !self.disallow_circuit_construction && self.check(&Token::LeftCurly) { - self.parse_circuit_expression(ident)? - } else { - Expression::Identifier(ident) - } - } - Token::SelfUpper => { - let ident = Identifier { - name: sym::SelfUpper, - span, - }; - if !self.disallow_circuit_construction && self.check(&Token::LeftCurly) { + // Parse circuit and records inits as circuit expressions. + // Enforce circuit or record type later at type checking. self.parse_circuit_expression(ident)? } else { Expression::Identifier(ident) diff --git a/compiler/parser/src/parser/file.rs b/compiler/parser/src/parser/file.rs index e2d86e0ea6..59b2f66491 100644 --- a/compiler/parser/src/parser/file.rs +++ b/compiler/parser/src/parser/file.rs @@ -185,7 +185,9 @@ impl ParserContext<'_> { let actual_type = self.parse_all_types()?.0; if expected_name != actual_name.name || expected_type != actual_type { - self.emit_err(ParserError::required_record_variable(expected_name, expected_type, actual_name.span()).into()); + self.emit_err( + ParserError::required_record_variable(expected_name, expected_type, actual_name.span()).into(), + ); } // Emit an error for a record variable without an ending comma or semicolon. diff --git a/compiler/parser/src/tokenizer/token.rs b/compiler/parser/src/tokenizer/token.rs index 71a661e6df..965e98fbf9 100644 --- a/compiler/parser/src/tokenizer/token.rs +++ b/compiler/parser/src/tokenizer/token.rs @@ -89,7 +89,6 @@ pub enum Token { U32, U64, U128, - SelfUpper, Record, // Regular Keywords @@ -144,7 +143,6 @@ pub const KEYWORD_TOKENS: &[Token] = &[ Token::Record, Token::Return, Token::SelfLower, - Token::SelfUpper, Token::Scalar, Token::Static, Token::String, @@ -190,7 +188,6 @@ impl Token { Token::Return => sym::Return, Token::Scalar => sym::scalar, Token::SelfLower => sym::SelfLower, - Token::SelfUpper => sym::SelfUpper, Token::Static => sym::Static, Token::String => sym::string, Token::True => sym::True, @@ -270,7 +267,6 @@ impl fmt::Display for Token { U32 => write!(f, "u32"), U64 => write!(f, "u64"), U128 => write!(f, "u128"), - SelfUpper => write!(f, "Self"), Record => write!(f, "record"), Circuit => write!(f, "circuit"), diff --git a/compiler/passes/src/symbol_table/create.rs b/compiler/passes/src/symbol_table/create.rs index c13fc8c0e0..bd537484db 100644 --- a/compiler/passes/src/symbol_table/create.rs +++ b/compiler/passes/src/symbol_table/create.rs @@ -54,4 +54,11 @@ impl<'a> ProgramVisitor<'a> for CreateSymbolTable<'a> { } VisitResult::SkipChildren } + + fn visit_record(&mut self, input: &'a Record) -> VisitResult { + if let Err(err) = self.symbol_table.insert_record(input.name(), input) { + self.handler.emit_err(err); + } + VisitResult::SkipChildren + } } diff --git a/compiler/passes/src/symbol_table/table.rs b/compiler/passes/src/symbol_table/table.rs index 9402e2d2a1..e8e494bb42 100644 --- a/compiler/passes/src/symbol_table/table.rs +++ b/compiler/passes/src/symbol_table/table.rs @@ -16,7 +16,7 @@ use std::fmt::Display; -use leo_ast::{Circuit, Function}; +use leo_ast::{Circuit, Function, Record}; use leo_errors::{AstError, Result}; use leo_span::{Span, Symbol}; @@ -32,6 +32,9 @@ pub struct SymbolTable<'a> { /// Maps circuit names to circuit definitions. /// This field is populated at a first pass. circuits: IndexMap, + /// Maps record names to record definitions. + /// This field is populated at a first pass. + records: IndexMap, /// Variables represents functions variable definitions and input variables. /// This field is not populated till necessary. pub(crate) variables: VariableScope<'a>, @@ -62,10 +65,27 @@ impl<'a> SymbolTable<'a> { // Return an error if the circuit name has already been inserted. return Err(AstError::shadowed_circuit(symbol, insert.span).into()); } + if self.records.contains_key(&symbol) { + // Return an error if the record name has already been inserted. + return Err(AstError::shadowed_record(symbol, insert.span).into()); + } self.circuits.insert(symbol, insert); Ok(()) } + pub fn insert_record(&mut self, symbol: Symbol, insert: &'a Record) -> Result<()> { + if self.circuits.contains_key(&symbol) { + // Return an error if the circuit name has already been inserted. + return Err(AstError::shadowed_circuit(symbol, insert.span).into()); + } + if self.records.contains_key(&symbol) { + // Return an error if the record name has already been inserted. + return Err(AstError::shadowed_record(symbol, insert.span).into()); + } + self.records.insert(symbol, insert); + Ok(()) + } + pub fn insert_variable(&mut self, symbol: Symbol, insert: VariableSymbol<'a>) -> Result<()> { self.check_shadowing(symbol, insert.span)?; self.variables.variables.insert(symbol, insert); @@ -80,6 +100,10 @@ impl<'a> SymbolTable<'a> { self.circuits.get(symbol) } + pub fn lookup_record(&self, symbol: &Symbol) -> Option<&&'a Record> { + self.records.get(symbol) + } + pub fn lookup_variable(&self, symbol: &Symbol) -> Option<&VariableSymbol<'a>> { self.variables.lookup_variable(symbol) } @@ -111,6 +135,10 @@ impl<'a> Display for SymbolTable<'a> { write!(f, "{circ}")?; } + for rec in self.records.values() { + write!(f, "{rec}")?; + } + write!(f, "{}", self.variables) } } diff --git a/compiler/passes/src/type_checker/check_expressions.rs b/compiler/passes/src/type_checker/check_expressions.rs index 8744d7e172..67c2ffabc5 100644 --- a/compiler/passes/src/type_checker/check_expressions.rs +++ b/compiler/passes/src/type_checker/check_expressions.rs @@ -70,6 +70,12 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> { expected, circuit.span(), )); + } else if let Some(record) = self.visitor.symbol_table.clone().lookup_record(&var.name) { + return Some(self.visitor.assert_expected_option( + Type::Identifier(record.identifier.clone()), + expected, + record.span(), + )); } else if let VisitResult::VisitChildren = self.visitor.visit_identifier(var) { if let Some(var) = self.visitor.symbol_table.clone().lookup_variable(&var.name) { return Some(self.visitor.assert_expected_option(*var.type_, expected, var.span)); @@ -624,7 +630,62 @@ impl<'a> ExpressionVisitorDirector<'a> for Director<'a> { input: &'a CircuitInitExpression, additional: &Self::AdditionalInput, ) -> Option { - if let Some(circ) = self.visitor.symbol_table.clone().lookup_circuit(&input.name.name) { + // Type check record init expression. + if let Some(expected) = self.visitor.symbol_table.clone().lookup_record(&input.name.name) { + // Check record type name. + let ret = self + .visitor + .assert_expected_circuit(expected.identifier, additional, input.name.span()); + + // Check number of record data variables. + if expected.data.len() != input.members.len() - 2 { + self.visitor.handler.emit_err( + TypeCheckerError::incorrect_num_record_variables( + expected.data.len(), + input.members.len(), + input.span(), + ) + .into(), + ); + } + + // Check record variable types. + input.members.iter().for_each(|actual| { + // Check record owner. + if actual.identifier.matches(&expected.owner.ident) { + if let Some(owner_expr) = &actual.expression { + self.visit_expression(owner_expr, &Some(Type::Address)); + } + } + + // Check record balance. + if actual.identifier.matches(&expected.balance.ident) { + if let Some(balance_expr) = &actual.expression { + self.visit_expression(balance_expr, &Some(Type::IntegerType(IntegerType::U64))); + } + } + + // Check record data variable. + if let Some(expected_var) = expected + .data + .iter() + .find(|member| member.ident.matches(&actual.identifier)) + { + if let Some(var_expr) = &actual.expression { + self.visit_expression(var_expr, &Some(expected_var.type_)); + } + } else { + self.visitor.handler.emit_err( + TypeCheckerError::unknown_sym("record variable", actual.identifier, actual.identifier.span()) + .into(), + ); + } + }); + + Some(ret) + } else if let Some(circ) = self.visitor.symbol_table.clone().lookup_circuit(&input.name.name) { + // Type check circuit init expression. + // Check circuit type name. let ret = self .visitor diff --git a/compiler/passes/src/type_checker/check_file.rs b/compiler/passes/src/type_checker/check_file.rs index fa4c8e136f..66c0c186d0 100644 --- a/compiler/passes/src/type_checker/check_file.rs +++ b/compiler/passes/src/type_checker/check_file.rs @@ -68,6 +68,7 @@ impl<'a> ProgramVisitorDirector<'a> for Director<'a> { } fn visit_record(&mut self, input: &'a Record) { + println!("visit record"); if let VisitResult::VisitChildren = self.visitor_ref().visit_record(input) { // Check for conflicting record member names. let mut used = HashSet::new(); diff --git a/compiler/passes/src/type_checker/checker.rs b/compiler/passes/src/type_checker/checker.rs index d59457bee9..034eb787a5 100644 --- a/compiler/passes/src/type_checker/checker.rs +++ b/compiler/passes/src/type_checker/checker.rs @@ -147,7 +147,7 @@ impl<'a> TypeChecker<'a> { /// Returns the `circuit` type and emits an error if the `expected` type does not match. pub(crate) fn assert_expected_circuit(&mut self, circuit: Identifier, expected: &Option, span: Span) -> Type { if let Some(Type::Identifier(expected)) = expected { - if expected.name != circuit.name { + if !circuit.matches(expected) { self.handler .emit_err(TypeCheckerError::type_should_be(circuit.name, expected.name, span).into()); } diff --git a/examples/hello-world/src/main.leo b/examples/hello-world/src/main.leo index 065bdc8447..14548f4501 100644 --- a/examples/hello-world/src/main.leo +++ b/examples/hello-world/src/main.leo @@ -1,13 +1,16 @@ -record Token { +circuit Token { // The token owner. - owner: address, - // The Aleo balance (in gates). balance: u64, - // The token amount. - amount: u64, + data: u64 +} + +function mint() -> Token { + let tok: Token = Token { balance: 1u64, data: 1u64}; + return tok; } function main(a: u8) -> group { + return Pedersen64::hash(a); } \ No newline at end of file diff --git a/examples/token/src/main.leo b/examples/token/src/main.leo index aa55207da6..d7adbc5b55 100644 --- a/examples/token/src/main.leo +++ b/examples/token/src/main.leo @@ -42,7 +42,7 @@ function transfer(r0: Token, r1: Receiver) -> Token { let r4: Token = mint(r0.owner, r0.amount); // return (r3, r4); - return r3 + return r3; } function main() -> u8 { diff --git a/leo/errors/src/errors/ast/ast_errors.rs b/leo/errors/src/errors/ast/ast_errors.rs index 6e751086e8..e58f21e12f 100644 --- a/leo/errors/src/errors/ast/ast_errors.rs +++ b/leo/errors/src/errors/ast/ast_errors.rs @@ -155,6 +155,14 @@ create_messages!( help: None, } + /// For when a user shadows a record. + @formatted + shadowed_record { + args: (record: impl Display), + msg: format!("record `{record}` shadowed by"), + help: None, + } + /// For when a user shadows a variable. @formatted shadowed_variable { diff --git a/leo/errors/src/errors/type_checker/type_checker_error.rs b/leo/errors/src/errors/type_checker/type_checker_error.rs index 9a1454f7fa..6c8992c2b1 100644 --- a/leo/errors/src/errors/type_checker/type_checker_error.rs +++ b/leo/errors/src/errors/type_checker/type_checker_error.rs @@ -191,6 +191,16 @@ create_messages!( help: None, } + /// For when the user tries initialize a circuit with the incorrect number of args. + @formatted + incorrect_num_record_variables { + args: (expected: impl Display, received: impl Display), + msg: format!( + "Record expected `{expected}` variables, but got `{received}`", + ), + help: None, + } + /// An invalid access call is made e.g., `bool::MAX` @formatted invalid_access_expression {