some work on extend HPT to infer LLVM types

This commit is contained in:
Csaba Hruska 2018-01-08 18:25:51 +01:00
parent 6b48936b40
commit 6e51acbd3a
6 changed files with 83 additions and 28 deletions

3
.gitignore vendored
View File

@ -1,7 +1,8 @@
dist dist
dist-* dist-*
cabal-dev cabal-dev
*.o *.s
*.ll
*.hi *.hi
*.chi *.chi
*.chs.h *.chs.h

View File

@ -7,6 +7,7 @@ module AbstractRunGrin
, RTLocVal(..) , RTLocVal(..)
, RTNode(..) , RTNode(..)
, RTVar(..) , RTVar(..)
, emptyComputer
) where ) where
import Debug.Trace import Debug.Trace
@ -39,9 +40,18 @@ type ADefMap = Map Name ADef
implement equasion solver for the specific example from the grin paper as a separate app implement equasion solver for the specific example from the grin paper as a separate app
-} -}
data CGType
= T_I64
| T_Unit
| T_Loc
| T_Tag
| T_UNKNOWN
| T_Fun String
deriving (Eq, Ord, Show)
data RTLocVal data RTLocVal
= RTLoc Int = RTLoc Int
| BAS | BAS CGType
| RTVar Name -- HACK | RTVar Name -- HACK
deriving (Eq, Ord, Show) deriving (Eq, Ord, Show)
@ -72,7 +82,11 @@ data Step
| StepAssign Name VarSet | StepAssign Name VarSet
deriving Show deriving Show
emptyComputer = Computer mempty mempty mempty emptyComputer = Computer
{ storeMap = mempty
, envMap = mempty
, steps = mempty
}
type GrinM = ReaderT ADefMap (State Computer) type GrinM = ReaderT ADefMap (State Computer)
@ -125,7 +139,12 @@ lookupEnv n = Map.findWithDefault (error $ "missing variable: " ++ n) n <$> gets
lookupStore :: Int -> GrinM NodeSet lookupStore :: Int -> GrinM NodeSet
lookupStore i = IntMap.findWithDefault (error $ "missing location: " ++ show i) i <$> gets storeMap lookupStore i = IntMap.findWithDefault (error $ "missing location: " ++ show i) i <$> gets storeMap
basVarSet = Set.singleton $ V BAS basVarSet cgType = Set.singleton . V . BAS $ cgType
boolVarSet = Set.fromList
[ N $ RTNode (Tag C "True" 0) []
, N $ RTNode (Tag C "False" 0) []
]
toRTLocVal :: RTVar -> RTLocVal toRTLocVal :: RTVar -> RTLocVal
toRTLocVal (V a) = a toRTLocVal (V a) = a
@ -137,7 +156,7 @@ toRTNode a = error $ "toRTNode: illegal value " ++ show a
-} -}
evalVal :: Val -> GrinM VarSet evalVal :: Val -> GrinM VarSet
evalVal = \case evalVal = \case
v@Lit{} -> pure basVarSet v@Lit{} -> pure $ basVarSet T_I64
Var n -> lookupEnv n Var n -> lookupEnv n
ConstTagNode t a -> Set.singleton . N . RTNode t <$> mapM (\x -> Set.map toRTLocVal <$> evalVal x) a ConstTagNode t a -> Set.singleton . N . RTNode t <$> mapM (\x -> Set.map toRTLocVal <$> evalVal x) a
{- {-
@ -148,15 +167,15 @@ evalVal = \case
-- TODO: support TagValue ; represent it as normal value instead of BAS -- TODO: support TagValue ; represent it as normal value instead of BAS
pure $ Set.fromList [N $ RTNode t args | t <- values] pure $ Set.fromList [N $ RTNode t args | t <- values]
-} -}
v@ValTag{} -> pure basVarSet v@ValTag{} -> pure $ basVarSet T_Tag
v@Unit -> pure basVarSet v@Unit -> pure $ basVarSet T_Unit
v@Loc{} -> pure basVarSet v@Loc{} -> pure $ basVarSet T_Loc
x -> fail $ "ERROR: evalVal: " ++ show x x -> fail $ "ERROR: evalVal: " ++ show x
selectRTNodeItem :: Maybe Int -> RTVar -> VarSet selectRTNodeItem :: Maybe Int -> RTVar -> VarSet
selectRTNodeItem Nothing val = Set.singleton val selectRTNodeItem Nothing val = Set.singleton val
selectRTNodeItem (Just 0) (N (RTNode tag args)) = basVarSet selectRTNodeItem (Just 0) (N (RTNode tag args)) = basVarSet T_Tag
selectRTNodeItem (Just i) (N (RTNode tag args)) = Set.map V $ (args !! (i - 1)) selectRTNodeItem (Just i) (N (RTNode tag args)) = Set.map V $ (args !! (i - 1))
evalSFetchF :: Maybe Int -> VarSet -> GrinM VarSet evalSFetchF :: Maybe Int -> VarSet -> GrinM VarSet
@ -165,7 +184,8 @@ evalSFetchF index vals = mconcat <$> mapM fetch (Set.toList vals) where
V (RTLoc l) -> {-Set.map N <$> -}mconcat . map (selectRTNodeItem index) . Set.toList <$> lookupStore l V (RTLoc l) -> {-Set.map N <$> -}mconcat . map (selectRTNodeItem index) . Set.toList <$> lookupStore l
x -> fail $ "ERROR: evalSimpleExp - Fetch expected location, got: " ++ show x x -> fail $ "ERROR: evalSimpleExp - Fetch expected location, got: " ++ show x
evalSUpdateF vals v' = mapM_ update vals >> pure basVarSet where evalSUpdateF :: VarSet-> NodeSet -> GrinM VarSet
evalSUpdateF vals v' = mapM_ update vals >> pure (basVarSet T_UNKNOWN) where
update = \case update = \case
V (RTLoc l) -> IntMap.member l <$> gets storeMap >>= \case V (RTLoc l) -> IntMap.member l <$> gets storeMap >>= \case
False -> fail $ "ERROR: evalSimpleExp - Update unknown location: " ++ show l False -> fail $ "ERROR: evalSimpleExp - Update unknown location: " ++ show l
@ -205,18 +225,19 @@ evalSAppF n rtVals = do
evalSimpleExp :: ASimpleExp -> GrinM VarSet evalSimpleExp :: ASimpleExp -> GrinM VarSet
evalSimpleExp = \case evalSimpleExp = \case
_ :< (SAppF n args) -> case n of _ :< (SAppF n args) -> do
rtVals <- mapM evalVal args -- Question: is this correct here?
case n of
-- Special case -- Special case
-- "eval" -> evalEval args -- "eval" -> evalEval args
-- Primitives -- Primitives
"add" -> pure basVarSet "add" -> pure $ basVarSet T_I64
"mul" -> pure basVarSet "mul" -> pure $ basVarSet T_I64
"intPrint" -> pure basVarSet "intPrint" -> pure $ basVarSet $ T_Fun "intPrint"
"intGT" -> pure basVarSet "intGT" -> pure $ basVarSet $ T_Fun "intGT" --boolVarSet
"intAdd" -> pure basVarSet "intAdd" -> pure $ basVarSet T_I64
-- User defined functions -- User defined functions
_ -> do _ -> do
rtVals <- mapM evalVal args -- Question: is this correct here?
evalSAppF n rtVals evalSAppF n rtVals
_ :< (SReturnF v) -> evalVal v _ :< (SReturnF v) -> evalVal v
@ -255,9 +276,10 @@ evalExp x = {-addStep x >> -}case x of
, AltF (NodePat alttag names) exp <- map unwrap alts , AltF (NodePat alttag names) exp <- map unwrap alts
, tag == alttag , tag == alttag
] ]
case Set.member (V BAS) vals of -- what is this???
False -> pure a case [() | V (BAS _) <- Set.toList vals] of
True -> do [] -> pure a
_ -> do
let notNodePat = \case let notNodePat = \case
NodePat{} -> False NodePat{} -> False
_ -> True _ -> True

