mirror of
https://github.com/anoma/juvix.git
synced 2024-10-26 17:52:17 +03:00
JuvixReg transformation: initialize variables assigned in other branches (#2650)
* Closes #2576 * Adds a JuvixReg transformation `InitBranchVars` which inserts assignments to initialize variables assigned in other branches. Assumes the input is in SSA form (which is preserved). * Adds tests for the `InitBranchVars` transformation. * Depends on #2647
This commit is contained in:
parent
dfd664e184
commit
49678b4e54
@ -6,8 +6,9 @@ import Juvix.Prelude
|
||||
|
||||
data TransformationId
|
||||
= Identity
|
||||
| SSA
|
||||
| Cleanup
|
||||
| SSA
|
||||
| InitBranchVars
|
||||
deriving stock (Data, Bounded, Enum, Show)
|
||||
|
||||
data PipelineId
|
||||
@ -21,14 +22,15 @@ toCTransformations :: [TransformationId]
|
||||
toCTransformations = [Cleanup]
|
||||
|
||||
toCairoTransformations :: [TransformationId]
|
||||
toCairoTransformations = [Cleanup, SSA]
|
||||
toCairoTransformations = [Cleanup, SSA, InitBranchVars]
|
||||
|
||||
instance TransformationId' TransformationId where
|
||||
transformationText :: TransformationId -> Text
|
||||
transformationText = \case
|
||||
Identity -> strIdentity
|
||||
SSA -> strSSA
|
||||
Cleanup -> strCleanup
|
||||
SSA -> strSSA
|
||||
InitBranchVars -> strInitBranchVars
|
||||
|
||||
instance PipelineId' TransformationId PipelineId where
|
||||
pipelineText :: PipelineId -> Text
|
||||
|
@ -11,8 +11,11 @@ strCairoPipeline = "pipeline-cairo"
|
||||
strIdentity :: Text
|
||||
strIdentity = "identity"
|
||||
|
||||
strCleanup :: Text
|
||||
strCleanup = "cleanup"
|
||||
|
||||
strSSA :: Text
|
||||
strSSA = "ssa"
|
||||
|
||||
strCleanup :: Text
|
||||
strCleanup = "cleanup"
|
||||
strInitBranchVars :: Text
|
||||
strInitBranchVars = "init-branch-vars"
|
||||
|
@ -9,7 +9,17 @@ data ForwardRecursorSig m c = ForwardRecursorSig
|
||||
}
|
||||
|
||||
data BackwardRecursorSig m a = BackwardRecursorSig
|
||||
{ _backwardFun :: Code -> a -> [a] -> m (a, Code),
|
||||
{ -- | In `_backwardFun is a as`: `is = i : is'` is the instruction list
|
||||
-- currently being processed (the head `i` is the processed instruction, the
|
||||
-- tail `is'` contains the instructions after it); `a` is the accumulator
|
||||
-- for `is'`; `as` contains the accumulator values for the branches (for
|
||||
-- `Branch` and `Case` instructions, otherwise empty). For the `Case`
|
||||
-- instruction, the accumulator for the default branch (if present) is the
|
||||
-- last element of `as`.
|
||||
_backwardFun :: Code -> a -> [a] -> m (a, Code),
|
||||
-- | `backwardAdjust a` adjusts the accumulator value when going backwards
|
||||
-- into a branch. See also `FoldSig` in `Asm.Extra.Recursors` for more
|
||||
-- explanations.
|
||||
_backwardAdjust :: a -> a
|
||||
}
|
||||
|
||||
@ -125,3 +135,25 @@ ifoldFM f a0 is0 =
|
||||
|
||||
ifoldF :: (Monoid a) => (a -> Instruction -> a) -> a -> Code -> a
|
||||
ifoldF f a is = runIdentity (ifoldFM (\a' -> return . f a') a is)
|
||||
|
||||
ifoldBM :: forall a m. (Monad m) => (a -> [a] -> Instruction -> m a) -> a -> Code -> m a
|
||||
ifoldBM f a0 is0 =
|
||||
fst
|
||||
<$> recurseB
|
||||
BackwardRecursorSig
|
||||
{ _backwardFun = go,
|
||||
_backwardAdjust = id
|
||||
}
|
||||
a0
|
||||
is0
|
||||
where
|
||||
go :: Code -> a -> [a] -> m (a, Code)
|
||||
go is a as = case is of
|
||||
i : _ -> do
|
||||
a' <- f a as i
|
||||
return (a', is)
|
||||
[] ->
|
||||
return (a, is)
|
||||
|
||||
ifoldB :: (a -> [a] -> Instruction -> a) -> a -> Code -> a
|
||||
ifoldB f a is = runIdentity (ifoldBM (\a' as' -> return . f a' as') a is)
|
||||
|
@ -9,6 +9,7 @@ import Juvix.Compiler.Reg.Data.TransformationId
|
||||
import Juvix.Compiler.Reg.Transformation.Base
|
||||
import Juvix.Compiler.Reg.Transformation.Cleanup
|
||||
import Juvix.Compiler.Reg.Transformation.Identity
|
||||
import Juvix.Compiler.Reg.Transformation.InitBranchVars
|
||||
import Juvix.Compiler.Reg.Transformation.SSA
|
||||
|
||||
applyTransformations :: forall r. [TransformationId] -> InfoTable -> Sem r InfoTable
|
||||
@ -17,5 +18,6 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
|
||||
appTrans :: TransformationId -> InfoTable -> Sem r InfoTable
|
||||
appTrans = \case
|
||||
Identity -> return . identity
|
||||
SSA -> return . computeSSA
|
||||
Cleanup -> return . cleanup
|
||||
SSA -> return . computeSSA
|
||||
InitBranchVars -> return . initBranchVars
|
||||
|
91
src/Juvix/Compiler/Reg/Transformation/InitBranchVars.hs
Normal file
91
src/Juvix/Compiler/Reg/Transformation/InitBranchVars.hs
Normal file
@ -0,0 +1,91 @@
|
||||
module Juvix.Compiler.Reg.Transformation.InitBranchVars where
|
||||
|
||||
import Data.Functor.Identity
|
||||
import Data.HashSet qualified as HashSet
|
||||
import Data.List qualified as List
|
||||
import Juvix.Compiler.Reg.Extra
|
||||
import Juvix.Compiler.Reg.Transformation.Base
|
||||
|
||||
-- | Inserts assignments to initialize variables assigned in other branches.
|
||||
-- Assumes the input is in SSA form (which is preserved).
|
||||
initBranchVars :: InfoTable -> InfoTable
|
||||
initBranchVars = mapT (const goFun)
|
||||
where
|
||||
goFun :: Code -> Code
|
||||
goFun =
|
||||
snd
|
||||
. runIdentity
|
||||
. recurseB
|
||||
BackwardRecursorSig
|
||||
{ _backwardFun = \is a as -> return (go is a as),
|
||||
_backwardAdjust = const mempty
|
||||
}
|
||||
mempty
|
||||
|
||||
go :: Code -> HashSet VarRef -> [HashSet VarRef] -> (HashSet VarRef, Code)
|
||||
go is a as = case is of
|
||||
Branch InstrBranch {..} : is' -> case as of
|
||||
[a1, a2] -> (a <> a', i' : is')
|
||||
where
|
||||
a' = a1 <> a2
|
||||
a1' = HashSet.difference a' a1
|
||||
a2' = HashSet.difference a' a2
|
||||
i' =
|
||||
Branch
|
||||
InstrBranch
|
||||
{ _instrBranchTrue = addInits a1' _instrBranchTrue,
|
||||
_instrBranchFalse = addInits a2' _instrBranchFalse,
|
||||
..
|
||||
}
|
||||
_ -> impossible
|
||||
Case InstrCase {..} : is' ->
|
||||
(a <> a', i' : is')
|
||||
where
|
||||
a' = mconcat as
|
||||
as' = map (HashSet.difference a') as
|
||||
n = length _instrCaseBranches
|
||||
brs' = zipWithExact goBranch (take n as') _instrCaseBranches
|
||||
def' = maybe Nothing (Just . addInits (List.last as')) _instrCaseDefault
|
||||
i' =
|
||||
Case
|
||||
InstrCase
|
||||
{ _instrCaseBranches = brs',
|
||||
_instrCaseDefault = def',
|
||||
..
|
||||
}
|
||||
|
||||
goBranch :: HashSet VarRef -> CaseBranch -> CaseBranch
|
||||
goBranch vars = over caseBranchCode (addInits vars)
|
||||
i : _ ->
|
||||
case getResultVar i of
|
||||
Just v ->
|
||||
(HashSet.insert v a <> mconcat as, is)
|
||||
Nothing ->
|
||||
(a <> mconcat as, is)
|
||||
[] ->
|
||||
(a <> mconcat as, is)
|
||||
|
||||
addInits :: HashSet VarRef -> Code -> Code
|
||||
addInits vars is = map mk (toList vars) ++ is
|
||||
where
|
||||
mk :: VarRef -> Instruction
|
||||
mk vref =
|
||||
Assign
|
||||
InstrAssign
|
||||
{ _instrAssignResult = vref,
|
||||
_instrAssignValue = Const ConstVoid
|
||||
}
|
||||
|
||||
checkInitialized :: InfoTable -> Bool
|
||||
checkInitialized tab = all (goFun . (^. functionCode)) (tab ^. infoFunctions)
|
||||
where
|
||||
goFun :: Code -> Bool
|
||||
goFun = snd . ifoldB go (mempty, True)
|
||||
where
|
||||
go :: (HashSet VarRef, Bool) -> [(HashSet VarRef, Bool)] -> Instruction -> (HashSet VarRef, Bool)
|
||||
go (v, b) ls i = case getResultVar i of
|
||||
Just vref -> (HashSet.insert vref v', b')
|
||||
Nothing -> (v', b')
|
||||
where
|
||||
v' = v <> mconcat (map fst ls)
|
||||
b' = b && allSame (map fst ls) && and (map snd ls)
|
@ -2,6 +2,7 @@ module Reg.Transformation where
|
||||
|
||||
import Base
|
||||
import Reg.Transformation.Identity qualified as Identity
|
||||
import Reg.Transformation.InitBranchVars qualified as InitBranchVars
|
||||
import Reg.Transformation.SSA qualified as SSA
|
||||
|
||||
allTests :: TestTree
|
||||
@ -9,5 +10,6 @@ allTests =
|
||||
testGroup
|
||||
"JuvixReg transformations"
|
||||
[ Identity.allTests,
|
||||
SSA.allTests
|
||||
SSA.allTests,
|
||||
InitBranchVars.allTests
|
||||
]
|
||||
|
25
test/Reg/Transformation/InitBranchVars.hs
Normal file
25
test/Reg/Transformation/InitBranchVars.hs
Normal file
@ -0,0 +1,25 @@
|
||||
module Reg.Transformation.InitBranchVars where
|
||||
|
||||
import Base
|
||||
import Juvix.Compiler.Reg.Transformation
|
||||
import Juvix.Compiler.Reg.Transformation.InitBranchVars
|
||||
import Juvix.Compiler.Reg.Transformation.SSA
|
||||
import Reg.Parse.Positive qualified as Parse
|
||||
import Reg.Transformation.Base
|
||||
|
||||
allTests :: TestTree
|
||||
allTests = testGroup "InitBranchVars" (map liftTest Parse.tests)
|
||||
|
||||
pipe :: [TransformationId]
|
||||
pipe = [SSA, InitBranchVars]
|
||||
|
||||
liftTest :: Parse.PosTest -> TestTree
|
||||
liftTest _testRun =
|
||||
fromTest
|
||||
Test
|
||||
{ _testTransformations = pipe,
|
||||
_testAssertion = \tab -> do
|
||||
unless (checkSSA tab) $ error "check ssa"
|
||||
unless (checkInitialized tab) $ error "check initialized",
|
||||
_testRun
|
||||
}
|
Loading…
Reference in New Issue
Block a user