Refactor TypeChecker unifyTypes

The dicttype case is more clear now
This commit is contained in:
craigmc08 2021-07-16 16:07:46 -04:00 committed by Craig McIlwrath
parent cf55594eda
commit 2f173bc8b7

View File

@ -57,7 +57,8 @@ import Analyzer.TypeChecker.TypeError
import qualified Analyzer.TypeDefinitions as TD import qualified Analyzer.TypeDefinitions as TD
import Control.Arrow (left) import Control.Arrow (left)
import Control.Monad (foldM) import Control.Monad (foldM)
import qualified Data.HashMap.Strict as H import Data.Foldable (foldl')
import qualified Data.HashMap.Strict as M
import Data.List.NonEmpty (NonEmpty ((:|)), nonEmpty, toList) import Data.List.NonEmpty (NonEmpty ((:|)), nonEmpty, toList)
check :: AST -> T TypedAST check :: AST -> T TypedAST
@ -124,7 +125,7 @@ checkExpr (P.List values) = do
checkExpr (P.Dict entries) = do checkExpr (P.Dict entries) = do
guardUnique $ map fst entries guardUnique $ map fst entries
typedEntries <- zip (map fst entries) <$> mapM (checkExpr . snd) entries typedEntries <- zip (map fst entries) <$> mapM (checkExpr . snd) entries
let dictType = H.fromList $ map (\(key, val) -> (key, DictRequired $ exprType val)) typedEntries let dictType = M.fromList $ map (\(key, val) -> (key, DictRequired $ exprType val)) typedEntries
return $ Dict typedEntries (DictType dictType) return $ Dict typedEntries (DictType dictType)
where where
guardUnique :: [String] -> T () guardUnique :: [String] -> T ()
@ -153,8 +154,8 @@ unify (expr :| exprs) = do
-- >>> unifyTypes StringType StringType -- >>> unifyTypes StringType StringType
-- Right StringType -- Right StringType
-- --
-- >>> unifyTypes (DictType $ H.empty) (DictType $ H.singleton "a" (DictRequired NumberType)) -- >>> unifyTypes (DictType $ M.empty) (DictType $ M.singleton "a" (DictRequired NumberType))
-- Right (DictType (H.singleton "a" (DictOptional NumberType))) -- Right (DictType (M.singleton "a" (DictOptional NumberType)))
unifyTypes :: Type -> Type -> Either TypeError Type unifyTypes :: Type -> Type -> Either TypeError Type
-- Trivial case: two identical types unify to themselves -- Trivial case: two identical types unify to themselves
unifyTypes s t unifyTypes s t
@ -169,30 +170,26 @@ unifyTypes s@(DeclType _) t = Left $ UnificationError ReasonDecl s t
unifyTypes s@(EnumType _) t = Left $ UnificationError ReasonEnum s t unifyTypes s@(EnumType _) t = Left $ UnificationError ReasonEnum s t
-- The unification of two dictionaries is defined by the [DictNone] and [DictSome] rules -- The unification of two dictionaries is defined by the [DictNone] and [DictSome] rules
unifyTypes typS@(DictType s) typT@(DictType t) = do unifyTypes typS@(DictType s) typT@(DictType t) = do
-- Rules are applied in both directions, then unioned because s may not let keys = M.keysSet s <> M.keysSet t
-- have keys that t does, or vice versa unifiedType <- foldMapM (\key -> M.singleton key <$> unifyEntryTypesForKey key) keys
-- TODO: should this be improved? return $ DictType unifiedType
onS <- foldMapMWithKey (go t) s
onT <- foldMapMWithKey (go s) t
return $ DictType $ onS <> onT
where where
-- Tries to apply [DictSome] and [DictNone] rules to s and u unifyEntryTypesForKey :: String -> Either TypeError DictEntryType
-- TODO: better name unifyEntryTypesForKey key = annotateError key $ case (M.lookup key s, M.lookup key t) of
go :: H.HashMap String DictEntryType -> String -> DictEntryType -> Either TypeError (H.HashMap String DictEntryType) (Nothing, Nothing) ->
go u k (DictRequired s') = annotateError k $ case H.lookup k u of error "impossible: unifyTypes.unifyEntryTypesForKey should be called with only the keys of s and t"
-- [DictSome] on s, [DictNone] on u -- [DictSome] on s, [DictNone] on t
Nothing -> Right $ H.singleton k (DictOptional s') (Just sType, Nothing) ->
-- No rules applied to s or u Right $ DictOptional $ dictEntryType sType
Just (DictRequired u') -> H.singleton k . DictRequired <$> unifyTypes s' u' -- [DictNone] on s, [DictSome] on t
-- [DictNone] on s (Nothing, Just tType) ->
Just (DictOptional u') -> H.singleton k . DictOptional <$> unifyTypes s' u' Right $ DictOptional $ dictEntryType tType
go u k (DictOptional s') = annotateError k $ case H.lookup k u of -- Both require @key@, so it must be a required entry of the unified entry types
-- [DictNone] on u (Just (DictRequired sType), Just (DictRequired tType)) ->
Nothing -> Right $ H.singleton k (DictOptional s') DictRequired <$> unifyTypes sType tType
-- [DictSome] on u -- One of s or t has @key@ optionally, so it must be an optional entry of the unified entry types
Just (DictRequired u') -> H.singleton k . DictOptional <$> unifyTypes s' u' (Just sType, Just tType) ->
-- No rules applied to s or u DictOptional <$> unifyTypes (dictEntryType sType) (dictEntryType tType)
Just (DictOptional u') -> H.singleton k . DictOptional <$> unifyTypes s' u'
annotateError :: String -> Either TypeError a -> Either TypeError a annotateError :: String -> Either TypeError a -> Either TypeError a
annotateError k = left (\e -> UnificationError (ReasonDictWrongKeyType k e) typS typT) annotateError k = left (\e -> UnificationError (ReasonDictWrongKeyType k e) typS typT)
@ -217,12 +214,12 @@ weaken (ListType typ') expr@(List vals _) =
mapM (weaken typ') vals mapM (weaken typ') vals
weaken (DictType typ') expr@(Dict entries _) = do weaken (DictType typ') expr@(Dict entries _) = do
entries' <- mapM weakenEntry entries entries' <- mapM weakenEntry entries
mapM_ guardHasEntry $ H.toList typ' mapM_ guardHasEntry $ M.toList typ'
return $ Dict entries' $ DictType typ' return $ Dict entries' $ DictType typ'
where where
-- Tries to apply [DictSome] and [DictNone] rules to the entries of the dict -- Tries to apply [DictSome] and [DictNone] rules to the entries of the dict
weakenEntry :: (String, TypedExpr) -> Either TypeError (Ident, TypedExpr) weakenEntry :: (String, TypedExpr) -> Either TypeError (Ident, TypedExpr)
weakenEntry (key, value) = case H.lookup key typ' of weakenEntry (key, value) = case M.lookup key typ' of
-- @key@ is missing from @typ'@ => extra keys are not allowed -- @key@ is missing from @typ'@ => extra keys are not allowed
Nothing -> Left $ WeakenError (ReasonDictExtraKey key) expr (DictType typ') Nothing -> Left $ WeakenError (ReasonDictExtraKey key) expr (DictType typ')
-- @key@ is required and present => only need to weaken the value's type -- @key@ is required and present => only need to weaken the value's type
@ -253,6 +250,5 @@ weaken (DictType typ') expr@(Dict entries _) = do
-- All other cases can not be weakened -- All other cases can not be weakened
weaken typ' expr = Left $ WeakenError ReasonUncoercable expr typ' weaken typ' expr = Left $ WeakenError ReasonUncoercable expr typ'
-- | Like @foldMap@, but runs in a monad @m@ and is specialised for entries of a "HashMap". foldMapM :: (Foldable t, Monad m, Monoid s) => (a -> m s) -> t a -> m s
foldMapMWithKey :: (Monad m, Monoid s) => (k -> v -> m s) -> H.HashMap k v -> m s foldMapM f = foldl' (\ms a -> ms >>= \s -> (s <>) <$> f a) $ pure mempty
foldMapMWithKey f = H.foldlWithKey' (\m k v -> m >>= \s -> (s <>) <$> f k v) $ return mempty