mirror of
https://github.com/chrisdone-archive/duet.git
synced 2024-11-29 09:25:33 +03:00
Configure special types
This commit is contained in:
parent
93728e2c31
commit
0d4496a789
205
src/THIH.hs
205
src/THIH.hs
@ -1,3 +1,4 @@
|
||||
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
|
||||
{-# LANGUAGE FlexibleInstances #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
{-# OPTIONS -Wno-incomplete-patterns #-}
|
||||
@ -17,6 +18,8 @@ module THIH
|
||||
-- * Setting up
|
||||
, addClass
|
||||
, addInstance
|
||||
, defaultSpecialTypes
|
||||
, SpecialTypes(..)
|
||||
, ClassEnvironment(..)
|
||||
, ReadException(..)
|
||||
-- * Printers
|
||||
@ -56,6 +59,14 @@ import Data.Typeable
|
||||
--------------------------------------------------------------------------------
|
||||
-- Types
|
||||
|
||||
data SpecialTypes = SpecialTypes
|
||||
{ specialTypesBool :: Type
|
||||
, specialTypesChar :: Type
|
||||
, specialTypesString :: Type
|
||||
, specialTypesFunction :: Type
|
||||
, specialTypesList :: Type
|
||||
} deriving (Show)
|
||||
|
||||
-- | Type inference monad.
|
||||
newtype InferT m a = InferT
|
||||
{ runInferT :: StateT InferState m a
|
||||
@ -65,6 +76,7 @@ newtype InferT m a = InferT
|
||||
data InferState = InferState
|
||||
{ inferStateSubstitutions :: ![Substitution]
|
||||
, inferStateCounter :: !Int
|
||||
, inferStateSpecialTypes :: !SpecialTypes
|
||||
} deriving (Show)
|
||||
|
||||
-- | An exception that may be thrown when reading in source code,
|
||||
@ -304,6 +316,7 @@ demo = do
|
||||
[StarKind]
|
||||
(Qualified [] (makeArrow (GenericType 0) (GenericType 0))))
|
||||
]
|
||||
defaultSpecialTypes
|
||||
[ BindGroup
|
||||
[ ExplicitlyTypedBinding
|
||||
"x"
|
||||
@ -336,10 +349,12 @@ demo = do
|
||||
]
|
||||
]
|
||||
]
|
||||
mapM_ (putStrLn . printTypeSignature) assumptions
|
||||
mapM_ (putStrLn . printTypeSignature defaultSpecialTypes) assumptions
|
||||
where
|
||||
tInteger :: Type
|
||||
tInteger = ConstructorType (TypeConstructor "Integer" StarKind)
|
||||
makeArrow :: Type -> Type -> Type
|
||||
a `makeArrow` b = ApplicationType (ApplicationType (specialTypesFunction defaultSpecialTypes) a) b
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Printer
|
||||
@ -347,12 +362,12 @@ demo = do
|
||||
printIdentifier :: Identifier -> String
|
||||
printIdentifier (Identifier i) = i
|
||||
|
||||
printTypeSignature :: TypeSignature -> String
|
||||
printTypeSignature (TypeSignature identifier scheme) =
|
||||
printIdentifier identifier ++ " :: " ++ printScheme scheme
|
||||
printTypeSignature :: SpecialTypes -> TypeSignature -> String
|
||||
printTypeSignature specialTypes (TypeSignature identifier scheme) =
|
||||
printIdentifier identifier ++ " :: " ++ printScheme specialTypes scheme
|
||||
|
||||
printScheme :: Scheme -> [Char]
|
||||
printScheme (Forall kinds qualifiedType') =
|
||||
printScheme :: SpecialTypes -> Scheme -> [Char]
|
||||
printScheme specialTypes (Forall kinds qualifiedType') =
|
||||
(if null kinds
|
||||
then ""
|
||||
else "forall " ++
|
||||
@ -364,7 +379,7 @@ printScheme (Forall kinds qualifiedType') =
|
||||
[0 :: Int ..]
|
||||
kinds) ++
|
||||
". ") ++
|
||||
printQualifiedType qualifiedType'
|
||||
printQualifiedType specialTypes qualifiedType'
|
||||
|
||||
printKind :: Kind -> [Char]
|
||||
printKind =
|
||||
@ -372,32 +387,32 @@ printKind =
|
||||
StarKind -> "*"
|
||||
FunctionKind x' y -> printKind x' ++ " -> " ++ printKind y
|
||||
|
||||
printQualifiedType :: Qualified Type -> [Char]
|
||||
printQualifiedType (Qualified predicates typ) =
|
||||
printQualifiedType :: SpecialTypes -> Qualified Type -> [Char]
|
||||
printQualifiedType specialTypes(Qualified predicates typ) =
|
||||
case predicates of
|
||||
[] -> printTypeSansParens typ
|
||||
[] -> printTypeSansParens specialTypes typ
|
||||
_ ->
|
||||
"(" ++
|
||||
intercalate ", " (map printPredicate predicates) ++
|
||||
") => " ++ printTypeSansParens typ
|
||||
intercalate ", " (map (printPredicate specialTypes) predicates) ++
|
||||
") => " ++ printTypeSansParens specialTypes typ
|
||||
|
||||
printTypeSansParens :: Type -> [Char]
|
||||
printTypeSansParens =
|
||||
printTypeSansParens :: SpecialTypes -> Type -> [Char]
|
||||
printTypeSansParens specialTypes =
|
||||
\case
|
||||
ApplicationType (ApplicationType (ConstructorType (TypeConstructor (Identifier "(->)") _)) x') y' ->
|
||||
printType x' ++ " -> " ++ printTypeSansParens y'
|
||||
o -> printType o
|
||||
ApplicationType (ApplicationType func x') y' | func == specialTypesFunction specialTypes ->
|
||||
printType specialTypes x' ++ " -> " ++ printTypeSansParens specialTypes y'
|
||||
o -> printType specialTypes o
|
||||
|
||||
printType :: Type -> [Char]
|
||||
printType =
|
||||
printType :: SpecialTypes -> Type -> [Char]
|
||||
printType specialTypes =
|
||||
\case
|
||||
VariableType v -> printTypeVariable v
|
||||
ConstructorType tyCon -> printTypeConstructor tyCon
|
||||
ApplicationType (ApplicationType (ConstructorType (TypeConstructor (Identifier "(->)") _)) x') y ->
|
||||
"(" ++ printType x' ++ " -> " ++ printTypeSansParens y ++ ")"
|
||||
ApplicationType (ConstructorType (TypeConstructor (Identifier "[]") _)) ty ->
|
||||
"[" ++ printTypeSansParens ty ++ "]"
|
||||
ApplicationType x' y -> "(" ++ printType x' ++ " " ++ printType y ++ ")"
|
||||
ApplicationType (ApplicationType func x') y | func == specialTypesFunction specialTypes ->
|
||||
"(" ++ printType specialTypes x' ++ " -> " ++ printTypeSansParens specialTypes y ++ ")"
|
||||
ApplicationType list ty | list == specialTypesList specialTypes ->
|
||||
"[" ++ printTypeSansParens specialTypes ty ++ "]"
|
||||
ApplicationType x' y -> "(" ++ printType specialTypes x' ++ " " ++ printType specialTypes y ++ ")"
|
||||
GenericType int -> "a" ++ show int
|
||||
|
||||
printTypeConstructor :: TypeConstructor -> String
|
||||
@ -412,9 +427,9 @@ printTypeVariable (TypeVariable identifier kind) =
|
||||
StarKind -> printIdentifier identifier
|
||||
_ -> "(" ++ printIdentifier identifier ++ " :: " ++ printKind kind ++ ")"
|
||||
|
||||
printPredicate :: Predicate -> [Char]
|
||||
printPredicate (IsIn identifier types) =
|
||||
printIdentifier identifier ++ " " ++ unwords (map printType types)
|
||||
printPredicate :: SpecialTypes -> Predicate -> [Char]
|
||||
printPredicate specialTypes (IsIn identifier types) =
|
||||
printIdentifier identifier ++ " " ++ unwords (map (printType specialTypes) types)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Type inference
|
||||
@ -436,10 +451,11 @@ printPredicate (IsIn identifier types) =
|
||||
typeCheckModule
|
||||
:: MonadThrow m
|
||||
=> ClassEnvironment -- ^ Set of defined type-classes.
|
||||
-> [TypeSignature] -- ^ Pre-defined type signatures e.g. for built-ins or FFI.
|
||||
-> [BindGroup] -- ^ Bindings in the module.
|
||||
-> m [TypeSignature] -- ^ Inferred types for all identifiers.
|
||||
typeCheckModule ce as bgs =
|
||||
-> [TypeSignature] -- ^ Pre-defined type signatures e.g. for built-ins or FFI.
|
||||
-> SpecialTypes -- ^ Special types that Haskell uses for pattern matching and literals.
|
||||
-> [BindGroup] -- ^ Bindings in the module.
|
||||
-> m [TypeSignature] -- ^ Inferred types for all identifiers.
|
||||
typeCheckModule ce as specialTypes bgs =
|
||||
evalStateT
|
||||
(runInferT $ do
|
||||
(ps, as') <- inferSequenceTypes inferBindGroupTypes ce as bgs
|
||||
@ -447,54 +463,75 @@ typeCheckModule ce as bgs =
|
||||
let rs = reduce ce (map (substitutePredicate s) ps)
|
||||
s' <- defaultSubst ce [] rs
|
||||
return (map (substituteTypeSignature (s' @@ s)) as'))
|
||||
(InferState nullSubst 0)
|
||||
(InferState nullSubst 0 specialTypes)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Built-in types and classes
|
||||
|
||||
boolType :: Type
|
||||
boolType = ConstructorType (TypeConstructor "Bool" StarKind)
|
||||
-- | Special types that Haskell uses for pattern matching and literals.
|
||||
defaultSpecialTypes :: SpecialTypes
|
||||
defaultSpecialTypes =
|
||||
SpecialTypes
|
||||
{ specialTypesBool = ConstructorType (TypeConstructor "Bool" StarKind)
|
||||
, specialTypesChar = ConstructorType (TypeConstructor "Char" StarKind)
|
||||
, specialTypesString = makeListType (specialTypesChar defaultSpecialTypes)
|
||||
, specialTypesFunction =
|
||||
ConstructorType
|
||||
(TypeConstructor
|
||||
"(->)"
|
||||
(FunctionKind StarKind (FunctionKind StarKind StarKind)))
|
||||
, specialTypesList = listType
|
||||
}
|
||||
where
|
||||
makeListType :: Type -> Type
|
||||
makeListType t = ApplicationType listType t
|
||||
listType :: Type
|
||||
listType =
|
||||
ConstructorType (TypeConstructor "[]" (FunctionKind StarKind StarKind))
|
||||
|
||||
charType :: Type
|
||||
charType = ConstructorType (TypeConstructor "Char" StarKind)
|
||||
-- boolType :: Type
|
||||
-- boolType = ConstructorType (TypeConstructor "Bool" StarKind)
|
||||
|
||||
stringType :: Type
|
||||
stringType = makeListType charType
|
||||
-- charType :: Type
|
||||
-- charType = ConstructorType (TypeConstructor "Char" StarKind)
|
||||
|
||||
makeListType :: Type -> Type
|
||||
makeListType t = ApplicationType listType t
|
||||
-- stringType :: Type
|
||||
-- stringType = makeListType charType
|
||||
|
||||
listType :: Type
|
||||
listType = ConstructorType (TypeConstructor "[]" (FunctionKind StarKind StarKind))
|
||||
-- makeListType :: Type -> Type
|
||||
-- makeListType t = ApplicationType listType t
|
||||
|
||||
makeArrow :: Type -> Type -> Type
|
||||
a `makeArrow` b = ApplicationType (ApplicationType tArrow a) b
|
||||
-- listType :: Type
|
||||
-- listType = ConstructorType (TypeConstructor "[]" (FunctionKind StarKind StarKind))
|
||||
|
||||
tArrow :: Type
|
||||
tArrow =
|
||||
ConstructorType
|
||||
(TypeConstructor
|
||||
"(->)"
|
||||
(FunctionKind StarKind (FunctionKind StarKind StarKind)))
|
||||
-- makeArrow :: Type -> Type -> Type
|
||||
-- a `makeArrow` b = ApplicationType (ApplicationType tArrow a) b
|
||||
|
||||
numClasses :: [Identifier]
|
||||
numClasses =
|
||||
["Num", "Integral", "Floating", "Fractional", "Real", "RealFloat", "RealFrac"]
|
||||
-- tArrow :: Type
|
||||
-- tArrow =
|
||||
-- ConstructorType
|
||||
-- (TypeConstructor
|
||||
-- "(->)"
|
||||
-- (FunctionKind StarKind (FunctionKind StarKind StarKind)))
|
||||
|
||||
stdClasses :: [Identifier]
|
||||
stdClasses =
|
||||
[ "Eq"
|
||||
, "Ord"
|
||||
, "Show"
|
||||
, "Read"
|
||||
, "Bounded"
|
||||
, "Enum"
|
||||
, "Ix"
|
||||
, "Functor"
|
||||
, "Monad"
|
||||
, "MonadPlus"
|
||||
] ++
|
||||
numClasses
|
||||
-- numClasses :: [Identifier]
|
||||
-- numClasses =
|
||||
-- ["Num", "Integral", "Floating", "Fractional", "Real", "RealFloat", "RealFrac"]
|
||||
|
||||
-- stdClasses :: [Identifier]
|
||||
-- stdClasses =
|
||||
-- [ "Eq"
|
||||
-- , "Ord"
|
||||
-- , "Show"
|
||||
-- , "Read"
|
||||
-- , "Bounded"
|
||||
-- , "Enum"
|
||||
-- , "Ix"
|
||||
-- , "Functor"
|
||||
-- , "Monad"
|
||||
-- , "MonadPlus"
|
||||
-- ] ++
|
||||
-- numClasses
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- Substitution
|
||||
@ -770,13 +807,15 @@ enumId n = Identifier ("v" ++ show n)
|
||||
|
||||
inferLiteralType
|
||||
:: Monad m
|
||||
=> Literal -> InferT m ([Predicate], Type)
|
||||
inferLiteralType (CharacterLiteral _) = return ([], charType)
|
||||
inferLiteralType (IntegerLiteral _) = do
|
||||
=> SpecialTypes -> Literal -> InferT m ([Predicate], Type)
|
||||
inferLiteralType specialTypes (CharacterLiteral _) =
|
||||
return ([], specialTypesChar specialTypes)
|
||||
inferLiteralType _ (IntegerLiteral _) = do
|
||||
v <- newVariableType StarKind
|
||||
return ([IsIn "Num" [v]], v)
|
||||
inferLiteralType (StringLiteral _) = return ([], stringType)
|
||||
inferLiteralType (RationalLiteral _) = do
|
||||
inferLiteralType specialTypes (StringLiteral _) =
|
||||
return ([], specialTypesString specialTypes)
|
||||
inferLiteralType _ (RationalLiteral _) = do
|
||||
v <- newVariableType StarKind
|
||||
return ([IsIn "Fractional" [v]], v)
|
||||
|
||||
@ -793,12 +832,16 @@ inferPattern (AsPattern i pat) = do
|
||||
(ps, as, t) <- inferPattern pat
|
||||
return (ps, (TypeSignature i (toScheme t)) : as, t)
|
||||
inferPattern (LiteralPattern l) = do
|
||||
(ps, t) <- inferLiteralType l
|
||||
specialTypes <- InferT (gets inferStateSpecialTypes)
|
||||
(ps, t) <- inferLiteralType specialTypes l
|
||||
return (ps, [], t)
|
||||
inferPattern (ConstructorPattern (TypeSignature _ sc) pats) = do
|
||||
(ps, as, ts) <- inferPatterns pats
|
||||
t' <- newVariableType StarKind
|
||||
(Qualified qs t) <- freshInst sc
|
||||
specialTypes <- InferT (gets inferStateSpecialTypes)
|
||||
let makeArrow :: Type -> Type -> Type
|
||||
a `makeArrow` b = ApplicationType (ApplicationType (specialTypesFunction specialTypes) a) b
|
||||
unify t (foldr makeArrow t' ts)
|
||||
return (ps ++ qs, as, t')
|
||||
inferPattern (LazyPattern pat) = inferPattern pat
|
||||
@ -964,12 +1007,16 @@ inferExpressionType _ _ (ConstantExpression (TypeSignature _ sc)) = do
|
||||
(Qualified ps t) <- freshInst sc
|
||||
return (ps, t)
|
||||
inferExpressionType _ _ (LiteralExpression l) = do
|
||||
(ps, t) <- inferLiteralType l
|
||||
specialTypes <- InferT (gets inferStateSpecialTypes)
|
||||
(ps, t) <- inferLiteralType specialTypes l
|
||||
return (ps, t)
|
||||
inferExpressionType ce as (ApplicationExpression e f) = do
|
||||
(ps, te) <- inferExpressionType ce as e
|
||||
(qs, tf) <- inferExpressionType ce as f
|
||||
t <- newVariableType StarKind
|
||||
specialTypes <- InferT (gets inferStateSpecialTypes)
|
||||
let makeArrow :: Type -> Type -> Type
|
||||
a `makeArrow` b = ApplicationType (ApplicationType (specialTypesFunction specialTypes) a) b
|
||||
unify (tf `makeArrow` t) te
|
||||
return (ps ++ qs, t)
|
||||
inferExpressionType ce as (LetExpression bg e) = do
|
||||
@ -979,7 +1026,8 @@ inferExpressionType ce as (LetExpression bg e) = do
|
||||
inferExpressionType ce as (LambdaExpression alt) = inferAltType ce as alt
|
||||
inferExpressionType ce as (IfExpression e e1 e2) = do
|
||||
(ps, t) <- inferExpressionType ce as e
|
||||
unify t boolType
|
||||
specialTypes <- InferT (gets inferStateSpecialTypes)
|
||||
unify t (specialTypesBool specialTypes)
|
||||
(ps1, t1) <- inferExpressionType ce as e1
|
||||
(ps2, t2) <- inferExpressionType ce as e2
|
||||
unify t1 t2
|
||||
@ -1005,6 +1053,9 @@ inferAltType
|
||||
inferAltType ce as (Alternative pats e) = do
|
||||
(ps, as', ts) <- inferPatterns pats
|
||||
(qs, t) <- inferExpressionType ce (as' ++ as) e
|
||||
specialTypes <- InferT (gets inferStateSpecialTypes)
|
||||
let makeArrow :: Type -> Type -> Type
|
||||
a `makeArrow` b = ApplicationType (ApplicationType (specialTypesFunction specialTypes) a) b
|
||||
return (ps ++ qs, foldr makeArrow t ts)
|
||||
|
||||
inferAltTypes
|
||||
@ -1034,8 +1085,8 @@ candidates ce (Ambiguity v qs) =
|
||||
| let is = [i | IsIn i _ <- qs]
|
||||
ts = [t | IsIn _ t <- qs]
|
||||
, all ([VariableType v] ==) ts
|
||||
, any (`elem` numClasses) is
|
||||
, all (`elem` stdClasses) is
|
||||
-- , any (`elem` numClasses) is
|
||||
-- , all (`elem` stdClasses) is
|
||||
, t' <- classEnvironmentDefaults ce
|
||||
, all (entail ce []) [IsIn i [t'] | i <- is]
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user