View File

@ -16,6 +16,7 @@ import Data.Map (Map)
import qualified Data.Map as Map import qualified Data.Map as Map
import Grin import Grin
import AbstractRunGrin
import LLVM.AST hiding (callingConvention) import LLVM.AST hiding (callingConvention)
import LLVM.AST.Type import LLVM.AST.Type
@ -47,14 +48,38 @@ toLLVM fname mod = withContext $ \ctx -> do
BS.writeFile fname llvm BS.writeFile fname llvm
pure llvm pure llvm
-- TODO: create Tag map
{-
b2 -> {BAS}
n13 -> {BAS,sum}
n18 -> {BAS}
n28 -> {BAS}
n29 -> {BAS}
n30 -> {BAS}
n31 -> {BAS}
sum -> {BAS,sum}
-}
-- TODO: create Tag map ; get as parameter ; store in reader environment
{-
question: how to calculate from grin or hpt result?
-}
tagMap :: Map Tag (Type, Constant) tagMap :: Map Tag (Type, Constant)
tagMap = Map.fromList tagMap = Map.fromList
[ (Tag Grin.C "False" 0, (i1, Int 1 0)) [ (Tag Grin.C "False" 0, (i1, Int 1 0))
, (Tag Grin.C "True" 0, (i1, Int 1 1)) , (Tag Grin.C "True" 0, (i1, Int 1 1))
] ]
-- TODO: create Type map -- TODO: create Type map ; calculate once ; store in reader environment
{-
question: how to calculate from grin or hpt result?
ANSWER: lookup from HPT result ; function name = result type ; argument names = input type
TODO:
in pre passes build ; store in env
function type map (llvm type)
variable map (llvm type)
-}
typeMap :: Map Grin.Name Type typeMap :: Map Grin.Name Type
typeMap = Map.fromList typeMap = Map.fromList
[ ("b2" , i64) [ ("b2" , i64)
@ -103,6 +128,7 @@ data Env
, constantMap :: Map Grin.Name Operand , constantMap :: Map Grin.Name Operand
, currentBlockName :: AST.Name , currentBlockName :: AST.Name
, envTempCounter :: Int , envTempCounter :: Int
, envHPTResult :: HPTResult
} }
emptyEnv = Env emptyEnv = Env
@ -112,6 +138,7 @@ emptyEnv = Env
, constantMap = mempty , constantMap = mempty
, currentBlockName = mkName "" , currentBlockName = mkName ""
, envTempCounter = 0 , envTempCounter = 0
, envHPTResult = emptyComputer
} }
type CG = State Env type CG = State Env
@ -215,8 +242,8 @@ toModule Env{..} = defaultModule
, moduleDefinitions = envDefinitions , moduleDefinitions = envDefinitions
} }
codeGen :: Exp -> AST.Module codeGen :: HPTResult -> Exp -> AST.Module
codeGen = toModule . flip execState emptyEnv . para folder where codeGen hptResult = toModule . flip execState (emptyEnv {envHPTResult = hptResult}) . para folder where
folder :: ExpF (Exp, CG Result) -> CG Result folder :: ExpF (Exp, CG Result) -> CG Result
folder = \case folder = \case
SReturnF val -> O <$> codeGenVal val SReturnF val -> O <$> codeGenVal val

