1
1
mirror of https://github.com/anoma/juvix.git synced 2024-10-26 09:45:47 +03:00

JuvixReg transformation: initialize variables assigned in other branches (#2650)

* Closes #2576 
* Adds a JuvixReg transformation `InitBranchVars` which inserts
assignments to initialize variables assigned in other branches. Assumes
the input is in SSA form (which is preserved).
* Adds tests for the `InitBranchVars` transformation.
* Depends on #2647
This commit is contained in:
Łukasz Czajka 2024-02-23 12:20:11 +01:00 committed by GitHub
parent dfd664e184
commit 49678b4e54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 165 additions and 8 deletions

View File

@ -6,8 +6,9 @@ import Juvix.Prelude
data TransformationId
= Identity
| SSA
| Cleanup
| SSA
| InitBranchVars
deriving stock (Data, Bounded, Enum, Show)
data PipelineId
@ -21,14 +22,15 @@ toCTransformations :: [TransformationId]
toCTransformations = [Cleanup]
toCairoTransformations :: [TransformationId]
toCairoTransformations = [Cleanup, SSA]
toCairoTransformations = [Cleanup, SSA, InitBranchVars]
instance TransformationId' TransformationId where
transformationText :: TransformationId -> Text
transformationText = \case
Identity -> strIdentity
SSA -> strSSA
Cleanup -> strCleanup
SSA -> strSSA
InitBranchVars -> strInitBranchVars
instance PipelineId' TransformationId PipelineId where
pipelineText :: PipelineId -> Text

View File

@ -11,8 +11,11 @@ strCairoPipeline = "pipeline-cairo"
strIdentity :: Text
strIdentity = "identity"
strCleanup :: Text
strCleanup = "cleanup"
strSSA :: Text
strSSA = "ssa"
strCleanup :: Text
strCleanup = "cleanup"
strInitBranchVars :: Text
strInitBranchVars = "init-branch-vars"

View File

@ -9,7 +9,17 @@ data ForwardRecursorSig m c = ForwardRecursorSig
}
data BackwardRecursorSig m a = BackwardRecursorSig
{ _backwardFun :: Code -> a -> [a] -> m (a, Code),
{ -- | In `_backwardFun is a as`: `is = i : is'` is the instruction list
-- currently being processed (the head `i` is the processed instruction, the
-- tail `is'` contains the instructions after it); `a` is the accumulator
-- for `is'`; `as` contains the accumulator values for the branches (for
-- `Branch` and `Case` instructions, otherwise empty). For the `Case`
-- instruction, the accumulator for the default branch (if present) is the
-- last element of `as`.
_backwardFun :: Code -> a -> [a] -> m (a, Code),
-- | `backwardAdjust a` adjusts the accumulator value when going backwards
-- into a branch. See also `FoldSig` in `Asm.Extra.Recursors` for more
-- explanations.
_backwardAdjust :: a -> a
}
@ -125,3 +135,25 @@ ifoldFM f a0 is0 =
ifoldF :: (Monoid a) => (a -> Instruction -> a) -> a -> Code -> a
ifoldF f a is = runIdentity (ifoldFM (\a' -> return . f a') a is)
ifoldBM :: forall a m. (Monad m) => (a -> [a] -> Instruction -> m a) -> a -> Code -> m a
ifoldBM f a0 is0 =
fst
<$> recurseB
BackwardRecursorSig
{ _backwardFun = go,
_backwardAdjust = id
}
a0
is0
where
go :: Code -> a -> [a] -> m (a, Code)
go is a as = case is of
i : _ -> do
a' <- f a as i
return (a', is)
[] ->
return (a, is)
ifoldB :: (a -> [a] -> Instruction -> a) -> a -> Code -> a
ifoldB f a is = runIdentity (ifoldBM (\a' as' -> return . f a' as') a is)

View File

@ -9,6 +9,7 @@ import Juvix.Compiler.Reg.Data.TransformationId
import Juvix.Compiler.Reg.Transformation.Base
import Juvix.Compiler.Reg.Transformation.Cleanup
import Juvix.Compiler.Reg.Transformation.Identity
import Juvix.Compiler.Reg.Transformation.InitBranchVars
import Juvix.Compiler.Reg.Transformation.SSA
applyTransformations :: forall r. [TransformationId] -> InfoTable -> Sem r InfoTable
@ -17,5 +18,6 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
appTrans :: TransformationId -> InfoTable -> Sem r InfoTable
appTrans = \case
Identity -> return . identity
SSA -> return . computeSSA
Cleanup -> return . cleanup
SSA -> return . computeSSA
InitBranchVars -> return . initBranchVars

View File

@ -0,0 +1,91 @@
module Juvix.Compiler.Reg.Transformation.InitBranchVars where
import Data.Functor.Identity
import Data.HashSet qualified as HashSet
import Data.List qualified as List
import Juvix.Compiler.Reg.Extra
import Juvix.Compiler.Reg.Transformation.Base
-- | Inserts assignments to initialize variables assigned in other branches.
-- Assumes the input is in SSA form (which is preserved).
initBranchVars :: InfoTable -> InfoTable
initBranchVars = mapT (const goFun)
where
goFun :: Code -> Code
goFun =
snd
. runIdentity
. recurseB
BackwardRecursorSig
{ _backwardFun = \is a as -> return (go is a as),
_backwardAdjust = const mempty
}
mempty
go :: Code -> HashSet VarRef -> [HashSet VarRef] -> (HashSet VarRef, Code)
go is a as = case is of
Branch InstrBranch {..} : is' -> case as of
[a1, a2] -> (a <> a', i' : is')
where
a' = a1 <> a2
a1' = HashSet.difference a' a1
a2' = HashSet.difference a' a2
i' =
Branch
InstrBranch
{ _instrBranchTrue = addInits a1' _instrBranchTrue,
_instrBranchFalse = addInits a2' _instrBranchFalse,
..
}
_ -> impossible
Case InstrCase {..} : is' ->
(a <> a', i' : is')
where
a' = mconcat as
as' = map (HashSet.difference a') as
n = length _instrCaseBranches
brs' = zipWithExact goBranch (take n as') _instrCaseBranches
def' = maybe Nothing (Just . addInits (List.last as')) _instrCaseDefault
i' =
Case
InstrCase
{ _instrCaseBranches = brs',
_instrCaseDefault = def',
..
}
goBranch :: HashSet VarRef -> CaseBranch -> CaseBranch
goBranch vars = over caseBranchCode (addInits vars)
i : _ ->
case getResultVar i of
Just v ->
(HashSet.insert v a <> mconcat as, is)
Nothing ->
(a <> mconcat as, is)
[] ->
(a <> mconcat as, is)
addInits :: HashSet VarRef -> Code -> Code
addInits vars is = map mk (toList vars) ++ is
where
mk :: VarRef -> Instruction
mk vref =
Assign
InstrAssign
{ _instrAssignResult = vref,
_instrAssignValue = Const ConstVoid
}
checkInitialized :: InfoTable -> Bool
checkInitialized tab = all (goFun . (^. functionCode)) (tab ^. infoFunctions)
where
goFun :: Code -> Bool
goFun = snd . ifoldB go (mempty, True)
where
go :: (HashSet VarRef, Bool) -> [(HashSet VarRef, Bool)] -> Instruction -> (HashSet VarRef, Bool)
go (v, b) ls i = case getResultVar i of
Just vref -> (HashSet.insert vref v', b')
Nothing -> (v', b')
where
v' = v <> mconcat (map fst ls)
b' = b && allSame (map fst ls) && and (map snd ls)

