diff --git a/imports/src/parser/import_parser.rs b/imports/src/parser/import_parser.rs index 3b4dc2fe7b..fde6aa4d95 100644 --- a/imports/src/parser/import_parser.rs +++ b/imports/src/parser/import_parser.rs @@ -37,7 +37,7 @@ impl ImportParser { /// It is okay if the imported program is already present since importing multiple symbols from /// the same file is allowed. /// - pub(crate) fn insert_import(&mut self, file_name: String, program: Program) { + pub fn insert_import(&mut self, file_name: String, program: Program) { // Insert the imported program. let _program = self.imports.insert(file_name, program); } @@ -49,7 +49,7 @@ impl ImportParser { /// /// If the vector did have this file_name present, a duplicate import error is thrown. /// - pub(crate) fn insert_core_package(&mut self, package: &Package) -> Result<(), ImportParserError> { + pub fn insert_core_package(&mut self, package: &Package) -> Result<(), ImportParserError> { // Check for duplicate core package name. if self.core_packages.contains(package) { return Err(ImportParserError::duplicate_core_package(package.name.clone())); diff --git a/symbol-table/src/errors/symbol_table.rs b/symbol-table/src/errors/symbol_table.rs index 7ec93ee2b1..e2ea24802c 100644 --- a/symbol-table/src/errors/symbol_table.rs +++ b/symbol-table/src/errors/symbol_table.rs @@ -92,7 +92,7 @@ impl SymbolTableError { pub fn unknown_symbol(symbol: &ImportSymbol, program: &Program) -> Self { let message = format!( "Cannot find imported symbol `{}` in imported file `{}`", - symbol, program.name + symbol.symbol, program.name ); Self::new_from_span(message, symbol.span.to_owned()) diff --git a/symbol-table/src/symbol_table.rs b/symbol-table/src/symbol_table.rs index 9c58914ada..d3cf5cdf1a 100644 --- a/symbol-table/src/symbol_table.rs +++ b/symbol-table/src/symbol_table.rs @@ -38,10 +38,10 @@ pub struct SymbolTable { names: IndexMap, /// Maps circuit name -> circuit type. - circuits: IndexMap, + pub circuits: IndexMap, /// Maps function name -> function type. - functions: IndexMap, + pub functions: IndexMap, /// The parent of this symbol table. parent: Option>, @@ -68,7 +68,7 @@ impl SymbolTable { table.insert_input(input)?; // Check for duplicate program and import names. - table.check_names(program, import_parser)?; + table.check_names(program, import_parser, input)?; // Check for unknown or invalid types. table.check_types(program)?; @@ -149,6 +149,24 @@ impl SymbolTable { self.functions.insert(identifier.name, function_type) } + /// + /// Insert all circuit and function types from another symbol table. + /// Used when importing names from another package. + /// + pub fn insert_table_types(&mut self, imported: &Self) -> Result<(), SymbolTableError> { + for (imported_circuit_name, imported_circuit_type) in imported.circuits.iter() { + self.circuits + .insert(imported_circuit_name.to_owned(), imported_circuit_type.to_owned()); + } + + for (imported_function_name, imported_function_type) in imported.functions.iter() { + self.functions + .insert(imported_function_name.to_owned(), imported_function_type.to_owned()); + } + + Ok(()) + } + /// /// Returns a reference to the circuit type corresponding to the name. /// @@ -195,9 +213,14 @@ impl SymbolTable { /// If a circuit or function name has no duplicates, then it is inserted into the symbol table. /// Variables defined later in the unresolved program cannot have the same name. /// - pub fn check_names(&mut self, program: &Program, import_parser: &ImportParser) -> Result<(), SymbolTableError> { + pub fn check_names( + &mut self, + program: &Program, + import_parser: &ImportParser, + input: &Input, + ) -> Result<(), SymbolTableError> { // Check unresolved program import names. - self.check_import_names(&program.imports, import_parser)?; + self.check_import_names(&program.imports, import_parser, input)?; // Check unresolved program circuit names. self.check_circuit_names(&program.circuits)?; @@ -250,10 +273,11 @@ impl SymbolTable { &mut self, imports: &[ImportStatement], import_parser: &ImportParser, + input: &Input, ) -> Result<(), SymbolTableError> { // Iterate over imported names. for import in imports { - self.check_import_statement(import, import_parser)?; + self.check_import_statement(import, import_parser, input)?; } Ok(()) @@ -269,6 +293,7 @@ impl SymbolTable { &mut self, import: &ImportStatement, import_parser: &ImportParser, + input: &Input, ) -> Result<(), SymbolTableError> { // Check if the import name exists as core package. let core_package = import_parser.get_core_package(&import.package); @@ -279,7 +304,7 @@ impl SymbolTable { } // Attempt to insert the imported names into the symbol table. - self.check_package(import, import_parser) + self.check_package(import, import_parser, input) } /// @@ -320,6 +345,7 @@ impl SymbolTable { &mut self, import: &ImportStatement, import_parser: &ImportParser, + input: &Input, ) -> Result<(), SymbolTableError> { // Get imported symbols from statement. let imported_symbols = ImportedSymbols::new(import); @@ -341,14 +367,11 @@ impl SymbolTable { continue; }; - // Check the imported program for duplicate types. - self.check_names(program, import_parser)?; + // Check the imported program for duplicate or undefined types. + let import_symbol_table = SymbolTable::new(program, import_parser, input)?; - // Check the imported program for undefined types. - self.check_types(program)?; - - // Store the imported symbol. - // self.insert_import_symbol(symbol, program)?; // TODO (collinc97) uncomment this line when public/private import scopes are implemented. + // Import symbols into the self symbol table. + self.insert_import_symbol(symbol, import_symbol_table, program)?; } Ok(()) @@ -357,37 +380,34 @@ impl SymbolTable { /// /// Inserts the imported symbol into the symbol table if it is present in the given program. /// - pub fn insert_import_symbol(&mut self, symbol: ImportSymbol, program: &Program) -> Result<(), SymbolTableError> { + pub fn insert_import_symbol( + &mut self, + symbol: ImportSymbol, + table: SymbolTable, + program: &Program, + ) -> Result<(), SymbolTableError> { // Check for import *. if symbol.is_star() { - // Insert all program circuits. - self.check_circuit_names(&program.circuits)?; - - // Insert all program functions. - self.check_function_names(&program.functions) + // Insert all program circuits and functions. + self.insert_table_types(&table) } else { // Check for a symbol alias. - let identifier = symbol.alias.to_owned().unwrap_or_else(|| symbol.symbol.to_owned()); + let identifier = symbol.alias.to_owned().unwrap_or_else(|| symbol.symbol.clone()); - // Check if the imported symbol is a circuit - match program.circuits.get(&symbol.symbol) { - Some(circuit) => { - // Insert imported circuit. - self.insert_circuit_name(identifier.to_string(), UserDefinedType::from(circuit.to_owned())) - } - None => { - // Check if the imported symbol is a function. - match program.functions.get(&symbol.symbol) { - Some(function) => { - // Insert the imported function. - self.insert_function_name( - identifier.to_string(), - UserDefinedType::from(function.to_owned()), - ) - } - None => Err(SymbolTableError::unknown_symbol(&symbol, program)), - } - } + // Check if the imported symbol is a circuit or a function. + if let Some(circuit_type) = table.get_circuit_type(&symbol.symbol.name) { + // Insert the circuit into the self symbol table. + self.insert_circuit_type(identifier, circuit_type.to_owned()); + + Ok(()) + } else if let Some(function_type) = table.get_function_type(&symbol.symbol.name) { + // Insert the function into the self symbol table. + self.insert_function_type(identifier, function_type.to_owned()); + + Ok(()) + } else { + // Return an error if we cannot find the imported symbol. + Err(SymbolTableError::unknown_symbol(&symbol, program)) } } } diff --git a/symbol-table/tests/mod.rs b/symbol-table/tests/mod.rs index 21c7710cb6..e73dc0d233 100644 --- a/symbol-table/tests/mod.rs +++ b/symbol-table/tests/mod.rs @@ -52,13 +52,10 @@ impl TestSymbolTable { /// /// Expect no errors during parsing. /// - pub fn expect_success(self) { + pub fn expect_success(self, import_parser: ImportParser) { // Get program. let program = self.ast.into_repr(); - // Create empty import parser. - let import_parser = ImportParser::default(); - // Create empty input. let input = Input::new(); @@ -82,7 +79,9 @@ impl TestSymbolTable { let import_parser = ImportParser::default(); // Run pass one and expect an error. - let error = static_check.check_names(&program, &import_parser).unwrap_err(); + let error = static_check + .check_names(&program, &import_parser, &Input::new()) + .unwrap_err(); match error { SymbolTableError::Error(_) => {} // Ok @@ -106,7 +105,9 @@ impl TestSymbolTable { let import_parser = ImportParser::default(); // Run the pass one and expect no errors. - static_check.check_names(&program, &import_parser).unwrap(); + static_check + .check_names(&program, &import_parser, &Input::new()) + .unwrap(); // Run the pass two and expect and error. let error = static_check.check_types(&program).unwrap_err(); diff --git a/symbol-table/tests/symbol_table/import_circuit_alias.leo b/symbol-table/tests/symbol_table/import_circuit_alias.leo new file mode 100644 index 0000000000..f0bbc156fd --- /dev/null +++ b/symbol-table/tests/symbol_table/import_circuit_alias.leo @@ -0,0 +1,9 @@ +import bar.Bar as Baz; + +circuit Bar { + r: bool, +} +function main() { + let z = Baz { z: 0u32 }; + let r = Bar { r: true }; +} \ No newline at end of file diff --git a/symbol-table/tests/symbol_table/import_function_alias.leo b/symbol-table/tests/symbol_table/import_function_alias.leo new file mode 100644 index 0000000000..125b2095c8 --- /dev/null +++ b/symbol-table/tests/symbol_table/import_function_alias.leo @@ -0,0 +1,10 @@ +import foo.foo as boo; + +function foo() -> bool { + return false +} + +function main() { + let z: u8 = boo(); + let r: bool = foo(); +} \ No newline at end of file diff --git a/symbol-table/tests/symbol_table/import_star.leo b/symbol-table/tests/symbol_table/import_star.leo new file mode 100644 index 0000000000..ad7be78c89 --- /dev/null +++ b/symbol-table/tests/symbol_table/import_star.leo @@ -0,0 +1,5 @@ +import foo.*; + +function main() { + let x: u8 = boo(); +} \ No newline at end of file diff --git a/symbol-table/tests/symbol_table/import_undefined.leo b/symbol-table/tests/symbol_table/import_undefined.leo new file mode 100644 index 0000000000..9e4063b297 --- /dev/null +++ b/symbol-table/tests/symbol_table/import_undefined.leo @@ -0,0 +1,5 @@ +import foo.boo; + +function main() { + let x: u8 = boo(); +} \ No newline at end of file diff --git a/symbol-table/tests/symbol_table/imports/bar.leo b/symbol-table/tests/symbol_table/imports/bar.leo new file mode 100644 index 0000000000..26ac259dee --- /dev/null +++ b/symbol-table/tests/symbol_table/imports/bar.leo @@ -0,0 +1,3 @@ +circuit Bar { + z: u32 +} \ No newline at end of file diff --git a/symbol-table/tests/symbol_table/imports/foo.leo b/symbol-table/tests/symbol_table/imports/foo.leo new file mode 100644 index 0000000000..b462c11d57 --- /dev/null +++ b/symbol-table/tests/symbol_table/imports/foo.leo @@ -0,0 +1,3 @@ +function foo() -> u8 { + return 5u8 +} \ No newline at end of file diff --git a/symbol-table/tests/symbol_table/mod.rs b/symbol-table/tests/symbol_table/mod.rs index e5b820d4a4..3a3d15b57f 100644 --- a/symbol-table/tests/symbol_table/mod.rs +++ b/symbol-table/tests/symbol_table/mod.rs @@ -16,6 +16,10 @@ use crate::TestSymbolTable; +use leo_ast::Input; +use leo_imports::ImportParser; +use leo_symbol_table::{SymbolTable, SymbolTableError}; + /// /// Defines a circuit `Foo {}`. /// Attempts to define a second circuit `Foo {}`. @@ -73,3 +77,106 @@ fn test_undefined_circuit() { resolver.expect_pass_two_error(); } + +/// +/// Imports an undefined function `boo` from file foo.leo. +/// +/// Expected output: SymbolTableError +/// Message: Cannot find imported symbol `boo` in imported file `` +/// +#[test] +fn test_import_undefined() { + let program_string = include_str!("import_undefined.leo"); + let import_string = include_str!("imports/foo.leo"); + + let program_table = TestSymbolTable::new(program_string); + let import_table = TestSymbolTable::new(import_string); + + let import_program = import_table.ast.into_repr(); + + let mut imports = ImportParser::default(); + imports.insert_import("foo".to_owned(), import_program); + + // Create new symbol table. + let static_check = &mut SymbolTable::default(); + + // Run pass one and expect an error. + let error = static_check + .check_names(&program_table.ast.into_repr(), &imports, &Input::new()) + .unwrap_err(); + + match error { + SymbolTableError::Error(_) => {} // Ok + error => panic!("Expected a symbol table error found `{}`", error), + } +} + +/// +/// Imports all functions from file foo.leo. +/// Calls function `foo` defined in foo.leo. +/// +/// Expected output: Test Pass +/// +#[test] +fn test_import_star() { + let program_string = include_str!("import_star.leo"); + let import_string = include_str!("imports/foo.leo"); + + let program_table = TestSymbolTable::new(program_string); + let import_table = TestSymbolTable::new(import_string); + + let import_program = import_table.ast.into_repr(); + + let mut imports = ImportParser::default(); + imports.insert_import("foo".to_owned(), import_program); + + program_table.expect_success(imports); +} + +/// +/// Imports a circuit named `Bar` from file bar.leo. +/// Renames `Bar` => `Baz`. +/// Defines a circuit named `Bar` in main.leo. +/// Instantiates circuits `Bar` and `Baz`. +/// +/// Expected output: Test Pass +/// +#[test] +fn test_import_circuit_alias() { + let program_string = include_str!("import_circuit_alias.leo"); + let import_string = include_str!("imports/bar.leo"); + + let program_table = TestSymbolTable::new(program_string); + let import_table = TestSymbolTable::new(import_string); + + let import_program = import_table.ast.into_repr(); + + let mut imports = ImportParser::default(); + imports.insert_import("bar".to_owned(), import_program); + + program_table.expect_success(imports); +} + +/// +/// Imports a function named `foo` from file foo.leo. +/// Renames `foo` => `boo`. +/// Defines a function named `foo` in main.leo. +/// Calls functions `foo` and `boo`. +/// +/// Expected output: Test Pass +/// +#[test] +fn test_import_function_alias() { + let program_string = include_str!("import_function_alias.leo"); + let import_string = include_str!("imports/foo.leo"); + + let program_table = TestSymbolTable::new(program_string); + let import_table = TestSymbolTable::new(import_string); + + let import_program = import_table.ast.into_repr(); + + let mut imports = ImportParser::default(); + imports.insert_import("foo".to_owned(), import_program); + + program_table.expect_success(imports); +}