1
1
mirror of https://github.com/anoma/juvix.git synced 2024-11-30 05:42:26 +03:00

Transform JuvixReg into SSA form (#2646)

* Closes #2560 
* Adds a transformation of JuvixReg into SSA form.
* Adds an "output variable" field to branching instructions (`Case`,
`Branch`) which indicates the output variable to which the result is
assigned in both branches. The output variable corresponds to top of
stack in JuvixAsm after executing the branches. In the SSA
transformation, differently renamed output variables are unified by
inserting assignment instructions at the end of branches.
* Adds tests for the SSA transformation.
* Depends on #2641.
This commit is contained in:
Łukasz Czajka 2024-02-20 11:45:14 +01:00 committed by GitHub
parent 9a48f1fd7c
commit cb808c1696
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 522 additions and 74 deletions

View File

@ -17,8 +17,8 @@ import Juvix.Compiler.Asm.Pretty
data RecursorSig m r a = RecursorSig
{ _recursorInfoTable :: InfoTable,
_recurseInstr :: m -> CmdInstr -> Sem r a,
_recurseBranch :: m -> CmdBranch -> [a] -> [a] -> Sem r a,
_recurseCase :: m -> CmdCase -> [[a]] -> Maybe [a] -> Sem r a,
_recurseBranch :: Bool -> m -> CmdBranch -> [a] -> [a] -> Sem r a,
_recurseCase :: Bool -> m -> CmdCase -> [[a]] -> Maybe [a] -> Sem r a,
_recurseSave :: m -> CmdSave -> [a] -> Sem r a
}
@ -252,7 +252,7 @@ recurse' sig = go True
let mem0 = popValueStack 1 mem
(mem1, as1) <- go isTail mem0 _cmdBranchTrue
(mem2, as2) <- go isTail mem0 _cmdBranchFalse
a' <- (sig ^. recurseBranch) mem cmd as1 as2
a' <- (sig ^. recurseBranch) isTail mem cmd as1 as2
mem' <- unifyMemory' loc (sig ^. recursorInfoTable) mem1 mem2
checkBranchInvariant 1 loc mem0 mem'
return (mem', a')
@ -268,7 +268,7 @@ recurse' sig = go True
rd <- maybe (return Nothing) (fmap Just . go isTail mem) _cmdCaseDefault
let md = fmap fst rd
ad = fmap snd rd
a' <- (sig ^. recurseCase) mem cmd ass ad
a' <- (sig ^. recurseCase) isTail mem cmd ass ad
case mems of
[] -> return (fromMaybe mem md, a')
mem0 : mems' -> do
@ -333,25 +333,25 @@ recurseS :: forall r a. (Member (Error AsmError) r) => RecursorSig StackInfo r a
recurseS sig code = snd <$> recurseS' sig initialStackInfo code
recurseS' :: forall r a. (Member (Error AsmError) r) => RecursorSig StackInfo r a -> StackInfo -> Code -> Sem r (StackInfo, [a])
recurseS' sig = go
recurseS' sig = go True
where
go :: StackInfo -> Code -> Sem r (StackInfo, [a])
go si = \case
go :: Bool -> StackInfo -> Code -> Sem r (StackInfo, [a])
go isTail si = \case
[] -> return (si, [])
h : t -> case h of
Instr x -> do
goNextCmd (goInstr si x) t
goNextCmd isTail (goInstr si x) t
Branch x ->
goNextCmd (goBranch si x) t
goNextCmd isTail (goBranch (isTail && null t) si x) t
Case x ->
goNextCmd (goCase si x) t
goNextCmd isTail (goCase (isTail && null t) si x) t
Save x ->
goNextCmd (goSave si x) t
goNextCmd isTail (goSave si x) t
goNextCmd :: Sem r (StackInfo, a) -> Code -> Sem r (StackInfo, [a])
goNextCmd mp t = do
goNextCmd :: Bool -> Sem r (StackInfo, a) -> Code -> Sem r (StackInfo, [a])
goNextCmd isTail mp t = do
(si', r) <- mp
(si'', rs) <- go si' t
(si'', rs) <- go isTail si' t
return (si'', r : rs)
goInstr :: StackInfo -> CmdInstr -> Sem r (StackInfo, a)
@ -433,26 +433,26 @@ recurseS' sig = go
fixStackCallClosures si InstrCallClosures {..} = do
return $ stackInfoPopValueStack _callClosuresArgsNum si
goBranch :: StackInfo -> CmdBranch -> Sem r (StackInfo, a)
goBranch si cmd@CmdBranch {..} = do
goBranch :: Bool -> StackInfo -> CmdBranch -> Sem r (StackInfo, a)
goBranch isTail si cmd@CmdBranch {..} = do
let si0 = stackInfoPopValueStack 1 si
(si1, as1) <- go si0 _cmdBranchTrue
(si2, as2) <- go si0 _cmdBranchFalse
a' <- (sig ^. recurseBranch) si cmd as1 as2
(si1, as1) <- go isTail si0 _cmdBranchTrue
(si2, as2) <- go isTail si0 _cmdBranchFalse
a' <- (sig ^. recurseBranch) isTail si cmd as1 as2
checkStackInfo loc si1 si2
return (si1, a')
where
loc = cmd ^. cmdBranchInfo . commandInfoLocation
goCase :: StackInfo -> CmdCase -> Sem r (StackInfo, a)
goCase si cmd@CmdCase {..} = do
rs <- mapM (go si . (^. caseBranchCode)) _cmdCaseBranches
goCase :: Bool -> StackInfo -> CmdCase -> Sem r (StackInfo, a)
goCase isTail si cmd@CmdCase {..} = do
rs <- mapM (go isTail si . (^. caseBranchCode)) _cmdCaseBranches
let sis = map fst rs
ass = map snd rs
rd <- maybe (return Nothing) (fmap Just . go si) _cmdCaseDefault
rd <- maybe (return Nothing) (fmap Just . go isTail si) _cmdCaseDefault
let sd = fmap fst rd
ad = fmap snd rd
a' <- (sig ^. recurseCase) si cmd ass ad
a' <- (sig ^. recurseCase) isTail si cmd ass ad
case sis of
[] -> return (fromMaybe si sd, a')
si0 : sis' -> do
@ -465,7 +465,7 @@ recurseS' sig = go
goSave :: StackInfo -> CmdSave -> Sem r (StackInfo, a)
goSave si cmd@CmdSave {..} = do
let si1 = stackInfoPushTempStack 1 (stackInfoPopValueStack 1 si)
(si2, c) <- go si1 _cmdSaveCode
(si2, c) <- go _cmdSaveIsTail si1 _cmdSaveCode
c' <- (sig ^. recurseSave) si cmd c
let si' = if _cmdSaveIsTail then si2 else stackInfoPopTempStack 1 si2
return (si', c')
@ -528,7 +528,7 @@ foldS' sig si code acc = do
RecursorSig
{ _recursorInfoTable = sig ^. foldInfoTable,
_recurseInstr = \s cmd -> return ((sig ^. foldInstr) s cmd),
_recurseBranch = \s cmd br1 br2 ->
_recurseBranch = \_ s cmd br1 br2 ->
return
( \a -> do
let a' = (sig ^. foldAdjust) a
@ -536,7 +536,7 @@ foldS' sig si code acc = do
a2 <- compose' br2 a'
(sig ^. foldBranch) s cmd a1 a2 a
),
_recurseCase = \s cmd brs md ->
_recurseCase = \_ s cmd brs md ->
return
( \a -> do
let a' = (sig ^. foldAdjust) a

View File

@ -22,12 +22,12 @@ computeFunctionStackUsage tab fi = do
RecursorSig
{ _recursorInfoTable = tab,
_recurseInstr = \si _ -> return (si ^. stackInfoValueStackHeight, si ^. stackInfoTempStackHeight),
_recurseBranch = \si _ l r ->
_recurseBranch = \_ si _ l r ->
return
( max (si ^. stackInfoValueStackHeight) (max (maximum (map fst l)) (maximum (map fst r))),
max (si ^. stackInfoTempStackHeight) (max (maximum (map snd l)) (maximum (map snd r)))
),
_recurseCase = \si _ cs md ->
_recurseCase = \_ si _ cs md ->
return
( max (si ^. stackInfoValueStackHeight) (max (maximum (map (maximum . map fst) cs)) (maybe 0 (maximum . map fst) md)),
max (si ^. stackInfoTempStackHeight) (max (maximum (map (maximum . map snd) cs)) (maybe 0 (maximum . map snd) md))

View File

@ -12,8 +12,8 @@ validateCode tab fi code = do
RecursorSig
{ _recursorInfoTable = tab,
_recurseInstr = \_ _ -> return (),
_recurseBranch = \_ _ _ _ -> return (),
_recurseCase = \_ _ _ _ -> return (),
_recurseBranch = \_ _ _ _ _ -> return (),
_recurseCase = \_ _ _ _ _ -> return (),
_recurseSave = \_ _ _ -> return ()
}

View File

@ -0,0 +1,58 @@
module Juvix.Compiler.Reg.Data.IndexMap where
import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Reg.Language.Base hiding (lookup)
data IndexMap k = IndexMap
{ _indexMapFirstFree :: Int,
_indexMapTable :: HashMap k Index
}
makeLenses ''IndexMap
instance (Hashable k) => Semigroup (IndexMap k) where
m1 <> m2 =
IndexMap
{ _indexMapTable = m1 ^. indexMapTable <> m2 ^. indexMapTable,
_indexMapFirstFree = max (m1 ^. indexMapFirstFree) (m2 ^. indexMapFirstFree)
}
instance (Hashable k) => Monoid (IndexMap k) where
mempty =
IndexMap
{ _indexMapFirstFree = 0,
_indexMapTable = mempty
}
assign :: (Hashable k) => IndexMap k -> k -> (Index, IndexMap k)
assign IndexMap {..} k =
( _indexMapFirstFree,
IndexMap
{ _indexMapFirstFree = _indexMapFirstFree + 1,
_indexMapTable = HashMap.insert k _indexMapFirstFree _indexMapTable
}
)
lookup' :: (Hashable k) => IndexMap k -> k -> Maybe Index
lookup' IndexMap {..} k = HashMap.lookup k _indexMapTable
lookup :: (Hashable k) => IndexMap k -> k -> Index
lookup mp = fromJust . lookup' mp
combine :: forall k. (Hashable k) => IndexMap k -> IndexMap k -> IndexMap k
combine mp1 mp2 =
IndexMap
{ _indexMapFirstFree = max (mp1 ^. indexMapFirstFree) (mp2 ^. indexMapFirstFree),
_indexMapTable = mp
}
where
mp =
foldr
(\k -> HashMap.update (checkVal k) k)
(HashMap.intersection (mp1 ^. indexMapTable) (mp2 ^. indexMapTable))
(HashMap.keys (mp2 ^. indexMapTable))
checkVal :: k -> Index -> Maybe Index
checkVal k idx
| lookup mp2 k == idx = Just idx
| otherwise = Nothing

View File

@ -6,6 +6,7 @@ import Juvix.Prelude
data TransformationId
= Identity
| SSA
deriving stock (Data, Bounded, Enum, Show)
data PipelineId
@ -19,12 +20,13 @@ toCTransformations :: [TransformationId]
toCTransformations = []
toCairoTransformations :: [TransformationId]
toCairoTransformations = []
toCairoTransformations = [SSA]
instance TransformationId' TransformationId where
transformationText :: TransformationId -> Text
transformationText = \case
Identity -> strIdentity
SSA -> strSSA
instance PipelineId' TransformationId PipelineId where
pipelineText :: PipelineId -> Text

View File

@ -10,3 +10,6 @@ strCairoPipeline = "pipeline-cairo"
strIdentity :: Text
strIdentity = "identity"
strSSA :: Text
strSSA = "ssa"

View File

@ -0,0 +1,10 @@
module Juvix.Compiler.Reg.Extra
( module Juvix.Compiler.Reg.Extra.Base,
module Juvix.Compiler.Reg.Extra.Recursors,
module Juvix.Compiler.Reg.Extra.Info,
)
where
import Juvix.Compiler.Reg.Extra.Base
import Juvix.Compiler.Reg.Extra.Info
import Juvix.Compiler.Reg.Extra.Recursors

View File

@ -0,0 +1,156 @@
module Juvix.Compiler.Reg.Extra.Base where
import Juvix.Compiler.Reg.Language
getResultVar :: Instruction -> Maybe VarRef
getResultVar = \case
Binop x -> Just $ x ^. binaryOpResult
Show x -> Just $ x ^. instrShowResult
StrToInt x -> Just $ x ^. instrStrToIntResult
Assign x -> Just $ x ^. instrAssignResult
ArgsNum x -> Just $ x ^. instrArgsNumResult
Alloc x -> Just $ x ^. instrAllocResult
AllocClosure x -> Just $ x ^. instrAllocClosureResult
ExtendClosure x -> Just $ x ^. instrExtendClosureResult
Call x -> Just $ x ^. instrCallResult
CallClosures x -> Just $ x ^. instrCallClosuresResult
_ -> Nothing
setResultVar :: Instruction -> VarRef -> Instruction
setResultVar instr vref = case instr of
Binop x -> Binop $ set binaryOpResult vref x
Show x -> Show $ set instrShowResult vref x
StrToInt x -> StrToInt $ set instrStrToIntResult vref x
Assign x -> Assign $ set instrAssignResult vref x
ArgsNum x -> ArgsNum $ set instrArgsNumResult vref x
Alloc x -> Alloc $ set instrAllocResult vref x
AllocClosure x -> AllocClosure $ set instrAllocClosureResult vref x
ExtendClosure x -> ExtendClosure $ set instrExtendClosureResult vref x
Call x -> Call $ set instrCallResult vref x
CallClosures x -> CallClosures $ set instrCallClosuresResult vref x
_ -> impossible
overValueRefs :: (VarRef -> VarRef) -> Instruction -> Instruction
overValueRefs f = \case
Binop x -> Binop $ goBinop x
Show x -> Show $ goShow x
StrToInt x -> StrToInt $ goStrToInt x
Assign x -> Assign $ goAssign x
ArgsNum x -> ArgsNum $ goArgsNum x
Alloc x -> Alloc $ goAlloc x
AllocClosure x -> AllocClosure $ goAllocClosure x
ExtendClosure x -> ExtendClosure $ goExtendClosure x
Call x -> Call $ goCall x
CallClosures x -> CallClosures $ goCallClosures x
TailCall x -> TailCall $ goTailCall x
TailCallClosures x -> TailCallClosures $ goTailCallClosures x
Return x -> Return $ goReturn x
Branch x -> Branch $ goBranch x
Case x -> Case $ goCase x
Trace x -> Trace $ goTrace x
Dump -> Dump
Failure x -> Failure $ goFailure x
Prealloc x -> Prealloc $ goPrealloc x
Nop -> Nop
Block x -> Block $ goBlock x
where
goConstrField :: ConstrField -> ConstrField
goConstrField = over constrFieldRef f
goValue :: Value -> Value
goValue = \case
Const c -> Const c
CRef x -> CRef $ goConstrField x
VRef x -> VRef $ f x
goBinop :: BinaryOp -> BinaryOp
goBinop BinaryOp {..} =
BinaryOp
{ _binaryOpArg1 = goValue _binaryOpArg1,
_binaryOpArg2 = goValue _binaryOpArg2,
..
}
goShow :: InstrShow -> InstrShow
goShow = over instrShowValue goValue
goStrToInt :: InstrStrToInt -> InstrStrToInt
goStrToInt = over instrStrToIntValue goValue
goAssign :: InstrAssign -> InstrAssign
goAssign = over instrAssignValue goValue
goArgsNum :: InstrArgsNum -> InstrArgsNum
goArgsNum = over instrArgsNumValue goValue
goAlloc :: InstrAlloc -> InstrAlloc
goAlloc = over instrAllocArgs (map goValue)
goAllocClosure :: InstrAllocClosure -> InstrAllocClosure
goAllocClosure = over instrAllocClosureArgs (map goValue)
goExtendClosure :: InstrExtendClosure -> InstrExtendClosure
goExtendClosure InstrExtendClosure {..} =
InstrExtendClosure
{ _instrExtendClosureValue = f _instrExtendClosureValue,
_instrExtendClosureArgs = map goValue _instrExtendClosureArgs,
..
}
goCallType :: CallType -> CallType
goCallType = \case
CallFun sym -> CallFun sym
CallClosure cl -> CallClosure (f cl)
goCall :: InstrCall -> InstrCall
goCall InstrCall {..} =
InstrCall
{ _instrCallType = goCallType _instrCallType,
_instrCallArgs = map goValue _instrCallArgs,
..
}
goCallClosures :: InstrCallClosures -> InstrCallClosures
goCallClosures InstrCallClosures {..} =
InstrCallClosures
{ _instrCallClosuresArgs = map goValue _instrCallClosuresArgs,
_instrCallClosuresValue = f _instrCallClosuresValue,
..
}
goTailCall :: InstrTailCall -> InstrTailCall
goTailCall InstrTailCall {..} =
InstrTailCall
{ _instrTailCallType = goCallType _instrTailCallType,
_instrTailCallArgs = map goValue _instrTailCallArgs,
..
}
goTailCallClosures :: InstrTailCallClosures -> InstrTailCallClosures
goTailCallClosures InstrTailCallClosures {..} =
InstrTailCallClosures
{ _instrTailCallClosuresValue = f _instrTailCallClosuresValue,
_instrTailCallClosuresArgs = map goValue _instrTailCallClosuresArgs,
..
}
goReturn :: InstrReturn -> InstrReturn
goReturn = over instrReturnValue goValue
goBranch :: InstrBranch -> InstrBranch
goBranch = over instrBranchValue goValue
goCase :: InstrCase -> InstrCase
goCase = over instrCaseValue goValue
goTrace :: InstrTrace -> InstrTrace
goTrace = over instrTraceValue goValue
goFailure :: InstrFailure -> InstrFailure
goFailure = over instrFailureValue goValue
goPrealloc :: InstrPrealloc -> InstrPrealloc
goPrealloc x = x
goBlock :: InstrBlock -> InstrBlock
goBlock x = x

View File

@ -5,7 +5,7 @@ import Juvix.Compiler.Reg.Language
data ForwardRecursorSig m c = ForwardRecursorSig
{ _forwardFun :: Instruction -> c -> m (c, Instruction),
_forwardCombine :: NonEmpty c -> c
_forwardCombine :: Instruction -> NonEmpty c -> (c, Instruction)
}
data BackwardRecursorSig m a = BackwardRecursorSig
@ -25,16 +25,16 @@ recurseF sig c = \case
Branch x@InstrBranch {..} -> do
(c1, is1) <- recurseF sig c0 _instrBranchTrue
(c2, is2) <- recurseF sig c0 _instrBranchFalse
let c' = (sig ^. forwardCombine) (c1 :| [c2])
return (c', Branch x {_instrBranchTrue = is1, _instrBranchFalse = is2})
let i' = Branch x {_instrBranchTrue = is1, _instrBranchFalse = is2}
return $ (sig ^. forwardCombine) i' (c1 :| [c2])
Case x@InstrCase {..} -> do
brs' <- mapM goBranch _instrCaseBranches
def' <- maybe (return Nothing) (\is -> Just <$> recurseF sig c0 is) _instrCaseDefault
let cs = map fst brs' ++ maybe [] (\md -> [fst md]) def'
brs = map snd brs'
def = maybe Nothing (Just . snd) def'
c' = (sig ^. forwardCombine) (nonEmpty' cs)
return (c', Case x {_instrCaseBranches = brs, _instrCaseDefault = def})
i' = Case x {_instrCaseBranches = brs, _instrCaseDefault = def}
return $ (sig ^. forwardCombine) i' (nonEmpty' cs)
where
goBranch :: CaseBranch -> m (c, CaseBranch)
goBranch br@CaseBranch {..} = do
@ -118,7 +118,7 @@ ifoldFM f a0 is0 =
{ _forwardFun = \i a -> do
a' <- f a i
return (a', i),
_forwardCombine = mconcat . toList
_forwardCombine = \i a -> (mconcat (toList a), i)
}
a0
is0

View File

@ -10,9 +10,6 @@ data Value
= Const Constant
| CRef ConstrField
| VRef VarRef
deriving stock (Eq)
type Index = Int
-- | Reference to a constructor field (argument).
data ConstrField = ConstrField
@ -24,41 +21,59 @@ data ConstrField = ConstrField
_constrFieldRef :: VarRef,
_constrFieldIndex :: Index
}
deriving stock (Eq)
data VarGroup
= VarGroupArgs
| VarGroupLocal
deriving stock (Eq)
deriving stock (Eq, Generic)
instance Hashable VarGroup
data VarRef = VarRef
{ _varRefGroup :: VarGroup,
_varRefIndex :: Index,
_varRefName :: Maybe Text
}
deriving stock (Eq)
makeLenses ''VarRef
makeLenses ''ConstrField
instance Hashable VarRef where
hashWithSalt salt VarRef {..} = hashWithSalt salt (_varRefGroup, _varRefIndex)
instance Eq VarRef where
vr1 == vr2 =
vr1 ^. varRefGroup == vr2 ^. varRefGroup
&& vr1 ^. varRefIndex == vr2 ^. varRefIndex
deriving stock instance (Eq ConstrField)
deriving stock instance (Eq Value)
data Instruction
= Nop -- no operation
| Binop BinaryOp
= Binop BinaryOp
| Show InstrShow
| StrToInt InstrStrToInt
| Assign InstrAssign
| Trace InstrTrace
| Dump
| Failure InstrFailure
| ArgsNum InstrArgsNum
| Prealloc InstrPrealloc
| Alloc InstrAlloc
| AllocClosure InstrAllocClosure
| ExtendClosure InstrExtendClosure
| Call InstrCall
| TailCall InstrTailCall
| CallClosures InstrCallClosures
| ----
TailCall InstrTailCall
| TailCallClosures InstrTailCallClosures
| Return InstrReturn
| Branch InstrBranch
| ----
Branch InstrBranch
| Case InstrCase
| ----
Trace InstrTrace
| Dump
| Failure InstrFailure
| Prealloc InstrPrealloc
| Nop -- no operation
| Block InstrBlock
deriving stock (Eq)
@ -156,7 +171,7 @@ data InstrCall = InstrCall
{ _instrCallResult :: VarRef,
_instrCallType :: CallType,
_instrCallArgs :: [Value],
-- | Variables live at the point of the call. Live variables need to be
-- | Variables live after the call. Live variables need to be
-- saved before the call and restored after it.
_instrCallLiveVars :: [VarRef]
}
@ -190,7 +205,10 @@ newtype InstrReturn = InstrReturn
data InstrBranch = InstrBranch
{ _instrBranchValue :: Value,
_instrBranchTrue :: Code,
_instrBranchFalse :: Code
_instrBranchFalse :: Code,
-- | Output variable storing the result (corresponds to the top of the value
-- stack in JuvixAsm after executing the branches)
_instrBranchOutVar :: Maybe VarRef
}
deriving stock (Eq)
@ -199,7 +217,8 @@ data InstrCase = InstrCase
_instrCaseInductive :: Symbol,
_instrCaseIndRep :: IndRep,
_instrCaseBranches :: [CaseBranch],
_instrCaseDefault :: Maybe Code
_instrCaseDefault :: Maybe Code,
_instrCaseOutVar :: Maybe VarRef
}
deriving stock (Eq)
@ -217,7 +236,6 @@ newtype InstrBlock = InstrBlock
}
deriving stock (Eq)
makeLenses ''ConstrField
makeLenses ''BinaryOp
makeLenses ''InstrAssign
makeLenses ''InstrTrace
@ -232,6 +250,10 @@ makeLenses ''InstrBranch
makeLenses ''InstrCase
makeLenses ''CaseBranch
makeLenses ''InstrReturn
makeLenses ''InstrShow
makeLenses ''InstrStrToInt
makeLenses ''InstrArgsNum
makeLenses ''InstrTailCall
mkVarRef :: VarGroup -> Index -> VarRef
mkVarRef g i =

View File

@ -5,6 +5,6 @@ module Juvix.Compiler.Reg.Language.Base
)
where
import Juvix.Compiler.Core.Language.Base hiding (Index)
import Juvix.Compiler.Core.Language.Base
import Juvix.Compiler.Tree.Language.Base (Constant (..))
import Juvix.Compiler.Tree.Language.Rep

View File

@ -108,6 +108,13 @@ ppLiveVars vars
vars' <- mapM ppCode vars
return $ comma <+> primitive "live:" <+> arglist vars'
ppOutVar :: (Member (Reader Options) r) => Maybe VarRef -> Sem r (Doc Ann)
ppOutVar = \case
Nothing -> return mempty
Just var -> do
var' <- ppCode var
return $ comma <+> primitive "out:" <+> var'
instance PrettyCode InstrPrealloc where
ppCode InstrPrealloc {..} = do
vars <- ppLiveVars _instrPreallocLiveVars
@ -210,9 +217,11 @@ instance PrettyCode InstrBranch where
val <- ppCode _instrBranchValue
br1 <- ppCodeCode _instrBranchTrue
br2 <- ppCodeCode _instrBranchFalse
var <- ppOutVar _instrBranchOutVar
return $
primitive Str.br
<+> val
<> var
<+> braces'
( constr Str.true_ <> colon
<+> braces' br1
@ -240,8 +249,9 @@ instance PrettyCode InstrCase where
val <- ppCode _instrCaseValue
brs <- mapM ppCode _instrCaseBranches
def <- maybe (return Nothing) (fmap Just . ppDefaultBranch) _instrCaseDefault
var <- ppOutVar _instrCaseOutVar
let brs' = brs ++ catMaybes [def]
return $ primitive Str.case_ <> brackets ind <+> val <+> braces' (vsep brs')
return $ primitive Str.case_ <> brackets ind <+> val <> var <+> braces' (vsep brs')
instance PrettyCode InstrBlock where
ppCode InstrBlock {..} = braces' <$> ppCodeCode _instrBlockCode

View File

@ -8,6 +8,7 @@ where
import Juvix.Compiler.Reg.Data.TransformationId
import Juvix.Compiler.Reg.Transformation.Base
import Juvix.Compiler.Reg.Transformation.Identity
import Juvix.Compiler.Reg.Transformation.SSA
applyTransformations :: forall r. [TransformationId] -> InfoTable -> Sem r InfoTable
applyTransformations ts tbl = foldM (flip appTrans) tbl ts
@ -15,3 +16,4 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
appTrans :: TransformationId -> InfoTable -> Sem r InfoTable
appTrans = \case
Identity -> return . identity
SSA -> return . computeSSA

View File

@ -0,0 +1,142 @@
module Juvix.Compiler.Reg.Transformation.SSA where
import Data.Functor.Identity
import Data.HashSet qualified as HashSet
import Data.List.NonEmpty qualified as NonEmpty
import Data.Monoid
import Juvix.Compiler.Reg.Data.IndexMap (IndexMap)
import Juvix.Compiler.Reg.Data.IndexMap qualified as IndexMap
import Juvix.Compiler.Reg.Extra
import Juvix.Compiler.Reg.Transformation.Base
computeFunctionSSA :: Code -> Code
computeFunctionSSA =
snd
. runIdentity
. recurseF
ForwardRecursorSig
{ _forwardFun = \i acc -> return (go i acc),
_forwardCombine = combine
}
mempty
where
go :: Instruction -> IndexMap VarRef -> (IndexMap VarRef, Instruction)
go instr mp = case getResultVar instr' of
Just vref -> (mp', updateLiveVars mp' (setResultVar instr' (mkVarRef VarGroupLocal idx)))
where
(idx, mp') = IndexMap.assign mp vref
Nothing -> (mp, updateLiveVars mp instr')
where
instr' = overValueRefs (adjustVarRef mp) instr
updateLiveVars :: IndexMap VarRef -> Instruction -> Instruction
updateLiveVars mp = \case
Prealloc x -> Prealloc $ over instrPreallocLiveVars (mapMaybe (adjustVarRef' mp)) x
Call x -> Call $ over instrCallLiveVars (mapMaybe (adjustVarRef' mp)) x
CallClosures x -> CallClosures $ over instrCallClosuresLiveVars (mapMaybe (adjustVarRef' mp)) x
instr -> instr
-- For branches, when necessary we insert assignments unifying the renamed
-- output variables into a single output variable for both branches.
combine :: Instruction -> NonEmpty (IndexMap VarRef) -> (IndexMap VarRef, Instruction)
combine instr mps = case instr of
Branch InstrBranch {..} -> case mps of
mp1 :| mp2 : []
| isNothing _instrBranchOutVar ->
(mp, instr)
| idx1 == idx2 ->
( mp,
Branch
InstrBranch
{ _instrBranchOutVar = Just $ mkVarRef VarGroupLocal idx1,
..
}
)
| otherwise ->
( mp',
Branch
InstrBranch
{ _instrBranchTrue = assignInBranch _instrBranchTrue idx' idx1,
_instrBranchFalse = assignInBranch _instrBranchFalse idx' idx2,
_instrBranchOutVar = Just $ mkVarRef VarGroupLocal idx',
..
}
)
where
var = fromJust _instrBranchOutVar
idx1 = IndexMap.lookup mp1 var
idx2 = IndexMap.lookup mp2 var
mp = IndexMap.combine mp1 mp2
(idx', mp') = IndexMap.assign mp var
_ -> impossible
Case InstrCase {..} -> case mps of
mp0 :| mps'
| isNothing _instrCaseOutVar ->
(mp, instr)
| all (== head idxs) (NonEmpty.tail idxs) ->
( mp,
Case
InstrCase
{ _instrCaseOutVar = Just $ mkVarRef VarGroupLocal (head idxs),
..
}
)
| otherwise ->
( mp',
Case
InstrCase
{ _instrCaseBranches = brs',
_instrCaseDefault = def',
_instrCaseOutVar = Just $ mkVarRef VarGroupLocal idx',
..
}
)
where
var = fromJust _instrCaseOutVar
idxs = fmap (flip IndexMap.lookup var) mps
mp = foldr IndexMap.combine mp0 mps'
(idx', mp') = IndexMap.assign mp var
n = length _instrCaseBranches
brs' = zipWithExact updateBranch _instrCaseBranches (take n (toList idxs))
def' = fmap (\is -> assignInBranch is idx' (last idxs)) _instrCaseDefault
updateBranch :: CaseBranch -> Index -> CaseBranch
updateBranch br idx =
over caseBranchCode (\is -> assignInBranch is idx' idx) br
_ -> impossible
adjustVarRef :: IndexMap VarRef -> VarRef -> VarRef
adjustVarRef mpv vref@VarRef {..} = case _varRefGroup of
VarGroupArgs -> vref
VarGroupLocal -> mkVarRef VarGroupLocal (IndexMap.lookup mpv vref)
adjustVarRef' :: IndexMap VarRef -> VarRef -> Maybe VarRef
adjustVarRef' mpv vref = case IndexMap.lookup' mpv vref of
Just idx -> Just $ mkVarRef VarGroupLocal idx
Nothing -> case vref ^. varRefGroup of
VarGroupArgs -> Just vref
VarGroupLocal -> Nothing
assignInBranch :: Code -> Index -> Index -> Code
assignInBranch is idx idx' =
is
++ [ Assign
InstrAssign
{ _instrAssignResult = mkVarRef VarGroupLocal idx,
_instrAssignValue = VRef $ mkVarRef VarGroupLocal idx'
}
]
computeSSA :: InfoTable -> InfoTable
computeSSA = mapT (const computeFunctionSSA)
checkSSA :: InfoTable -> Bool
checkSSA tab = all (checkFun . (^. functionCode)) (tab ^. infoFunctions)
where
checkFun :: Code -> Bool
checkFun is = getAll $ snd $ ifoldF check (mempty, All True) is
where
check :: (HashSet VarRef, All) -> Instruction -> (HashSet VarRef, All)
check (refs, b) instr = case getResultVar instr of
Just var -> (HashSet.insert var refs, b <> All (not (HashSet.member var refs)))
Nothing -> (refs, b)

View File

@ -97,7 +97,7 @@ fromAsmInstr funInfo tab si Asm.CmdInstr {..} =
-- Live variables *after* executing the instruction. `k` is the number of
-- value stack cells that will be popped by the instruction. TODO: proper
-- liveness analysis in JuvixAsm.
-- liveness analysis.
liveVars :: Int -> [VarRef]
liveVars k =
map (mkVarRef VarGroupLocal) [0 .. si ^. Asm.stackInfoTempStackHeight - 1]
@ -255,33 +255,39 @@ fromAsmInstr funInfo tab si Asm.CmdInstr {..} =
fromAsmBranch ::
Asm.FunctionInfo ->
Bool ->
Asm.StackInfo ->
Asm.CmdBranch ->
Code ->
Code ->
Sem r Instruction
fromAsmBranch fi si Asm.CmdBranch {} codeTrue codeFalse =
fromAsmBranch fi isTail si Asm.CmdBranch {} codeTrue codeFalse =
return $
Branch $
InstrBranch
{ _instrBranchValue = VRef $ mkVarRef VarGroupLocal (fromJust (fi ^. Asm.functionExtra) ^. Asm.functionMaxTempStackHeight + si ^. Asm.stackInfoValueStackHeight - 1),
{ _instrBranchValue = VRef $ mkVarRef VarGroupLocal topIdx,
_instrBranchTrue = codeTrue,
_instrBranchFalse = codeFalse
_instrBranchFalse = codeFalse,
_instrBranchOutVar = if isTail then Nothing else Just $ mkVarRef VarGroupLocal topIdx
}
where
topIdx :: Int
topIdx = fromJust (fi ^. Asm.functionExtra) ^. Asm.functionMaxTempStackHeight + si ^. Asm.stackInfoValueStackHeight - 1
fromAsmCase ::
Asm.FunctionInfo ->
Asm.InfoTable ->
Bool ->
Asm.StackInfo ->
Asm.CmdCase ->
[Code] ->
Maybe Code ->
Sem r Instruction
fromAsmCase fi tab si Asm.CmdCase {..} brs def =
fromAsmCase fi tab isTail si Asm.CmdCase {..} brs def =
return $
Case $
InstrCase
{ _instrCaseValue = VRef $ mkVarRef VarGroupLocal (fromJust (fi ^. Asm.functionExtra) ^. Asm.functionMaxTempStackHeight + si ^. Asm.stackInfoValueStackHeight - 1),
{ _instrCaseValue = VRef $ mkVarRef VarGroupLocal topIdx,
_instrCaseInductive = _cmdCaseInductive,
_instrCaseIndRep = ii ^. Asm.inductiveRepresentation,
_instrCaseBranches =
@ -300,9 +306,11 @@ fromAsmCase fi tab si Asm.CmdCase {..} brs def =
)
_cmdCaseBranches
brs,
_instrCaseDefault = def
_instrCaseDefault = def,
_instrCaseOutVar = if isTail then Nothing else Just $ mkVarRef VarGroupLocal topIdx
}
where
topIdx = fromJust (fi ^. Asm.functionExtra) ^. Asm.functionMaxTempStackHeight + si ^. Asm.stackInfoValueStackHeight - 1
ii =
fromMaybe impossible $
HashMap.lookup _cmdCaseInductive (tab ^. Asm.infoInductives)

View File

@ -262,6 +262,13 @@ liveVars = do
P.try (comma >> symbol "live:")
parens (P.sepBy varRef comma)
outVar ::
(Members '[Reader ParserSig, InfoTableBuilder, State LocalParams] r) =>
ParsecS r VarRef
outVar = do
P.try (comma >> symbol "out:")
varRef
parseArgs ::
(Members '[Reader ParserSig, InfoTableBuilder, State LocalParams] r) =>
ParsecS r [Value]
@ -352,6 +359,7 @@ instrBranch ::
instrBranch = do
kw kwBr
val <- value
var <- optional outVar
(br1, br2) <- braces $ do
symbol "true:"
br1 <- braces parseCode
@ -364,7 +372,8 @@ instrBranch = do
InstrBranch
{ _instrBranchValue = val,
_instrBranchTrue = br1,
_instrBranchFalse = br2
_instrBranchFalse = br2,
_instrBranchOutVar = var
}
instrCase ::
@ -374,6 +383,7 @@ instrCase = do
kw kwCase
sym <- brackets (indSymbol @Code @() @VarRef)
val <- value
var <- optional outVar
lbrace
brs <- many caseBranch
def <- optional defaultBranch
@ -384,7 +394,8 @@ instrCase = do
_instrCaseInductive = sym,
_instrCaseIndRep = IndRepStandard,
_instrCaseBranches = brs,
_instrCaseDefault = def
_instrCaseDefault = def,
_instrCaseOutVar = var
}
caseBranch ::

View File

@ -2,10 +2,12 @@ module Reg.Transformation where
import Base
import Reg.Transformation.Identity qualified as Identity
import Reg.Transformation.SSA qualified as SSA
allTests :: TestTree
allTests =
testGroup
"JuvixReg transformations"
[ Identity.allTests
[ Identity.allTests,
SSA.allTests
]

View File

@ -0,0 +1,22 @@
module Reg.Transformation.SSA where
import Base
import Juvix.Compiler.Reg.Transformation
import Juvix.Compiler.Reg.Transformation.SSA
import Reg.Parse.Positive qualified as Parse
import Reg.Transformation.Base
allTests :: TestTree
allTests = testGroup "SSA" (map liftTest Parse.tests)
pipe :: [TransformationId]
pipe = [SSA]
liftTest :: Parse.PosTest -> TestTree
liftTest _testRun =
fromTest
Test
{ _testTransformations = pipe,
_testAssertion = \tab -> unless (checkSSA tab) $ error "check SSA",
_testRun
}

View File

@ -10,7 +10,7 @@ function main() : * {
tmp[0] = 3;
tmp[1] = 0;
tmp[0] = lt tmp[1] tmp[0];
br tmp[0] {
br tmp[0], out: tmp[0] {
true: {
tmp[0] = 1;
};
@ -21,7 +21,7 @@ function main() : * {
tmp[1] = 1;
tmp[2] = 2;
tmp[1] = le tmp[2] tmp[1];
br tmp[1] {
br tmp[1], out: tmp[1] {
true: {
tmp[1] = call loop (), live: (tmp[0]);
};
@ -29,7 +29,7 @@ function main() : * {
tmp[1] = 7;
tmp[2] = 8;
tmp[1] = le tmp[2] tmp[1];
br tmp[1] {
br tmp[1], out: tmp[1] {
true: {
tmp[1] = call loop (), live: (tmp[0]);
};

View File

@ -6,7 +6,7 @@ function sum(integer) : integer {
tmp[0] = arg[0];
tmp[1] = 0;
tmp[0] = eq tmp[1] tmp[0];
br tmp[0] {
br tmp[0], out: tmp[0] {
true: {
tmp[0] = 0;
};

View File

@ -56,7 +56,7 @@ function f(tree) : integer {
{
tmp[2] = tmp[5];
tmp[5] = tmp[1];
case[tree] tmp[5] {
case[tree] tmp[5], out: tmp[5] {
leaf: {
nop;
tmp[5] = 3;
@ -79,7 +79,7 @@ function f(tree) : integer {
{
tmp[3] = tmp[5];
tmp[5] = tmp[2];
case[tree] tmp[5] {
case[tree] tmp[5], out: tmp[5] {
node: {
{
tmp[4] = tmp[5];