From 621a2f2a9503b1ee31e8b2a8448e4e4054b38c0e Mon Sep 17 00:00:00 2001 From: Pranav Gaddamadugu Date: Tue, 10 Oct 2023 21:06:01 -0400 Subject: [PATCH] Add expressions to the type map --- .../passes/src/common/symbol_table/mod.rs | 4 ++-- .../src/type_checking/check_expressions.rs | 23 +++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/compiler/passes/src/common/symbol_table/mod.rs b/compiler/passes/src/common/symbol_table/mod.rs index 70abaef771..06107598fb 100644 --- a/compiler/passes/src/common/symbol_table/mod.rs +++ b/compiler/passes/src/common/symbol_table/mod.rs @@ -31,6 +31,7 @@ use serde::{Deserialize, Serialize}; use serde_json; // TODO (@d0cd) Consider a safe interface for the symbol table. +// TODO (@d0cd) Cleanup API #[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct SymbolTable { /// The parent scope if it exists. @@ -105,9 +106,8 @@ impl SymbolTable { } /// Inserts a type for a node ID into the symbol table. - pub fn insert_type(&mut self, node_id: NodeID, type_: Type) -> Result<()> { + pub fn insert_type(&mut self, node_id: NodeID, type_: Type) { self.types.insert(node_id, type_); - Ok(()) } /// Removes a variable from the symbol table. diff --git a/compiler/passes/src/type_checking/check_expressions.rs b/compiler/passes/src/type_checking/check_expressions.rs index d420d58163..5e8cd3b922 100644 --- a/compiler/passes/src/type_checking/check_expressions.rs +++ b/compiler/passes/src/type_checking/check_expressions.rs @@ -42,6 +42,29 @@ impl<'a> ExpressionVisitor<'a> for TypeChecker<'a> { type AdditionalInput = Option; type Output = Option; + fn visit_expression(&mut self, input: &'a Expression, additional: &Self::AdditionalInput) -> Self::Output { + let output = match input { + Expression::Access(access) => self.visit_access(access, additional), + Expression::Binary(binary) => self.visit_binary(binary, additional), + Expression::Call(call) => self.visit_call(call, additional), + Expression::Cast(cast) => self.visit_cast(cast, additional), + Expression::Struct(struct_) => self.visit_struct_init(struct_, additional), + Expression::Err(err) => self.visit_err(err, additional), + Expression::Identifier(identifier) => self.visit_identifier(identifier, additional), + Expression::Literal(literal) => self.visit_literal(literal, additional), + Expression::Ternary(ternary) => self.visit_ternary(ternary, additional), + Expression::Tuple(tuple) => self.visit_tuple(tuple, additional), + Expression::Unary(unary) => self.visit_unary(unary, additional), + Expression::Unit(unit) => self.visit_unit(unit, additional), + }; + // If the output type is known, add the expression and its associated type to the symbol table. + if let Some(type_) = &output { + self.symbol_table.borrow_mut().insert_type(input.id(), type_.clone()); + } + // Return the output type. + output + } + fn visit_access(&mut self, input: &'a AccessExpression, expected: &Self::AdditionalInput) -> Self::Output { match input { AccessExpression::Array(access) => {