diff --git a/src/THIH.hs b/src/THIH.hs index 560abc5..d487648 100644 --- a/src/THIH.hs +++ b/src/THIH.hs @@ -1,3 +1,4 @@ +{-# OPTIONS_GHC -fno-warn-name-shadowing #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} {-# OPTIONS -Wno-incomplete-patterns #-} @@ -17,6 +18,8 @@ module THIH -- * Setting up , addClass , addInstance + , defaultSpecialTypes + , SpecialTypes(..) , ClassEnvironment(..) , ReadException(..) -- * Printers @@ -56,6 +59,14 @@ import Data.Typeable -------------------------------------------------------------------------------- -- Types +data SpecialTypes = SpecialTypes + { specialTypesBool :: Type + , specialTypesChar :: Type + , specialTypesString :: Type + , specialTypesFunction :: Type + , specialTypesList :: Type + } deriving (Show) + -- | Type inference monad. newtype InferT m a = InferT { runInferT :: StateT InferState m a @@ -65,6 +76,7 @@ newtype InferT m a = InferT data InferState = InferState { inferStateSubstitutions :: ![Substitution] , inferStateCounter :: !Int + , inferStateSpecialTypes :: !SpecialTypes } deriving (Show) -- | An exception that may be thrown when reading in source code, @@ -304,6 +316,7 @@ demo = do [StarKind] (Qualified [] (makeArrow (GenericType 0) (GenericType 0)))) ] + defaultSpecialTypes [ BindGroup [ ExplicitlyTypedBinding "x" @@ -336,10 +349,12 @@ demo = do ] ] ] - mapM_ (putStrLn . printTypeSignature) assumptions + mapM_ (putStrLn . printTypeSignature defaultSpecialTypes) assumptions where tInteger :: Type tInteger = ConstructorType (TypeConstructor "Integer" StarKind) + makeArrow :: Type -> Type -> Type + a `makeArrow` b = ApplicationType (ApplicationType (specialTypesFunction defaultSpecialTypes) a) b -------------------------------------------------------------------------------- -- Printer @@ -347,12 +362,12 @@ demo = do printIdentifier :: Identifier -> String printIdentifier (Identifier i) = i -printTypeSignature :: TypeSignature -> String -printTypeSignature (TypeSignature identifier scheme) = - printIdentifier identifier ++ " :: " ++ printScheme scheme +printTypeSignature :: SpecialTypes -> TypeSignature -> String +printTypeSignature specialTypes (TypeSignature identifier scheme) = + printIdentifier identifier ++ " :: " ++ printScheme specialTypes scheme -printScheme :: Scheme -> [Char] -printScheme (Forall kinds qualifiedType') = +printScheme :: SpecialTypes -> Scheme -> [Char] +printScheme specialTypes (Forall kinds qualifiedType') = (if null kinds then "" else "forall " ++ @@ -364,7 +379,7 @@ printScheme (Forall kinds qualifiedType') = [0 :: Int ..] kinds) ++ ". ") ++ - printQualifiedType qualifiedType' + printQualifiedType specialTypes qualifiedType' printKind :: Kind -> [Char] printKind = @@ -372,32 +387,32 @@ printKind = StarKind -> "*" FunctionKind x' y -> printKind x' ++ " -> " ++ printKind y -printQualifiedType :: Qualified Type -> [Char] -printQualifiedType (Qualified predicates typ) = +printQualifiedType :: SpecialTypes -> Qualified Type -> [Char] +printQualifiedType specialTypes(Qualified predicates typ) = case predicates of - [] -> printTypeSansParens typ + [] -> printTypeSansParens specialTypes typ _ -> "(" ++ - intercalate ", " (map printPredicate predicates) ++ - ") => " ++ printTypeSansParens typ + intercalate ", " (map (printPredicate specialTypes) predicates) ++ + ") => " ++ printTypeSansParens specialTypes typ -printTypeSansParens :: Type -> [Char] -printTypeSansParens = +printTypeSansParens :: SpecialTypes -> Type -> [Char] +printTypeSansParens specialTypes = \case - ApplicationType (ApplicationType (ConstructorType (TypeConstructor (Identifier "(->)") _)) x') y' -> - printType x' ++ " -> " ++ printTypeSansParens y' - o -> printType o + ApplicationType (ApplicationType func x') y' | func == specialTypesFunction specialTypes -> + printType specialTypes x' ++ " -> " ++ printTypeSansParens specialTypes y' + o -> printType specialTypes o -printType :: Type -> [Char] -printType = +printType :: SpecialTypes -> Type -> [Char] +printType specialTypes = \case VariableType v -> printTypeVariable v ConstructorType tyCon -> printTypeConstructor tyCon - ApplicationType (ApplicationType (ConstructorType (TypeConstructor (Identifier "(->)") _)) x') y -> - "(" ++ printType x' ++ " -> " ++ printTypeSansParens y ++ ")" - ApplicationType (ConstructorType (TypeConstructor (Identifier "[]") _)) ty -> - "[" ++ printTypeSansParens ty ++ "]" - ApplicationType x' y -> "(" ++ printType x' ++ " " ++ printType y ++ ")" + ApplicationType (ApplicationType func x') y | func == specialTypesFunction specialTypes -> + "(" ++ printType specialTypes x' ++ " -> " ++ printTypeSansParens specialTypes y ++ ")" + ApplicationType list ty | list == specialTypesList specialTypes -> + "[" ++ printTypeSansParens specialTypes ty ++ "]" + ApplicationType x' y -> "(" ++ printType specialTypes x' ++ " " ++ printType specialTypes y ++ ")" GenericType int -> "a" ++ show int printTypeConstructor :: TypeConstructor -> String @@ -412,9 +427,9 @@ printTypeVariable (TypeVariable identifier kind) = StarKind -> printIdentifier identifier _ -> "(" ++ printIdentifier identifier ++ " :: " ++ printKind kind ++ ")" -printPredicate :: Predicate -> [Char] -printPredicate (IsIn identifier types) = - printIdentifier identifier ++ " " ++ unwords (map printType types) +printPredicate :: SpecialTypes -> Predicate -> [Char] +printPredicate specialTypes (IsIn identifier types) = + printIdentifier identifier ++ " " ++ unwords (map (printType specialTypes) types) -------------------------------------------------------------------------------- -- Type inference @@ -436,10 +451,11 @@ printPredicate (IsIn identifier types) = typeCheckModule :: MonadThrow m => ClassEnvironment -- ^ Set of defined type-classes. - -> [TypeSignature] -- ^ Pre-defined type signatures e.g. for built-ins or FFI. - -> [BindGroup] -- ^ Bindings in the module. - -> m [TypeSignature] -- ^ Inferred types for all identifiers. -typeCheckModule ce as bgs = + -> [TypeSignature] -- ^ Pre-defined type signatures e.g. for built-ins or FFI. + -> SpecialTypes -- ^ Special types that Haskell uses for pattern matching and literals. + -> [BindGroup] -- ^ Bindings in the module. + -> m [TypeSignature] -- ^ Inferred types for all identifiers. +typeCheckModule ce as specialTypes bgs = evalStateT (runInferT $ do (ps, as') <- inferSequenceTypes inferBindGroupTypes ce as bgs @@ -447,54 +463,75 @@ typeCheckModule ce as bgs = let rs = reduce ce (map (substitutePredicate s) ps) s' <- defaultSubst ce [] rs return (map (substituteTypeSignature (s' @@ s)) as')) - (InferState nullSubst 0) + (InferState nullSubst 0 specialTypes) -------------------------------------------------------------------------------- -- Built-in types and classes -boolType :: Type -boolType = ConstructorType (TypeConstructor "Bool" StarKind) +-- | Special types that Haskell uses for pattern matching and literals. +defaultSpecialTypes :: SpecialTypes +defaultSpecialTypes = + SpecialTypes + { specialTypesBool = ConstructorType (TypeConstructor "Bool" StarKind) + , specialTypesChar = ConstructorType (TypeConstructor "Char" StarKind) + , specialTypesString = makeListType (specialTypesChar defaultSpecialTypes) + , specialTypesFunction = + ConstructorType + (TypeConstructor + "(->)" + (FunctionKind StarKind (FunctionKind StarKind StarKind))) + , specialTypesList = listType + } + where + makeListType :: Type -> Type + makeListType t = ApplicationType listType t + listType :: Type + listType = + ConstructorType (TypeConstructor "[]" (FunctionKind StarKind StarKind)) -charType :: Type -charType = ConstructorType (TypeConstructor "Char" StarKind) +-- boolType :: Type +-- boolType = ConstructorType (TypeConstructor "Bool" StarKind) -stringType :: Type -stringType = makeListType charType +-- charType :: Type +-- charType = ConstructorType (TypeConstructor "Char" StarKind) -makeListType :: Type -> Type -makeListType t = ApplicationType listType t +-- stringType :: Type +-- stringType = makeListType charType -listType :: Type -listType = ConstructorType (TypeConstructor "[]" (FunctionKind StarKind StarKind)) +-- makeListType :: Type -> Type +-- makeListType t = ApplicationType listType t -makeArrow :: Type -> Type -> Type -a `makeArrow` b = ApplicationType (ApplicationType tArrow a) b +-- listType :: Type +-- listType = ConstructorType (TypeConstructor "[]" (FunctionKind StarKind StarKind)) -tArrow :: Type -tArrow = - ConstructorType - (TypeConstructor - "(->)" - (FunctionKind StarKind (FunctionKind StarKind StarKind))) +-- makeArrow :: Type -> Type -> Type +-- a `makeArrow` b = ApplicationType (ApplicationType tArrow a) b -numClasses :: [Identifier] -numClasses = - ["Num", "Integral", "Floating", "Fractional", "Real", "RealFloat", "RealFrac"] +-- tArrow :: Type +-- tArrow = +-- ConstructorType +-- (TypeConstructor +-- "(->)" +-- (FunctionKind StarKind (FunctionKind StarKind StarKind))) -stdClasses :: [Identifier] -stdClasses = - [ "Eq" - , "Ord" - , "Show" - , "Read" - , "Bounded" - , "Enum" - , "Ix" - , "Functor" - , "Monad" - , "MonadPlus" - ] ++ - numClasses +-- numClasses :: [Identifier] +-- numClasses = +-- ["Num", "Integral", "Floating", "Fractional", "Real", "RealFloat", "RealFrac"] + +-- stdClasses :: [Identifier] +-- stdClasses = +-- [ "Eq" +-- , "Ord" +-- , "Show" +-- , "Read" +-- , "Bounded" +-- , "Enum" +-- , "Ix" +-- , "Functor" +-- , "Monad" +-- , "MonadPlus" +-- ] ++ +-- numClasses -------------------------------------------------------------------------------- -- Substitution @@ -770,13 +807,15 @@ enumId n = Identifier ("v" ++ show n) inferLiteralType :: Monad m - => Literal -> InferT m ([Predicate], Type) -inferLiteralType (CharacterLiteral _) = return ([], charType) -inferLiteralType (IntegerLiteral _) = do + => SpecialTypes -> Literal -> InferT m ([Predicate], Type) +inferLiteralType specialTypes (CharacterLiteral _) = + return ([], specialTypesChar specialTypes) +inferLiteralType _ (IntegerLiteral _) = do v <- newVariableType StarKind return ([IsIn "Num" [v]], v) -inferLiteralType (StringLiteral _) = return ([], stringType) -inferLiteralType (RationalLiteral _) = do +inferLiteralType specialTypes (StringLiteral _) = + return ([], specialTypesString specialTypes) +inferLiteralType _ (RationalLiteral _) = do v <- newVariableType StarKind return ([IsIn "Fractional" [v]], v) @@ -793,12 +832,16 @@ inferPattern (AsPattern i pat) = do (ps, as, t) <- inferPattern pat return (ps, (TypeSignature i (toScheme t)) : as, t) inferPattern (LiteralPattern l) = do - (ps, t) <- inferLiteralType l + specialTypes <- InferT (gets inferStateSpecialTypes) + (ps, t) <- inferLiteralType specialTypes l return (ps, [], t) inferPattern (ConstructorPattern (TypeSignature _ sc) pats) = do (ps, as, ts) <- inferPatterns pats t' <- newVariableType StarKind (Qualified qs t) <- freshInst sc + specialTypes <- InferT (gets inferStateSpecialTypes) + let makeArrow :: Type -> Type -> Type + a `makeArrow` b = ApplicationType (ApplicationType (specialTypesFunction specialTypes) a) b unify t (foldr makeArrow t' ts) return (ps ++ qs, as, t') inferPattern (LazyPattern pat) = inferPattern pat @@ -964,12 +1007,16 @@ inferExpressionType _ _ (ConstantExpression (TypeSignature _ sc)) = do (Qualified ps t) <- freshInst sc return (ps, t) inferExpressionType _ _ (LiteralExpression l) = do - (ps, t) <- inferLiteralType l + specialTypes <- InferT (gets inferStateSpecialTypes) + (ps, t) <- inferLiteralType specialTypes l return (ps, t) inferExpressionType ce as (ApplicationExpression e f) = do (ps, te) <- inferExpressionType ce as e (qs, tf) <- inferExpressionType ce as f t <- newVariableType StarKind + specialTypes <- InferT (gets inferStateSpecialTypes) + let makeArrow :: Type -> Type -> Type + a `makeArrow` b = ApplicationType (ApplicationType (specialTypesFunction specialTypes) a) b unify (tf `makeArrow` t) te return (ps ++ qs, t) inferExpressionType ce as (LetExpression bg e) = do @@ -979,7 +1026,8 @@ inferExpressionType ce as (LetExpression bg e) = do inferExpressionType ce as (LambdaExpression alt) = inferAltType ce as alt inferExpressionType ce as (IfExpression e e1 e2) = do (ps, t) <- inferExpressionType ce as e - unify t boolType + specialTypes <- InferT (gets inferStateSpecialTypes) + unify t (specialTypesBool specialTypes) (ps1, t1) <- inferExpressionType ce as e1 (ps2, t2) <- inferExpressionType ce as e2 unify t1 t2 @@ -1005,6 +1053,9 @@ inferAltType inferAltType ce as (Alternative pats e) = do (ps, as', ts) <- inferPatterns pats (qs, t) <- inferExpressionType ce (as' ++ as) e + specialTypes <- InferT (gets inferStateSpecialTypes) + let makeArrow :: Type -> Type -> Type + a `makeArrow` b = ApplicationType (ApplicationType (specialTypesFunction specialTypes) a) b return (ps ++ qs, foldr makeArrow t ts) inferAltTypes @@ -1034,8 +1085,8 @@ candidates ce (Ambiguity v qs) = | let is = [i | IsIn i _ <- qs] ts = [t | IsIn _ t <- qs] , all ([VariableType v] ==) ts - , any (`elem` numClasses) is - , all (`elem` stdClasses) is + -- , any (`elem` numClasses) is + -- , all (`elem` stdClasses) is , t' <- classEnvironmentDefaults ce , all (entail ce []) [IsIn i [t'] | i <- is] ]