View File

@ -2,6 +2,7 @@ module Reg.Transformation where
import Base
import Reg.Transformation.Identity qualified as Identity
import Reg.Transformation.InitBranchVars qualified as InitBranchVars
import Reg.Transformation.SSA qualified as SSA
allTests :: TestTree
@ -9,5 +10,6 @@ allTests =
testGroup
"JuvixReg transformations"
[ Identity.allTests,
SSA.allTests
SSA.allTests,
InitBranchVars.allTests
]

View File

@ -0,0 +1,25 @@
module Reg.Transformation.InitBranchVars where
import Base
import Juvix.Compiler.Reg.Transformation
import Juvix.Compiler.Reg.Transformation.InitBranchVars
import Juvix.Compiler.Reg.Transformation.SSA
import Reg.Parse.Positive qualified as Parse
import Reg.Transformation.Base
allTests :: TestTree
allTests = testGroup "InitBranchVars" (map liftTest Parse.tests)
pipe :: [TransformationId]
pipe = [SSA, InitBranchVars]
liftTest :: Parse.PosTest -> TestTree
liftTest _testRun =
fromTest
Test
{ _testTransformations = pipe,
_testAssertion = \tab -> do
unless (checkSSA tab) $ error "check ssa"
unless (checkInitialized tab) $ error "check initialized",
_testRun
}