Configure special types

This commit is contained in:
Chris Done 2017-04-18 11:46:41 +01:00
parent 93728e2c31
commit 0d4496a789

View File

@ -1,3 +1,4 @@
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# OPTIONS -Wno-incomplete-patterns #-}
@ -17,6 +18,8 @@ module THIH
-- * Setting up
, addClass
, addInstance
, defaultSpecialTypes
, SpecialTypes(..)
, ClassEnvironment(..)
, ReadException(..)
-- * Printers
@ -56,6 +59,14 @@ import Data.Typeable
--------------------------------------------------------------------------------
-- Types
data SpecialTypes = SpecialTypes
{ specialTypesBool :: Type
, specialTypesChar :: Type
, specialTypesString :: Type
, specialTypesFunction :: Type
, specialTypesList :: Type
} deriving (Show)
-- | Type inference monad.
newtype InferT m a = InferT
{ runInferT :: StateT InferState m a
@ -65,6 +76,7 @@ newtype InferT m a = InferT
data InferState = InferState
{ inferStateSubstitutions :: ![Substitution]
, inferStateCounter :: !Int
, inferStateSpecialTypes :: !SpecialTypes
} deriving (Show)
-- | An exception that may be thrown when reading in source code,
@ -304,6 +316,7 @@ demo = do
[StarKind]
(Qualified [] (makeArrow (GenericType 0) (GenericType 0))))
]
defaultSpecialTypes
[ BindGroup
[ ExplicitlyTypedBinding
"x"
@ -336,10 +349,12 @@ demo = do
]
]
]
mapM_ (putStrLn . printTypeSignature) assumptions
mapM_ (putStrLn . printTypeSignature defaultSpecialTypes) assumptions
where
tInteger :: Type
tInteger = ConstructorType (TypeConstructor "Integer" StarKind)
makeArrow :: Type -> Type -> Type
a `makeArrow` b = ApplicationType (ApplicationType (specialTypesFunction defaultSpecialTypes) a) b
--------------------------------------------------------------------------------
-- Printer
@ -347,12 +362,12 @@ demo = do
printIdentifier :: Identifier -> String
printIdentifier (Identifier i) = i
printTypeSignature :: TypeSignature -> String
printTypeSignature (TypeSignature identifier scheme) =
printIdentifier identifier ++ " :: " ++ printScheme scheme
printTypeSignature :: SpecialTypes -> TypeSignature -> String
printTypeSignature specialTypes (TypeSignature identifier scheme) =
printIdentifier identifier ++ " :: " ++ printScheme specialTypes scheme
printScheme :: Scheme -> [Char]
printScheme (Forall kinds qualifiedType') =
printScheme :: SpecialTypes -> Scheme -> [Char]
printScheme specialTypes (Forall kinds qualifiedType') =
(if null kinds
then ""
else "forall " ++
@ -364,7 +379,7 @@ printScheme (Forall kinds qualifiedType') =
[0 :: Int ..]
kinds) ++
". ") ++
printQualifiedType qualifiedType'
printQualifiedType specialTypes qualifiedType'
printKind :: Kind -> [Char]
printKind =
@ -372,32 +387,32 @@ printKind =
StarKind -> "*"
FunctionKind x' y -> printKind x' ++ " -> " ++ printKind y
printQualifiedType :: Qualified Type -> [Char]
printQualifiedType (Qualified predicates typ) =
printQualifiedType :: SpecialTypes -> Qualified Type -> [Char]
printQualifiedType specialTypes(Qualified predicates typ) =
case predicates of
[] -> printTypeSansParens typ
[] -> printTypeSansParens specialTypes typ
_ ->
"(" ++
intercalate ", " (map printPredicate predicates) ++
") => " ++ printTypeSansParens typ
intercalate ", " (map (printPredicate specialTypes) predicates) ++
") => " ++ printTypeSansParens specialTypes typ
printTypeSansParens :: Type -> [Char]
printTypeSansParens =
printTypeSansParens :: SpecialTypes -> Type -> [Char]
printTypeSansParens specialTypes =
\case
ApplicationType (ApplicationType (ConstructorType (TypeConstructor (Identifier "(->)") _)) x') y' ->
printType x' ++ " -> " ++ printTypeSansParens y'
o -> printType o
ApplicationType (ApplicationType func x') y' | func == specialTypesFunction specialTypes ->
printType specialTypes x' ++ " -> " ++ printTypeSansParens specialTypes y'
o -> printType specialTypes o
printType :: Type -> [Char]
printType =
printType :: SpecialTypes -> Type -> [Char]
printType specialTypes =
\case
VariableType v -> printTypeVariable v
ConstructorType tyCon -> printTypeConstructor tyCon
ApplicationType (ApplicationType (ConstructorType (TypeConstructor (Identifier "(->)") _)) x') y ->
"(" ++ printType x' ++ " -> " ++ printTypeSansParens y ++ ")"
ApplicationType (ConstructorType (TypeConstructor (Identifier "[]") _)) ty ->
"[" ++ printTypeSansParens ty ++ "]"
ApplicationType x' y -> "(" ++ printType x' ++ " " ++ printType y ++ ")"
ApplicationType (ApplicationType func x') y | func == specialTypesFunction specialTypes ->
"(" ++ printType specialTypes x' ++ " -> " ++ printTypeSansParens specialTypes y ++ ")"
ApplicationType list ty | list == specialTypesList specialTypes ->
"[" ++ printTypeSansParens specialTypes ty ++ "]"
ApplicationType x' y -> "(" ++ printType specialTypes x' ++ " " ++ printType specialTypes y ++ ")"
GenericType int -> "a" ++ show int
printTypeConstructor :: TypeConstructor -> String
@ -412,9 +427,9 @@ printTypeVariable (TypeVariable identifier kind) =
StarKind -> printIdentifier identifier
_ -> "(" ++ printIdentifier identifier ++ " :: " ++ printKind kind ++ ")"
printPredicate :: Predicate -> [Char]
printPredicate (IsIn identifier types) =
printIdentifier identifier ++ " " ++ unwords (map printType types)
printPredicate :: SpecialTypes -> Predicate -> [Char]
printPredicate specialTypes (IsIn identifier types) =
printIdentifier identifier ++ " " ++ unwords (map (printType specialTypes) types)
--------------------------------------------------------------------------------
-- Type inference
@ -437,9 +452,10 @@ typeCheckModule
:: MonadThrow m
=> ClassEnvironment -- ^ Set of defined type-classes.
-> [TypeSignature] -- ^ Pre-defined type signatures e.g. for built-ins or FFI.
-> SpecialTypes -- ^ Special types that Haskell uses for pattern matching and literals.
-> [BindGroup] -- ^ Bindings in the module.
-> m [TypeSignature] -- ^ Inferred types for all identifiers.
typeCheckModule ce as bgs =
typeCheckModule ce as specialTypes bgs =
evalStateT
(runInferT $ do
(ps, as') <- inferSequenceTypes inferBindGroupTypes ce as bgs
@ -447,54 +463,75 @@ typeCheckModule ce as bgs =
let rs = reduce ce (map (substitutePredicate s) ps)
s' <- defaultSubst ce [] rs
return (map (substituteTypeSignature (s' @@ s)) as'))
(InferState nullSubst 0)
(InferState nullSubst 0 specialTypes)
--------------------------------------------------------------------------------
-- Built-in types and classes
boolType :: Type
boolType = ConstructorType (TypeConstructor "Bool" StarKind)
charType :: Type
charType = ConstructorType (TypeConstructor "Char" StarKind)
stringType :: Type
stringType = makeListType charType
makeListType :: Type -> Type
makeListType t = ApplicationType listType t
listType :: Type
listType = ConstructorType (TypeConstructor "[]" (FunctionKind StarKind StarKind))
makeArrow :: Type -> Type -> Type
a `makeArrow` b = ApplicationType (ApplicationType tArrow a) b
tArrow :: Type
tArrow =
-- | Special types that Haskell uses for pattern matching and literals.
defaultSpecialTypes :: SpecialTypes
defaultSpecialTypes =
SpecialTypes
{ specialTypesBool = ConstructorType (TypeConstructor "Bool" StarKind)
, specialTypesChar = ConstructorType (TypeConstructor "Char" StarKind)
, specialTypesString = makeListType (specialTypesChar defaultSpecialTypes)
, specialTypesFunction =
ConstructorType
(TypeConstructor
"(->)"
(FunctionKind StarKind (FunctionKind StarKind StarKind)))
, specialTypesList = listType
}
where
makeListType :: Type -> Type
makeListType t = ApplicationType listType t
listType :: Type
listType =
ConstructorType (TypeConstructor "[]" (FunctionKind StarKind StarKind))
numClasses :: [Identifier]
numClasses =
["Num", "Integral", "Floating", "Fractional", "Real", "RealFloat", "RealFrac"]
-- boolType :: Type
-- boolType = ConstructorType (TypeConstructor "Bool" StarKind)
stdClasses :: [Identifier]
stdClasses =
[ "Eq"
, "Ord"
, "Show"
, "Read"
, "Bounded"
, "Enum"
, "Ix"
, "Functor"
, "Monad"
, "MonadPlus"
] ++
numClasses
-- charType :: Type
-- charType = ConstructorType (TypeConstructor "Char" StarKind)
-- stringType :: Type
-- stringType = makeListType charType
-- makeListType :: Type -> Type
-- makeListType t = ApplicationType listType t
-- listType :: Type
-- listType = ConstructorType (TypeConstructor "[]" (FunctionKind StarKind StarKind))
-- makeArrow :: Type -> Type -> Type
-- a `makeArrow` b = ApplicationType (ApplicationType tArrow a) b
-- tArrow :: Type
-- tArrow =
-- ConstructorType
-- (TypeConstructor
-- "(->)"
-- (FunctionKind StarKind (FunctionKind StarKind StarKind)))
-- numClasses :: [Identifier]
-- numClasses =
-- ["Num", "Integral", "Floating", "Fractional", "Real", "RealFloat", "RealFrac"]
-- stdClasses :: [Identifier]
-- stdClasses =
-- [ "Eq"
-- , "Ord"
-- , "Show"
-- , "Read"
-- , "Bounded"
-- , "Enum"
-- , "Ix"
-- , "Functor"
-- , "Monad"
-- , "MonadPlus"
-- ] ++
-- numClasses
--------------------------------------------------------------------------------
-- Substitution
@ -770,13 +807,15 @@ enumId n = Identifier ("v" ++ show n)
inferLiteralType
:: Monad m
=> Literal -> InferT m ([Predicate], Type)
inferLiteralType (CharacterLiteral _) = return ([], charType)
inferLiteralType (IntegerLiteral _) = do
=> SpecialTypes -> Literal -> InferT m ([Predicate], Type)
inferLiteralType specialTypes (CharacterLiteral _) =
return ([], specialTypesChar specialTypes)
inferLiteralType _ (IntegerLiteral _) = do
v <- newVariableType StarKind
return ([IsIn "Num" [v]], v)
inferLiteralType (StringLiteral _) = return ([], stringType)
inferLiteralType (RationalLiteral _) = do
inferLiteralType specialTypes (StringLiteral _) =
return ([], specialTypesString specialTypes)
inferLiteralType _ (RationalLiteral _) = do
v <- newVariableType StarKind
return ([IsIn "Fractional" [v]], v)
@ -793,12 +832,16 @@ inferPattern (AsPattern i pat) = do
(ps, as, t) <- inferPattern pat
return (ps, (TypeSignature i (toScheme t)) : as, t)
inferPattern (LiteralPattern l) = do
(ps, t) <- inferLiteralType l
specialTypes <- InferT (gets inferStateSpecialTypes)
(ps, t) <- inferLiteralType specialTypes l
return (ps, [], t)
inferPattern (ConstructorPattern (TypeSignature _ sc) pats) = do
(ps, as, ts) <- inferPatterns pats
t' <- newVariableType StarKind
(Qualified qs t) <- freshInst sc
specialTypes <- InferT (gets inferStateSpecialTypes)
let makeArrow :: Type -> Type -> Type
a `makeArrow` b = ApplicationType (ApplicationType (specialTypesFunction specialTypes) a) b
unify t (foldr makeArrow t' ts)
return (ps ++ qs, as, t')
inferPattern (LazyPattern pat) = inferPattern pat
@ -964,12 +1007,16 @@ inferExpressionType _ _ (ConstantExpression (TypeSignature _ sc)) = do
(Qualified ps t) <- freshInst sc
return (ps, t)
inferExpressionType _ _ (LiteralExpression l) = do
(ps, t) <- inferLiteralType l
specialTypes <- InferT (gets inferStateSpecialTypes)
(ps, t) <- inferLiteralType specialTypes l
return (ps, t)
inferExpressionType ce as (ApplicationExpression e f) = do
(ps, te) <- inferExpressionType ce as e
(qs, tf) <- inferExpressionType ce as f
t <- newVariableType StarKind
specialTypes <- InferT (gets inferStateSpecialTypes)
let makeArrow :: Type -> Type -> Type
a `makeArrow` b = ApplicationType (ApplicationType (specialTypesFunction specialTypes) a) b
unify (tf `makeArrow` t) te
return (ps ++ qs, t)
inferExpressionType ce as (LetExpression bg e) = do
@ -979,7 +1026,8 @@ inferExpressionType ce as (LetExpression bg e) = do
inferExpressionType ce as (LambdaExpression alt) = inferAltType ce as alt
inferExpressionType ce as (IfExpression e e1 e2) = do
(ps, t) <- inferExpressionType ce as e
unify t boolType
specialTypes <- InferT (gets inferStateSpecialTypes)
unify t (specialTypesBool specialTypes)
(ps1, t1) <- inferExpressionType ce as e1
(ps2, t2) <- inferExpressionType ce as e2
unify t1 t2
@ -1005,6 +1053,9 @@ inferAltType
inferAltType ce as (Alternative pats e) = do
(ps, as', ts) <- inferPatterns pats
(qs, t) <- inferExpressionType ce (as' ++ as) e
specialTypes <- InferT (gets inferStateSpecialTypes)
let makeArrow :: Type -> Type -> Type
a `makeArrow` b = ApplicationType (ApplicationType (specialTypesFunction specialTypes) a) b
return (ps ++ qs, foldr makeArrow t ts)
inferAltTypes
@ -1034,8 +1085,8 @@ candidates ce (Ambiguity v qs) =
| let is = [i | IsIn i _ <- qs]
ts = [t | IsIn _ t <- qs]
, all ([VariableType v] ==) ts
, any (`elem` numClasses) is
, all (`elem` stdClasses) is
-- , any (`elem` numClasses) is
-- , all (`elem` stdClasses) is
, t' <- classEnvironmentDefaults ce
, all (entail ce []) [IsIn i [t'] | i <- is]
]