Configure special types

This commit is contained in:
Chris Done 2017-04-18 11:46:41 +01:00
parent 93728e2c31
commit 0d4496a789

View File

@ -1,3 +1,4 @@
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
{-# OPTIONS -Wno-incomplete-patterns #-} {-# OPTIONS -Wno-incomplete-patterns #-}
@ -17,6 +18,8 @@ module THIH
-- * Setting up -- * Setting up
, addClass , addClass
, addInstance , addInstance
, defaultSpecialTypes
, SpecialTypes(..)
, ClassEnvironment(..) , ClassEnvironment(..)
, ReadException(..) , ReadException(..)
-- * Printers -- * Printers
@ -56,6 +59,14 @@ import Data.Typeable
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- Types -- Types
data SpecialTypes = SpecialTypes
{ specialTypesBool :: Type
, specialTypesChar :: Type
, specialTypesString :: Type
, specialTypesFunction :: Type
, specialTypesList :: Type
} deriving (Show)
-- | Type inference monad. -- | Type inference monad.
newtype InferT m a = InferT newtype InferT m a = InferT
{ runInferT :: StateT InferState m a { runInferT :: StateT InferState m a
@ -65,6 +76,7 @@ newtype InferT m a = InferT
data InferState = InferState data InferState = InferState
{ inferStateSubstitutions :: ![Substitution] { inferStateSubstitutions :: ![Substitution]
, inferStateCounter :: !Int , inferStateCounter :: !Int
, inferStateSpecialTypes :: !SpecialTypes
} deriving (Show) } deriving (Show)
-- | An exception that may be thrown when reading in source code, -- | An exception that may be thrown when reading in source code,
@ -304,6 +316,7 @@ demo = do
[StarKind] [StarKind]
(Qualified [] (makeArrow (GenericType 0) (GenericType 0)))) (Qualified [] (makeArrow (GenericType 0) (GenericType 0))))
] ]
defaultSpecialTypes
[ BindGroup [ BindGroup
[ ExplicitlyTypedBinding [ ExplicitlyTypedBinding
"x" "x"
@ -336,10 +349,12 @@ demo = do
] ]
] ]
] ]
mapM_ (putStrLn . printTypeSignature) assumptions mapM_ (putStrLn . printTypeSignature defaultSpecialTypes) assumptions
where where
tInteger :: Type tInteger :: Type
tInteger = ConstructorType (TypeConstructor "Integer" StarKind) tInteger = ConstructorType (TypeConstructor "Integer" StarKind)
makeArrow :: Type -> Type -> Type
a `makeArrow` b = ApplicationType (ApplicationType (specialTypesFunction defaultSpecialTypes) a) b
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- Printer -- Printer
@ -347,12 +362,12 @@ demo = do
printIdentifier :: Identifier -> String printIdentifier :: Identifier -> String
printIdentifier (Identifier i) = i printIdentifier (Identifier i) = i
printTypeSignature :: TypeSignature -> String printTypeSignature :: SpecialTypes -> TypeSignature -> String
printTypeSignature (TypeSignature identifier scheme) = printTypeSignature specialTypes (TypeSignature identifier scheme) =
printIdentifier identifier ++ " :: " ++ printScheme scheme printIdentifier identifier ++ " :: " ++ printScheme specialTypes scheme
printScheme :: Scheme -> [Char] printScheme :: SpecialTypes -> Scheme -> [Char]
printScheme (Forall kinds qualifiedType') = printScheme specialTypes (Forall kinds qualifiedType') =
(if null kinds (if null kinds
then "" then ""
else "forall " ++ else "forall " ++
@ -364,7 +379,7 @@ printScheme (Forall kinds qualifiedType') =
[0 :: Int ..] [0 :: Int ..]
kinds) ++ kinds) ++
". ") ++ ". ") ++
printQualifiedType qualifiedType' printQualifiedType specialTypes qualifiedType'
printKind :: Kind -> [Char] printKind :: Kind -> [Char]
printKind = printKind =
@ -372,32 +387,32 @@ printKind =
StarKind -> "*" StarKind -> "*"
FunctionKind x' y -> printKind x' ++ " -> " ++ printKind y FunctionKind x' y -> printKind x' ++ " -> " ++ printKind y
printQualifiedType :: Qualified Type -> [Char] printQualifiedType :: SpecialTypes -> Qualified Type -> [Char]
printQualifiedType (Qualified predicates typ) = printQualifiedType specialTypes(Qualified predicates typ) =
case predicates of case predicates of
[] -> printTypeSansParens typ [] -> printTypeSansParens specialTypes typ
_ -> _ ->
"(" ++ "(" ++
intercalate ", " (map printPredicate predicates) ++ intercalate ", " (map (printPredicate specialTypes) predicates) ++
") => " ++ printTypeSansParens typ ") => " ++ printTypeSansParens specialTypes typ
printTypeSansParens :: Type -> [Char] printTypeSansParens :: SpecialTypes -> Type -> [Char]
printTypeSansParens = printTypeSansParens specialTypes =
\case \case
ApplicationType (ApplicationType (ConstructorType (TypeConstructor (Identifier "(->)") _)) x') y' -> ApplicationType (ApplicationType func x') y' | func == specialTypesFunction specialTypes ->
printType x' ++ " -> " ++ printTypeSansParens y' printType specialTypes x' ++ " -> " ++ printTypeSansParens specialTypes y'
o -> printType o o -> printType specialTypes o
printType :: Type -> [Char] printType :: SpecialTypes -> Type -> [Char]
printType = printType specialTypes =
\case \case
VariableType v -> printTypeVariable v VariableType v -> printTypeVariable v
ConstructorType tyCon -> printTypeConstructor tyCon ConstructorType tyCon -> printTypeConstructor tyCon
ApplicationType (ApplicationType (ConstructorType (TypeConstructor (Identifier "(->)") _)) x') y -> ApplicationType (ApplicationType func x') y | func == specialTypesFunction specialTypes ->
"(" ++ printType x' ++ " -> " ++ printTypeSansParens y ++ ")" "(" ++ printType specialTypes x' ++ " -> " ++ printTypeSansParens specialTypes y ++ ")"
ApplicationType (ConstructorType (TypeConstructor (Identifier "[]") _)) ty -> ApplicationType list ty | list == specialTypesList specialTypes ->
"[" ++ printTypeSansParens ty ++ "]" "[" ++ printTypeSansParens specialTypes ty ++ "]"
ApplicationType x' y -> "(" ++ printType x' ++ " " ++ printType y ++ ")" ApplicationType x' y -> "(" ++ printType specialTypes x' ++ " " ++ printType specialTypes y ++ ")"
GenericType int -> "a" ++ show int GenericType int -> "a" ++ show int
printTypeConstructor :: TypeConstructor -> String printTypeConstructor :: TypeConstructor -> String
@ -412,9 +427,9 @@ printTypeVariable (TypeVariable identifier kind) =
StarKind -> printIdentifier identifier StarKind -> printIdentifier identifier
_ -> "(" ++ printIdentifier identifier ++ " :: " ++ printKind kind ++ ")" _ -> "(" ++ printIdentifier identifier ++ " :: " ++ printKind kind ++ ")"
printPredicate :: Predicate -> [Char] printPredicate :: SpecialTypes -> Predicate -> [Char]
printPredicate (IsIn identifier types) = printPredicate specialTypes (IsIn identifier types) =
printIdentifier identifier ++ " " ++ unwords (map printType types) printIdentifier identifier ++ " " ++ unwords (map (printType specialTypes) types)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- Type inference -- Type inference
@ -436,10 +451,11 @@ printPredicate (IsIn identifier types) =
typeCheckModule typeCheckModule
:: MonadThrow m :: MonadThrow m
=> ClassEnvironment -- ^ Set of defined type-classes. => ClassEnvironment -- ^ Set of defined type-classes.
-> [TypeSignature] -- ^ Pre-defined type signatures e.g. for built-ins or FFI. -> [TypeSignature] -- ^ Pre-defined type signatures e.g. for built-ins or FFI.
-> [BindGroup] -- ^ Bindings in the module. -> SpecialTypes -- ^ Special types that Haskell uses for pattern matching and literals.
-> m [TypeSignature] -- ^ Inferred types for all identifiers. -> [BindGroup] -- ^ Bindings in the module.
typeCheckModule ce as bgs = -> m [TypeSignature] -- ^ Inferred types for all identifiers.
typeCheckModule ce as specialTypes bgs =
evalStateT evalStateT
(runInferT $ do (runInferT $ do
(ps, as') <- inferSequenceTypes inferBindGroupTypes ce as bgs (ps, as') <- inferSequenceTypes inferBindGroupTypes ce as bgs
@ -447,54 +463,75 @@ typeCheckModule ce as bgs =
let rs = reduce ce (map (substitutePredicate s) ps) let rs = reduce ce (map (substitutePredicate s) ps)
s' <- defaultSubst ce [] rs s' <- defaultSubst ce [] rs
return (map (substituteTypeSignature (s' @@ s)) as')) return (map (substituteTypeSignature (s' @@ s)) as'))
(InferState nullSubst 0) (InferState nullSubst 0 specialTypes)
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- Built-in types and classes -- Built-in types and classes
boolType :: Type -- | Special types that Haskell uses for pattern matching and literals.
boolType = ConstructorType (TypeConstructor "Bool" StarKind) defaultSpecialTypes :: SpecialTypes
defaultSpecialTypes =
SpecialTypes
{ specialTypesBool = ConstructorType (TypeConstructor "Bool" StarKind)
, specialTypesChar = ConstructorType (TypeConstructor "Char" StarKind)
, specialTypesString = makeListType (specialTypesChar defaultSpecialTypes)
, specialTypesFunction =
ConstructorType
(TypeConstructor
"(->)"
(FunctionKind StarKind (FunctionKind StarKind StarKind)))
, specialTypesList = listType
}
where
makeListType :: Type -> Type
makeListType t = ApplicationType listType t
listType :: Type
listType =
ConstructorType (TypeConstructor "[]" (FunctionKind StarKind StarKind))
charType :: Type -- boolType :: Type
charType = ConstructorType (TypeConstructor "Char" StarKind) -- boolType = ConstructorType (TypeConstructor "Bool" StarKind)
stringType :: Type -- charType :: Type
stringType = makeListType charType -- charType = ConstructorType (TypeConstructor "Char" StarKind)
makeListType :: Type -> Type -- stringType :: Type
makeListType t = ApplicationType listType t -- stringType = makeListType charType
listType :: Type -- makeListType :: Type -> Type
listType = ConstructorType (TypeConstructor "[]" (FunctionKind StarKind StarKind)) -- makeListType t = ApplicationType listType t
makeArrow :: Type -> Type -> Type -- listType :: Type
a `makeArrow` b = ApplicationType (ApplicationType tArrow a) b -- listType = ConstructorType (TypeConstructor "[]" (FunctionKind StarKind StarKind))
tArrow :: Type -- makeArrow :: Type -> Type -> Type
tArrow = -- a `makeArrow` b = ApplicationType (ApplicationType tArrow a) b
ConstructorType
(TypeConstructor
"(->)"
(FunctionKind StarKind (FunctionKind StarKind StarKind)))
numClasses :: [Identifier] -- tArrow :: Type
numClasses = -- tArrow =
["Num", "Integral", "Floating", "Fractional", "Real", "RealFloat", "RealFrac"] -- ConstructorType
-- (TypeConstructor
-- "(->)"
-- (FunctionKind StarKind (FunctionKind StarKind StarKind)))
stdClasses :: [Identifier] -- numClasses :: [Identifier]
stdClasses = -- numClasses =
[ "Eq" -- ["Num", "Integral", "Floating", "Fractional", "Real", "RealFloat", "RealFrac"]
, "Ord"
, "Show" -- stdClasses :: [Identifier]
, "Read" -- stdClasses =
, "Bounded" -- [ "Eq"
, "Enum" -- , "Ord"
, "Ix" -- , "Show"
, "Functor" -- , "Read"
, "Monad" -- , "Bounded"
, "MonadPlus" -- , "Enum"
] ++ -- , "Ix"
numClasses -- , "Functor"
-- , "Monad"
-- , "MonadPlus"
-- ] ++
-- numClasses
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- Substitution -- Substitution
@ -770,13 +807,15 @@ enumId n = Identifier ("v" ++ show n)
inferLiteralType inferLiteralType
:: Monad m :: Monad m
=> Literal -> InferT m ([Predicate], Type) => SpecialTypes -> Literal -> InferT m ([Predicate], Type)
inferLiteralType (CharacterLiteral _) = return ([], charType) inferLiteralType specialTypes (CharacterLiteral _) =
inferLiteralType (IntegerLiteral _) = do return ([], specialTypesChar specialTypes)
inferLiteralType _ (IntegerLiteral _) = do
v <- newVariableType StarKind v <- newVariableType StarKind
return ([IsIn "Num" [v]], v) return ([IsIn "Num" [v]], v)
inferLiteralType (StringLiteral _) = return ([], stringType) inferLiteralType specialTypes (StringLiteral _) =
inferLiteralType (RationalLiteral _) = do return ([], specialTypesString specialTypes)
inferLiteralType _ (RationalLiteral _) = do
v <- newVariableType StarKind v <- newVariableType StarKind
return ([IsIn "Fractional" [v]], v) return ([IsIn "Fractional" [v]], v)
@ -793,12 +832,16 @@ inferPattern (AsPattern i pat) = do
(ps, as, t) <- inferPattern pat (ps, as, t) <- inferPattern pat
return (ps, (TypeSignature i (toScheme t)) : as, t) return (ps, (TypeSignature i (toScheme t)) : as, t)
inferPattern (LiteralPattern l) = do inferPattern (LiteralPattern l) = do
(ps, t) <- inferLiteralType l specialTypes <- InferT (gets inferStateSpecialTypes)
(ps, t) <- inferLiteralType specialTypes l
return (ps, [], t) return (ps, [], t)
inferPattern (ConstructorPattern (TypeSignature _ sc) pats) = do inferPattern (ConstructorPattern (TypeSignature _ sc) pats) = do
(ps, as, ts) <- inferPatterns pats (ps, as, ts) <- inferPatterns pats
t' <- newVariableType StarKind t' <- newVariableType StarKind
(Qualified qs t) <- freshInst sc (Qualified qs t) <- freshInst sc
specialTypes <- InferT (gets inferStateSpecialTypes)
let makeArrow :: Type -> Type -> Type
a `makeArrow` b = ApplicationType (ApplicationType (specialTypesFunction specialTypes) a) b
unify t (foldr makeArrow t' ts) unify t (foldr makeArrow t' ts)
return (ps ++ qs, as, t') return (ps ++ qs, as, t')
inferPattern (LazyPattern pat) = inferPattern pat inferPattern (LazyPattern pat) = inferPattern pat
@ -964,12 +1007,16 @@ inferExpressionType _ _ (ConstantExpression (TypeSignature _ sc)) = do
(Qualified ps t) <- freshInst sc (Qualified ps t) <- freshInst sc
return (ps, t) return (ps, t)
inferExpressionType _ _ (LiteralExpression l) = do inferExpressionType _ _ (LiteralExpression l) = do
(ps, t) <- inferLiteralType l specialTypes <- InferT (gets inferStateSpecialTypes)
(ps, t) <- inferLiteralType specialTypes l
return (ps, t) return (ps, t)
inferExpressionType ce as (ApplicationExpression e f) = do inferExpressionType ce as (ApplicationExpression e f) = do
(ps, te) <- inferExpressionType ce as e (ps, te) <- inferExpressionType ce as e
(qs, tf) <- inferExpressionType ce as f (qs, tf) <- inferExpressionType ce as f
t <- newVariableType StarKind t <- newVariableType StarKind
specialTypes <- InferT (gets inferStateSpecialTypes)
let makeArrow :: Type -> Type -> Type
a `makeArrow` b = ApplicationType (ApplicationType (specialTypesFunction specialTypes) a) b
unify (tf `makeArrow` t) te unify (tf `makeArrow` t) te
return (ps ++ qs, t) return (ps ++ qs, t)
inferExpressionType ce as (LetExpression bg e) = do inferExpressionType ce as (LetExpression bg e) = do
@ -979,7 +1026,8 @@ inferExpressionType ce as (LetExpression bg e) = do
inferExpressionType ce as (LambdaExpression alt) = inferAltType ce as alt inferExpressionType ce as (LambdaExpression alt) = inferAltType ce as alt
inferExpressionType ce as (IfExpression e e1 e2) = do inferExpressionType ce as (IfExpression e e1 e2) = do
(ps, t) <- inferExpressionType ce as e (ps, t) <- inferExpressionType ce as e
unify t boolType specialTypes <- InferT (gets inferStateSpecialTypes)
unify t (specialTypesBool specialTypes)
(ps1, t1) <- inferExpressionType ce as e1 (ps1, t1) <- inferExpressionType ce as e1
(ps2, t2) <- inferExpressionType ce as e2 (ps2, t2) <- inferExpressionType ce as e2
unify t1 t2 unify t1 t2
@ -1005,6 +1053,9 @@ inferAltType
inferAltType ce as (Alternative pats e) = do inferAltType ce as (Alternative pats e) = do
(ps, as', ts) <- inferPatterns pats (ps, as', ts) <- inferPatterns pats
(qs, t) <- inferExpressionType ce (as' ++ as) e (qs, t) <- inferExpressionType ce (as' ++ as) e
specialTypes <- InferT (gets inferStateSpecialTypes)
let makeArrow :: Type -> Type -> Type
a `makeArrow` b = ApplicationType (ApplicationType (specialTypesFunction specialTypes) a) b
return (ps ++ qs, foldr makeArrow t ts) return (ps ++ qs, foldr makeArrow t ts)
inferAltTypes inferAltTypes
@ -1034,8 +1085,8 @@ candidates ce (Ambiguity v qs) =
| let is = [i | IsIn i _ <- qs] | let is = [i | IsIn i _ <- qs]
ts = [t | IsIn _ t <- qs] ts = [t | IsIn _ t <- qs]
, all ([VariableType v] ==) ts , all ([VariableType v] ==) ts
, any (`elem` numClasses) is -- , any (`elem` numClasses) is
, all (`elem` stdClasses) is -- , all (`elem` stdClasses) is
, t' <- classEnvironmentDefaults ce , t' <- classEnvironmentDefaults ce
, all (entail ce []) [IsIn i [t'] | i <- is] , all (entail ce []) [IsIn i [t'] | i <- is]
] ]