fix import type checking

This commit is contained in:
collin 2021-01-07 18:12:46 -05:00
parent cec3a40eb1
commit b4bcfe549a
7 changed files with 99 additions and 49 deletions

View File

@ -37,7 +37,7 @@ impl ImportParser {
/// It is okay if the imported program is already present since importing multiple symbols from
/// the same file is allowed.
///
pub(crate) fn insert_import(&mut self, file_name: String, program: Program) {
pub fn insert_import(&mut self, file_name: String, program: Program) {
// Insert the imported program.
let _program = self.imports.insert(file_name, program);
}
@ -49,7 +49,7 @@ impl ImportParser {
///
/// If the vector did have this file_name present, a duplicate import error is thrown.
///
pub(crate) fn insert_core_package(&mut self, package: &Package) -> Result<(), ImportParserError> {
pub fn insert_core_package(&mut self, package: &Package) -> Result<(), ImportParserError> {
// Check for duplicate core package name.
if self.core_packages.contains(package) {
return Err(ImportParserError::duplicate_core_package(package.name.clone()));

View File

@ -92,7 +92,7 @@ impl SymbolTableError {
pub fn unknown_symbol(symbol: &ImportSymbol, program: &Program) -> Self {
let message = format!(
"Cannot find imported symbol `{}` in imported file `{}`",
symbol, program.name
symbol.symbol, program.name
);
Self::new_from_span(message, symbol.span.to_owned())

View File

@ -38,10 +38,10 @@ pub struct SymbolTable {
names: IndexMap<String, UserDefinedType>,
/// Maps circuit name -> circuit type.
circuits: IndexMap<String, CircuitType>,
pub circuits: IndexMap<String, CircuitType>,
/// Maps function name -> function type.
functions: IndexMap<String, FunctionType>,
pub functions: IndexMap<String, FunctionType>,
/// The parent of this symbol table.
parent: Option<Box<SymbolTable>>,
@ -68,7 +68,7 @@ impl SymbolTable {
table.insert_input(input)?;
// Check for duplicate program and import names.
table.check_names(program, import_parser)?;
table.check_names(program, import_parser, input)?;
// Check for unknown or invalid types.
table.check_types(program)?;
@ -149,6 +149,24 @@ impl SymbolTable {
self.functions.insert(identifier.name, function_type)
}
///
/// Insert all circuit and function types from another symbol table.
/// Used when importing names from another package.
///
pub fn insert_table_types(&mut self, imported: &Self) -> Result<(), SymbolTableError> {
for (imported_circuit_name, imported_circuit_type) in imported.circuits.iter() {
self.circuits
.insert(imported_circuit_name.to_owned(), imported_circuit_type.to_owned());
}
for (imported_function_name, imported_function_type) in imported.functions.iter() {
self.functions
.insert(imported_function_name.to_owned(), imported_function_type.to_owned());
}
Ok(())
}
///
/// Returns a reference to the circuit type corresponding to the name.
///
@ -195,9 +213,14 @@ impl SymbolTable {
/// If a circuit or function name has no duplicates, then it is inserted into the symbol table.
/// Variables defined later in the unresolved program cannot have the same name.
///
pub fn check_names(&mut self, program: &Program, import_parser: &ImportParser) -> Result<(), SymbolTableError> {
pub fn check_names(
&mut self,
program: &Program,
import_parser: &ImportParser,
input: &Input,
) -> Result<(), SymbolTableError> {
// Check unresolved program import names.
self.check_import_names(&program.imports, import_parser)?;
self.check_import_names(&program.imports, import_parser, input)?;
// Check unresolved program circuit names.
self.check_circuit_names(&program.circuits)?;
@ -250,10 +273,11 @@ impl SymbolTable {
&mut self,
imports: &[ImportStatement],
import_parser: &ImportParser,
input: &Input,
) -> Result<(), SymbolTableError> {
// Iterate over imported names.
for import in imports {
self.check_import_statement(import, import_parser)?;
self.check_import_statement(import, import_parser, input)?;
}
Ok(())
@ -269,6 +293,7 @@ impl SymbolTable {
&mut self,
import: &ImportStatement,
import_parser: &ImportParser,
input: &Input,
) -> Result<(), SymbolTableError> {
// Check if the import name exists as core package.
let core_package = import_parser.get_core_package(&import.package);
@ -279,7 +304,7 @@ impl SymbolTable {
}
// Attempt to insert the imported names into the symbol table.
self.check_package(import, import_parser)
self.check_package(import, import_parser, input)
}
///
@ -320,6 +345,7 @@ impl SymbolTable {
&mut self,
import: &ImportStatement,
import_parser: &ImportParser,
input: &Input,
) -> Result<(), SymbolTableError> {
// Get imported symbols from statement.
let imported_symbols = ImportedSymbols::new(import);
@ -341,14 +367,11 @@ impl SymbolTable {
continue;
};
// Check the imported program for duplicate types.
self.check_names(program, import_parser)?;
// Check the imported program for duplicate or undefined types.
let import_symbol_table = SymbolTable::new(program, import_parser, input)?;
// Check the imported program for undefined types.
self.check_types(program)?;
// Store the imported symbol.
// self.insert_import_symbol(symbol, program)?; // TODO (collinc97) uncomment this line when public/private import scopes are implemented.
// Import symbols into the self symbol table.
self.insert_import_symbol(symbol, import_symbol_table, program)?;
}
Ok(())
@ -357,37 +380,34 @@ impl SymbolTable {
///
/// Inserts the imported symbol into the symbol table if it is present in the given program.
///
pub fn insert_import_symbol(&mut self, symbol: ImportSymbol, program: &Program) -> Result<(), SymbolTableError> {
pub fn insert_import_symbol(
&mut self,
symbol: ImportSymbol,
table: SymbolTable,
program: &Program,
) -> Result<(), SymbolTableError> {
// Check for import *.
if symbol.is_star() {
// Insert all program circuits.
self.check_circuit_names(&program.circuits)?;
// Insert all program functions.
self.check_function_names(&program.functions)
// Insert all program circuits and functions.
self.insert_table_types(&table)
} else {
// Check for a symbol alias.
let identifier = symbol.alias.to_owned().unwrap_or_else(|| symbol.symbol.to_owned());
let identifier = symbol.alias.to_owned().unwrap_or_else(|| symbol.symbol.clone());
// Check if the imported symbol is a circuit
match program.circuits.get(&symbol.symbol) {
Some(circuit) => {
// Insert imported circuit.
self.insert_circuit_name(identifier.to_string(), UserDefinedType::from(circuit.to_owned()))
}
None => {
// Check if the imported symbol is a function.
match program.functions.get(&symbol.symbol) {
Some(function) => {
// Insert the imported function.
self.insert_function_name(
identifier.to_string(),
UserDefinedType::from(function.to_owned()),
)
}
None => Err(SymbolTableError::unknown_symbol(&symbol, program)),
}
}
// Check if the imported symbol is a circuit or a function.
if let Some(circuit_type) = table.get_circuit_type(&symbol.symbol.name) {
// Insert the circuit into the self symbol table.
self.insert_circuit_type(identifier, circuit_type.to_owned());
Ok(())
} else if let Some(function_type) = table.get_function_type(&symbol.symbol.name) {
// Insert the function into the self symbol table.
self.insert_function_type(identifier, function_type.to_owned());
Ok(())
} else {
// Return an error if we cannot find the imported symbol.
Err(SymbolTableError::unknown_symbol(&symbol, program))
}
}
}

View File

@ -52,13 +52,10 @@ impl TestSymbolTable {
///
/// Expect no errors during parsing.
///
pub fn expect_success(self) {
pub fn expect_success(self, import_parser: ImportParser) {
// Get program.
let program = self.ast.into_repr();
// Create empty import parser.
let import_parser = ImportParser::default();
// Create empty input.
let input = Input::new();
@ -82,7 +79,9 @@ impl TestSymbolTable {
let import_parser = ImportParser::default();
// Run pass one and expect an error.
let error = static_check.check_names(&program, &import_parser).unwrap_err();
let error = static_check
.check_names(&program, &import_parser, &Input::new())
.unwrap_err();
match error {
SymbolTableError::Error(_) => {} // Ok
@ -106,7 +105,9 @@ impl TestSymbolTable {
let import_parser = ImportParser::default();
// Run the pass one and expect no errors.
static_check.check_names(&program, &import_parser).unwrap();
static_check
.check_names(&program, &import_parser, &Input::new())
.unwrap();
// Run the pass two and expect and error.
let error = static_check.check_types(&program).unwrap_err();

View File

@ -0,0 +1,8 @@
import bar.Bar as Baz;
circuit Bar {
b: u32,
}
function main() {
let b = Baz { b: 0u32 };
}

View File

@ -0,0 +1,3 @@
circuit Bar {
b: u32
}

View File

@ -16,6 +16,8 @@
use crate::TestSymbolTable;
use leo_imports::ImportParser;
///
/// Defines a circuit `Foo {}`.
/// Attempts to define a second circuit `Foo {}`.
@ -73,3 +75,19 @@ fn test_undefined_circuit() {
resolver.expect_pass_two_error();
}
#[test]
fn test_import_alias() {
let program_string = include_str!("import_alias.leo");
let import_string = include_str!("imports/bar.leo");
let program_table = TestSymbolTable::new(program_string);
let import_table = TestSymbolTable::new(import_string);
let import_program = import_table.ast.into_repr();
let mut imports = ImportParser::default();
imports.insert_import("bar".to_owned(), import_program);
program_table.expect_success(imports);
}