View File

@ -8,6 +8,8 @@ import qualified STReduceGrin
import qualified ReduceGrin import qualified ReduceGrin
import qualified JITLLVM import qualified JITLLVM
import qualified CodeGenLLVM import qualified CodeGenLLVM
import qualified AbstractRunGrin
import Transformations (assignStoreIDs)
data Reducer data Reducer
= PureReducer = PureReducer
@ -24,7 +26,8 @@ eval' reducer fname = do
case reducer of case reducer of
PureReducer -> pure $ ReduceGrin.reduceFun e "grinMain" PureReducer -> pure $ ReduceGrin.reduceFun e "grinMain"
STReducer -> pure $ STReduceGrin.reduceFun e "grinMain" STReducer -> pure $ STReduceGrin.reduceFun e "grinMain"
LLVMReducer -> JITLLVM.eagerJit (CodeGenLLVM.codeGen (Program e)) "grinMain" LLVMReducer -> JITLLVM.eagerJit (CodeGenLLVM.codeGen hptResult (Program e)) "grinMain" where
(result, hptResult) = AbstractRunGrin.abstractRun (assignStoreIDs $ Program e) "grinMain"
evalProgram :: Reducer -> Program -> Val evalProgram :: Reducer -> Program -> Val
evalProgram reducer (Program defs) = evalProgram reducer (Program defs) =

View File

@ -148,8 +148,9 @@ printGrinM color = do
jitLLVM :: PipelineM () jitLLVM :: PipelineM ()
jitLLVM = do jitLLVM = do
e <- use psExp e <- use psExp
Just hptResult <- use psHPTResult
liftIO $ do liftIO $ do
val <- JITLLVM.eagerJit (CGLLVM.codeGen e) "grinMain" val <- JITLLVM.eagerJit (CGLLVM.codeGen hptResult e) "grinMain"
print $ pretty val print $ pretty val
printAST :: PipelineM () printAST :: PipelineM ()
@ -172,9 +173,10 @@ saveLLVM :: FilePath -> PipelineM ()
saveLLVM fname' = do saveLLVM fname' = do
e <- use psExp e <- use psExp
n <- use psTransStep n <- use psTransStep
Just hptResult <- use psHPTResult
o <- view poOutputDir o <- view poOutputDir
let fname = o </> concat [fname',".",show n] let fname = o </> concat [fname',".",show n]
code = CGLLVM.codeGen e code = CGLLVM.codeGen hptResult e
llName = printf "%s.ll" fname llName = printf "%s.ll" fname
sName = printf "%s.s" fname sName = printf "%s.s" fname
liftIO . void $ do liftIO . void $ do

View File

@ -28,7 +28,7 @@ instance Pretty a => Pretty (Set a) where
instance Pretty RTLocVal where instance Pretty RTLocVal where
pretty = \case pretty = \case
RTLoc l -> int l RTLoc l -> int l
BAS -> text "BAS" bas@BAS{} -> text $ show bas
RTVar name -> ondullblack $ red $ text name RTVar name -> ondullblack $ red $ text name
instance Pretty RTNode where instance Pretty RTNode where