diff --git a/dynamic-check/src/dynamic_check.rs b/dynamic-check/src/dynamic_check.rs index 333a34a41e..53b4fc8d4d 100644 --- a/dynamic-check/src/dynamic_check.rs +++ b/dynamic-check/src/dynamic_check.rs @@ -14,15 +14,15 @@ // You should have received a copy of the GNU General Public License // along with the Leo library. If not, see . +use leo_static_check::{FunctionInputType, FunctionType, SymbolTable, Type}; use leo_typed::{Expression, Function, Identifier, Program, Span, Statement}; -use leo_static_check::{FunctionType, SymbolTable, Type}; use serde::{Deserialize, Serialize}; -use std::collections::HashSet; +use std::collections::HashMap; /// Performs a dynamic type inference check over a program. pub struct DynamicCheck { - symbol_table: SymbolTable, + table: SymbolTable, functions: Vec, } @@ -32,7 +32,7 @@ impl DynamicCheck { /// pub fn new(program: &Program, symbol_table: SymbolTable) -> Self { let mut dynamic_check = Self { - symbol_table, + table: symbol_table, functions: vec![], }; @@ -67,7 +67,7 @@ impl DynamicCheck { /// Collects a vector of `TypeAssertion` predicates from a function. /// fn parse_function(&mut self, function: &Function) { - let function_body = FunctionBody::new(function.clone(), self.symbol_table.clone()); + let function_body = FunctionBody::new(function.clone(), self.table.clone()); self.functions.push(function_body); } @@ -92,9 +92,9 @@ impl DynamicCheck { #[derive(Clone)] pub struct FunctionBody { function_type: FunctionType, - symbol_table: SymbolTable, + user_defined_types: SymbolTable, type_assertions: Vec, - type_variables: HashSet, + variable_table: VariableTable, } impl FunctionBody { @@ -107,18 +107,21 @@ impl FunctionBody { // Get function type from symbol table. let function_type = symbol_table.get_function(name).unwrap().clone(); + // Build symbol table for variables. + let mut variable_table = VariableTable::new(); + + // Initialize function inputs as variables. + variable_table.parse_function_inputs(&function_type.inputs); + // Create new function body struct. + // Update variables when encountering let/const variable definitions. let mut function_body = Self { function_type, - symbol_table, + user_defined_types: symbol_table, type_assertions: vec![], - type_variables: HashSet::new(), + variable_table, }; - // Build symbol table for variables. - // Initialize function inputs as variables. - // Update inputs when encountering let/const variable definitions. - // Create type assertions for function statements function_body.parse_statements(&function.statements); @@ -157,7 +160,7 @@ impl FunctionBody { let left = TypeElement::Type(output_type.clone()); // Create the right hand side from the statement return expression. - let right = TypeElement::new(expression, self.symbol_table.clone()); + let right = TypeElement::new(expression, self.user_defined_types.clone()); // Create a new type assertion for the statement return. let type_assertion = TypeAssertion::new(left, right); @@ -198,6 +201,44 @@ impl FunctionBody { } } +/// A structure for tracking the types of user defined variables in a program. +#[derive(Clone)] +pub struct VariableTable(pub HashMap); + +impl VariableTable { + /// + /// Return a new variable table + /// + pub fn new() -> Self { + Self(HashMap::new()) + } + + /// + /// Insert a name -> type pair into the variable table. + /// + /// If the variable table did not have this key present, [`None`] is returned. + /// + /// If the variable table did have this key present, the type is updated, and the old + /// type is returned. + /// + pub fn insert(&mut self, name: String, type_: Type) -> Option { + self.0.insert(name, type_) + } + + /// + /// Inserts a vector of function input types into the variable table. + /// + pub fn parse_function_inputs(&mut self, function_inputs: &Vec) { + for input in function_inputs { + let input_name = input.identifier().name.clone(); + let input_type = input.type_().clone(); + + // TODO (collinc97) throw an error for duplicate function input names. + self.insert(input_name, input_type); + } + } +} + /// A predicate that evaluates equality between two `TypeElement`s. #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] pub struct TypeAssertion { diff --git a/static-check/src/static_check.rs b/static-check/src/static_check.rs index 41508fc03e..3cf3cfc525 100644 --- a/static-check/src/static_check.rs +++ b/static-check/src/static_check.rs @@ -15,7 +15,7 @@ // along with the Leo library. If not, see . use crate::{SymbolTable, SymbolTableError}; -use leo_typed::Program as UnresolvedProgram; +use leo_typed::Program; /// Performs a static type check over a program. pub struct StaticCheck { @@ -26,7 +26,7 @@ impl StaticCheck { /// /// Return a new `StaticCheck` from a given program. /// - pub fn new(program: &UnresolvedProgram) -> Result { + pub fn new(program: &Program) -> Result { let mut check = Self { table: SymbolTable::new(None), }; @@ -46,7 +46,7 @@ impl StaticCheck { /// If a circuit or function name has no duplicates, then it is inserted into the symbol table. /// Variables defined later in the unresolved program cannot have the same name. /// - pub fn pass_one(&mut self, program: &UnresolvedProgram) -> Result<(), SymbolTableError> { + pub fn pass_one(&mut self, program: &Program) -> Result<(), SymbolTableError> { // Check unresolved program circuit names. self.table.check_duplicate_circuits(&program.circuits)?; @@ -63,7 +63,7 @@ impl StaticCheck { /// symbol table. Variables defined later in the unresolved program can lookup the definition and /// refer to its expected types. /// - pub fn pass_two(&mut self, program: &UnresolvedProgram) -> Result<(), SymbolTableError> { + pub fn pass_two(&mut self, program: &Program) -> Result<(), SymbolTableError> { // Check unresolved program circuit definitions. self.table.check_unknown_types_circuits(&program.circuits)?;