mirror of
https://github.com/anoma/juvix.git
synced 2025-01-05 22:46:08 +03:00
Lifting calls out of cases for the VampIR backend (#2218)
* Closes #2200 For example, ``` def power' : Int → Int → Int → Int := λ(acc : Int) λ(a : Int) λ(b : Int) if = b 0 then acc else if = (% b 2) 0 then power' acc (* a a) (/ b 2) else power' (* acc a) (* a a) (/ b 2); ``` is transformed into ``` def power' : Int → Int → Int → Int := λ(acc : Int) λ(a : Int) λ(b : Int) if = b 0 then acc else let _X : Bool := = (% b 2) 0 in power' (if _X then acc else * acc a) (* a a) (/ b 2); ```
This commit is contained in:
parent
8201cb828c
commit
f77e05513f
@ -27,9 +27,12 @@ data TransformationId
|
||||
| LetHoisting
|
||||
| Inlining
|
||||
| FoldTypeSynonyms
|
||||
| CaseCallLifting
|
||||
| SimplifyIfs
|
||||
| OptPhaseEval
|
||||
| OptPhaseExec
|
||||
| OptPhaseGeb
|
||||
| OptPhaseVampIR
|
||||
| OptPhaseMain
|
||||
deriving stock (Data, Bounded, Enum, Show)
|
||||
|
||||
@ -69,7 +72,7 @@ toNormalizeTransformations :: [TransformationId]
|
||||
toNormalizeTransformations = toEvalTransformations ++ [LetRecLifting, LetFolding, UnrollRecursion]
|
||||
|
||||
toVampIRTransformations :: [TransformationId]
|
||||
toVampIRTransformations = toEvalTransformations ++ [CheckVampIR, LetRecLifting, LetFolding, UnrollRecursion, Normalize, LetHoisting]
|
||||
toVampIRTransformations = toEvalTransformations ++ [CheckVampIR, LetRecLifting, OptPhaseVampIR, UnrollRecursion, Normalize, LetHoisting]
|
||||
|
||||
toStrippedTransformations :: [TransformationId]
|
||||
toStrippedTransformations =
|
||||
|
@ -87,9 +87,12 @@ transformationText = \case
|
||||
LetHoisting -> strLetHoisting
|
||||
Inlining -> strInlining
|
||||
FoldTypeSynonyms -> strFoldTypeSynonyms
|
||||
CaseCallLifting -> strCaseCallLifting
|
||||
SimplifyIfs -> strSimplifyIfs
|
||||
OptPhaseEval -> strOptPhaseEval
|
||||
OptPhaseExec -> strOptPhaseExec
|
||||
OptPhaseGeb -> strOptPhaseGeb
|
||||
OptPhaseVampIR -> strOptPhaseVampIR
|
||||
OptPhaseMain -> strOptPhaseMain
|
||||
|
||||
parsePipeline :: MonadParsec e Text m => m PipelineId
|
||||
@ -188,6 +191,12 @@ strInlining = "inlining"
|
||||
strFoldTypeSynonyms :: Text
|
||||
strFoldTypeSynonyms = "fold-type-synonyms"
|
||||
|
||||
strCaseCallLifting :: Text
|
||||
strCaseCallLifting = "case-call-lifting"
|
||||
|
||||
strSimplifyIfs :: Text
|
||||
strSimplifyIfs = "simplify-ifs"
|
||||
|
||||
strOptPhaseEval :: Text
|
||||
strOptPhaseEval = "opt-phase-eval"
|
||||
|
||||
@ -197,5 +206,8 @@ strOptPhaseExec = "opt-phase-exec"
|
||||
strOptPhaseGeb :: Text
|
||||
strOptPhaseGeb = "opt-phase-geb"
|
||||
|
||||
strOptPhaseVampIR :: Text
|
||||
strOptPhaseVampIR = "opt-phase-vampir"
|
||||
|
||||
strOptPhaseMain :: Text
|
||||
strOptPhaseMain = "opt-phase-main"
|
||||
|
@ -1,12 +1,15 @@
|
||||
module Juvix.Compiler.Core.Extra.Recursors
|
||||
( module Juvix.Compiler.Core.Extra.Recursors.Fold,
|
||||
module Juvix.Compiler.Core.Extra.Recursors.Collector,
|
||||
module Juvix.Compiler.Core.Extra.Recursors.Fold.Named,
|
||||
module Juvix.Compiler.Core.Extra.Recursors.Map,
|
||||
module Juvix.Compiler.Core.Extra.Recursors.Map.Named,
|
||||
module Juvix.Compiler.Core.Extra.Recursors.SFold,
|
||||
module Juvix.Compiler.Core.Extra.Recursors.SFold.Named,
|
||||
module Juvix.Compiler.Core.Extra.Recursors.SMap.Named,
|
||||
module Juvix.Compiler.Core.Extra.Recursors.RMap,
|
||||
module Juvix.Compiler.Core.Extra.Recursors.RMap.Named,
|
||||
module Juvix.Compiler.Core.Extra.Recursors.Fold.Named,
|
||||
module Juvix.Compiler.Core.Extra.Recursors.Recur,
|
||||
module Juvix.Compiler.Core.Extra.Recursors.Collector,
|
||||
)
|
||||
where
|
||||
|
||||
@ -19,3 +22,6 @@ import Juvix.Compiler.Core.Extra.Recursors.Map.Named
|
||||
import Juvix.Compiler.Core.Extra.Recursors.RMap
|
||||
import Juvix.Compiler.Core.Extra.Recursors.RMap.Named
|
||||
import Juvix.Compiler.Core.Extra.Recursors.Recur
|
||||
import Juvix.Compiler.Core.Extra.Recursors.SFold
|
||||
import Juvix.Compiler.Core.Extra.Recursors.SFold.Named
|
||||
import Juvix.Compiler.Core.Extra.Recursors.SMap.Named
|
||||
|
@ -4,15 +4,8 @@ module Juvix.Compiler.Core.Extra.Recursors.Map where
|
||||
|
||||
import Data.Functor.Identity
|
||||
import Data.Kind qualified as GHC
|
||||
import Data.Singletons.TH
|
||||
import Juvix.Compiler.Core.Extra.Base
|
||||
import Juvix.Compiler.Core.Extra.Recursors.Base
|
||||
import Juvix.Compiler.Core.Extra.Recursors.Parameters
|
||||
|
||||
type DirTy :: Direction -> GHC.Type -> GHC.Type
|
||||
type family DirTy d c = res | res -> d where
|
||||
DirTy 'TopDown c = Recur' c
|
||||
DirTy 'BottomUp _ = Node -- For bottom up maps we never recur on the children
|
||||
|
||||
-- | `umapG` maps the nodes bottom-up, i.e., when invoking the mapper function the
|
||||
-- recursive subnodes have already been mapped
|
||||
@ -119,14 +112,3 @@ fromRecur' d =
|
||||
End' x -> End' x
|
||||
Recur' (c, x) -> Recur' ((c, d), x)
|
||||
)
|
||||
|
||||
nodeMapG' ::
|
||||
(Monad m) =>
|
||||
Sing dir ->
|
||||
Collector (Int, [Binder]) c ->
|
||||
(c -> Node -> m (DirTy dir c)) ->
|
||||
Node ->
|
||||
m Node
|
||||
nodeMapG' sdir = case sdir of
|
||||
STopDown -> dmapG
|
||||
SBottomUp -> umapG
|
||||
|
@ -3,7 +3,6 @@ module Juvix.Compiler.Core.Extra.Recursors.Map.Named where
|
||||
import Data.Functor.Identity
|
||||
import Juvix.Compiler.Core.Extra.Recursors.Base
|
||||
import Juvix.Compiler.Core.Extra.Recursors.Map
|
||||
import Juvix.Compiler.Core.Extra.Recursors.Parameters
|
||||
|
||||
{-
|
||||
|
||||
@ -50,7 +49,7 @@ dmapLM :: (Monad m) => (BinderList Binder -> Node -> m Node) -> Node -> m Node
|
||||
dmapLM f = dmapLM' (mempty, f)
|
||||
|
||||
umapLM :: (Monad m) => (BinderList Binder -> Node -> m Node) -> Node -> m Node
|
||||
umapLM f = nodeMapG' SBottomUp binderInfoCollector f
|
||||
umapLM f = umapG binderInfoCollector f
|
||||
|
||||
dmapNRM :: (Monad m) => (Level -> Node -> m Recur) -> Node -> m Node
|
||||
dmapNRM f = dmapNRM' (0, f)
|
||||
@ -59,34 +58,34 @@ dmapNM :: (Monad m) => (Level -> Node -> m Node) -> Node -> m Node
|
||||
dmapNM f = dmapNM' (0, f)
|
||||
|
||||
umapNM :: (Monad m) => (Level -> Node -> m Node) -> Node -> m Node
|
||||
umapNM f = nodeMapG' SBottomUp binderNumCollector f
|
||||
umapNM f = umapG binderNumCollector f
|
||||
|
||||
dmapRM :: (Monad m) => (Node -> m Recur) -> Node -> m Node
|
||||
dmapRM f = nodeMapG' STopDown unitCollector (const (fromRecur mempty . f))
|
||||
dmapRM f = dmapG unitCollector (const (fromRecur mempty . f))
|
||||
|
||||
dmapM :: (Monad m) => (Node -> m Node) -> Node -> m Node
|
||||
dmapM f = nodeMapG' STopDown unitCollector (const (fromSimple mempty . f))
|
||||
dmapM f = dmapG unitCollector (const (fromSimple mempty . f))
|
||||
|
||||
umapM :: (Monad m) => (Node -> m Node) -> Node -> m Node
|
||||
umapM f = nodeMapG' SBottomUp unitCollector (const f)
|
||||
umapM f = umapG unitCollector (const f)
|
||||
|
||||
dmapLRM' :: (Monad m) => (BinderList Binder, BinderList Binder -> Node -> m Recur) -> Node -> m Node
|
||||
dmapLRM' f = nodeMapG' STopDown (binderInfoCollector' (fst f)) (\bi -> fromRecur bi . snd f bi)
|
||||
dmapLRM' f = dmapG (binderInfoCollector' (fst f)) (\bi -> fromRecur bi . snd f bi)
|
||||
|
||||
dmapLM' :: (Monad m) => (BinderList Binder, BinderList Binder -> Node -> m Node) -> Node -> m Node
|
||||
dmapLM' f = nodeMapG' STopDown (binderInfoCollector' (fst f)) (\bi -> fromSimple bi . snd f bi)
|
||||
dmapLM' f = dmapG (binderInfoCollector' (fst f)) (\bi -> fromSimple bi . snd f bi)
|
||||
|
||||
umapLM' :: (Monad m) => (BinderList Binder, BinderList Binder -> Node -> m Node) -> Node -> m Node
|
||||
umapLM' f = nodeMapG' SBottomUp (binderInfoCollector' (fst f)) (snd f)
|
||||
umapLM' f = umapG (binderInfoCollector' (fst f)) (snd f)
|
||||
|
||||
dmapNRM' :: (Monad m) => (Level, Level -> Node -> m Recur) -> Node -> m Node
|
||||
dmapNRM' f = nodeMapG' STopDown (binderNumCollector' (fst f)) (\bi -> fromRecur bi . snd f bi)
|
||||
dmapNRM' f = dmapG (binderNumCollector' (fst f)) (\bi -> fromRecur bi . snd f bi)
|
||||
|
||||
dmapNM' :: (Monad m) => (Level, Level -> Node -> m Node) -> Node -> m Node
|
||||
dmapNM' f = nodeMapG' STopDown (binderNumCollector' (fst f)) (\bi -> fromSimple bi . snd f bi)
|
||||
dmapNM' f = dmapG (binderNumCollector' (fst f)) (\bi -> fromSimple bi . snd f bi)
|
||||
|
||||
umapNM' :: (Monad m) => (Level, Level -> Node -> m Node) -> Node -> m Node
|
||||
umapNM' f = nodeMapG' SBottomUp (binderNumCollector' (fst f)) (snd f)
|
||||
umapNM' f = umapG (binderNumCollector' (fst f)) (snd f)
|
||||
|
||||
dmapLR :: (BinderList Binder -> Node -> Recur) -> Node -> Node
|
||||
dmapLR f = runIdentity . dmapLRM (embedIden f)
|
||||
@ -134,25 +133,25 @@ umapN' :: (Level, Level -> Node -> Node) -> Node -> Node
|
||||
umapN' f = runIdentity . umapNM' (embedIden f)
|
||||
|
||||
dmapCLM' :: (Monad m) => (BinderList Binder, c -> BinderList Binder -> Node -> m (c, Node)) -> c -> Node -> m Node
|
||||
dmapCLM' f ini = nodeMapG' STopDown (pairCollector (identityCollector ini) (binderInfoCollector' (fst f))) (\(c, bi) -> fromPair bi . snd f c bi)
|
||||
dmapCLM' f ini = dmapG (pairCollector (identityCollector ini) (binderInfoCollector' (fst f))) (\(c, bi) -> fromPair bi . snd f c bi)
|
||||
|
||||
dmapCLRM' :: (Monad m) => (BinderList Binder, c -> BinderList Binder -> Node -> m (Recur' c)) -> c -> Node -> m Node
|
||||
dmapCLRM' f ini = nodeMapG' STopDown (pairCollector (identityCollector ini) (binderInfoCollector' (fst f))) (\(c, bi) -> fromRecur' bi . snd f c bi)
|
||||
dmapCLRM' f ini = dmapG (pairCollector (identityCollector ini) (binderInfoCollector' (fst f))) (\(c, bi) -> fromRecur' bi . snd f c bi)
|
||||
|
||||
dmapCNRM' :: (Monad m) => (Level, c -> Level -> Node -> m (Recur' c)) -> c -> Node -> m Node
|
||||
dmapCNRM' f ini = nodeMapG' STopDown (pairCollector (identityCollector ini) (binderNumCollector' (fst f))) (\(c, bi) -> fromRecur' bi . snd f c bi)
|
||||
dmapCNRM' f ini = dmapG (pairCollector (identityCollector ini) (binderNumCollector' (fst f))) (\(c, bi) -> fromRecur' bi . snd f c bi)
|
||||
|
||||
dmapCLM :: (Monad m) => (c -> BinderList Binder -> Node -> m (c, Node)) -> c -> Node -> m Node
|
||||
dmapCLM f = dmapCLM' (mempty, f)
|
||||
|
||||
dmapCNM' :: (Monad m) => (Level, c -> Level -> Node -> m (c, Node)) -> c -> Node -> m Node
|
||||
dmapCNM' f ini = nodeMapG' STopDown (pairCollector (identityCollector ini) (binderNumCollector' (fst f))) (\(c, bi) -> fromPair bi . snd f c bi)
|
||||
dmapCNM' f ini = dmapG (pairCollector (identityCollector ini) (binderNumCollector' (fst f))) (\(c, bi) -> fromPair bi . snd f c bi)
|
||||
|
||||
dmapCNM :: (Monad m) => (c -> Level -> Node -> m (c, Node)) -> c -> Node -> m Node
|
||||
dmapCNM f = dmapCNM' (0, f)
|
||||
|
||||
dmapCM :: (Monad m) => (c -> Node -> m (c, Node)) -> c -> Node -> m Node
|
||||
dmapCM f ini = nodeMapG' STopDown (identityCollector ini) (\c -> fmap Recur' . f c)
|
||||
dmapCM f ini = dmapG (identityCollector ini) (\c -> fmap Recur' . f c)
|
||||
|
||||
dmapCL' :: (BinderList Binder, c -> BinderList Binder -> Node -> (c, Node)) -> c -> Node -> Node
|
||||
dmapCL' f ini = runIdentity . dmapCLM' (embedIden f) ini
|
||||
|
@ -1,26 +0,0 @@
|
||||
module Juvix.Compiler.Core.Extra.Recursors.Parameters where
|
||||
|
||||
import Juvix.Prelude
|
||||
|
||||
data CollectorIni
|
||||
= NoIni
|
||||
| Ini
|
||||
|
||||
data Ctx
|
||||
= CtxBinderList
|
||||
| CtxBinderNum
|
||||
| CtxNone
|
||||
|
||||
data Monadic
|
||||
= Monadic
|
||||
| NonMonadic
|
||||
|
||||
data Ret
|
||||
= RetRecur
|
||||
| RetSimple
|
||||
|
||||
data Direction
|
||||
= TopDown
|
||||
| BottomUp
|
||||
|
||||
$(genSingletons [''CollectorIni, ''Ctx, ''Ret, ''Monadic, ''Direction])
|
23
src/Juvix/Compiler/Core/Extra/Recursors/SFold.hs
Normal file
23
src/Juvix/Compiler/Core/Extra/Recursors/SFold.hs
Normal file
@ -0,0 +1,23 @@
|
||||
-- | Shallow fold recursors over 'Node'.
|
||||
module Juvix.Compiler.Core.Extra.Recursors.SFold where
|
||||
|
||||
import Juvix.Compiler.Core.Extra.Base
|
||||
import Juvix.Compiler.Core.Extra.Recursors.Base
|
||||
|
||||
sfoldG ::
|
||||
forall a f.
|
||||
(Applicative f) =>
|
||||
(a -> [a] -> a) ->
|
||||
(Node -> f a) ->
|
||||
Node ->
|
||||
f a
|
||||
sfoldG uplus f = go
|
||||
where
|
||||
go :: Node -> f a
|
||||
go n = do
|
||||
mas' <- sequenceA mas
|
||||
n' <- f n
|
||||
pure (uplus n' mas')
|
||||
where
|
||||
mas :: [f a]
|
||||
mas = map go (schildren n)
|
24
src/Juvix/Compiler/Core/Extra/Recursors/SFold/Named.hs
Normal file
24
src/Juvix/Compiler/Core/Extra/Recursors/SFold/Named.hs
Normal file
@ -0,0 +1,24 @@
|
||||
module Juvix.Compiler.Core.Extra.Recursors.SFold.Named where
|
||||
|
||||
import Data.Functor.Identity
|
||||
import Juvix.Compiler.Core.Extra.Recursors.Base
|
||||
import Juvix.Compiler.Core.Extra.Recursors.SFold
|
||||
|
||||
{-
|
||||
|
||||
The shallow folding recursors are analogous to general folding recursors (see
|
||||
Core/Extra/Recursors/Fold/Named.hs) except that they don't go under binders.
|
||||
|
||||
-}
|
||||
|
||||
sfoldA :: (Applicative f) => (a -> [a] -> a) -> (Node -> f a) -> Node -> f a
|
||||
sfoldA uplus f = sfoldG uplus f
|
||||
|
||||
swalk :: (Applicative f) => (Node -> f ()) -> Node -> f ()
|
||||
swalk = sfoldA (foldr mappend)
|
||||
|
||||
sfold :: (a -> [a] -> a) -> (Node -> a) -> Node -> a
|
||||
sfold uplus f = runIdentity . sfoldA uplus (return . f)
|
||||
|
||||
sgather :: (a -> Node -> a) -> a -> Node -> a
|
||||
sgather f acc = run . execState acc . swalk (\n' -> modify' (`f` n'))
|
68
src/Juvix/Compiler/Core/Extra/Recursors/SMap.hs
Normal file
68
src/Juvix/Compiler/Core/Extra/Recursors/SMap.hs
Normal file
@ -0,0 +1,68 @@
|
||||
{-# LANGUAGE UndecidableInstances #-}
|
||||
|
||||
module Juvix.Compiler.Core.Extra.Recursors.SMap where
|
||||
|
||||
import Data.Functor.Identity
|
||||
import Data.Kind qualified as GHC
|
||||
import Juvix.Compiler.Core.Extra.Base
|
||||
import Juvix.Compiler.Core.Extra.Recursors.Base
|
||||
|
||||
-- | `sumapG` is the shallow version of `umapG`, i.e., it doesn't recurse under binders
|
||||
sumapG ::
|
||||
forall m.
|
||||
(Monad m) =>
|
||||
(Node -> m Node) ->
|
||||
Node ->
|
||||
m Node
|
||||
sumapG f = go
|
||||
where
|
||||
go :: Node -> m Node
|
||||
go n =
|
||||
let ni = destruct n
|
||||
in do
|
||||
ns <- mapM goChild (ni ^. nodeChildren)
|
||||
f (reassembleDetails ni ns)
|
||||
where
|
||||
goChild :: NodeChild -> m Node
|
||||
goChild nc
|
||||
| nc ^. childBindersNum == 0 = go (nc ^. childNode)
|
||||
| otherwise = return $ nc ^. childNode
|
||||
|
||||
sdmapG ::
|
||||
forall m.
|
||||
(Monad m) =>
|
||||
(Node -> m Recur) ->
|
||||
Node ->
|
||||
m Node
|
||||
sdmapG f = go
|
||||
where
|
||||
go :: Node -> m Node
|
||||
go n = do
|
||||
r <- f n
|
||||
case r of
|
||||
End n' -> return n'
|
||||
Recur n' ->
|
||||
let ni = destruct n'
|
||||
in reassembleDetails ni <$> mapM goChild (ni ^. nodeChildren)
|
||||
where
|
||||
goChild :: NodeChild -> m Node
|
||||
goChild ch
|
||||
| ch ^. childBindersNum == 0 = go (ch ^. childNode)
|
||||
| otherwise = return $ ch ^. childNode
|
||||
|
||||
type OverIdentity :: GHC.Type -> GHC.Type
|
||||
type family OverIdentity t = res where
|
||||
OverIdentity (a -> b) = a -> OverIdentity b
|
||||
OverIdentity leaf = Identity leaf
|
||||
|
||||
class EmbedIdentity a where
|
||||
embedIden :: a -> OverIdentity a
|
||||
|
||||
instance (EmbedIdentity b) => EmbedIdentity (a -> b) where
|
||||
embedIden f = embedIden . f
|
||||
|
||||
instance EmbedIdentity Node where
|
||||
embedIden = Identity
|
||||
|
||||
instance EmbedIdentity Recur where
|
||||
embedIden = Identity
|
30
src/Juvix/Compiler/Core/Extra/Recursors/SMap/Named.hs
Normal file
30
src/Juvix/Compiler/Core/Extra/Recursors/SMap/Named.hs
Normal file
@ -0,0 +1,30 @@
|
||||
module Juvix.Compiler.Core.Extra.Recursors.SMap.Named where
|
||||
|
||||
import Data.Functor.Identity
|
||||
import Juvix.Compiler.Core.Extra.Recursors.Base
|
||||
import Juvix.Compiler.Core.Extra.Recursors.SMap
|
||||
|
||||
{-
|
||||
|
||||
The shallow mapping recursors are analogous to ordinary mapping recursors (see
|
||||
Core/Extra/Recursors/Map/Named.hs) except that they don't go under binders.
|
||||
|
||||
-}
|
||||
|
||||
sdmapM :: (Monad m) => (Node -> m Node) -> Node -> m Node
|
||||
sdmapM f = sdmapG (fmap Recur . f)
|
||||
|
||||
sumapM :: (Monad m) => (Node -> m Node) -> Node -> m Node
|
||||
sumapM f = sumapG f
|
||||
|
||||
sdmapRM :: (Monad m) => (Node -> m Recur) -> Node -> m Node
|
||||
sdmapRM f = sdmapG f
|
||||
|
||||
sdmapR :: (Node -> Recur) -> Node -> Node
|
||||
sdmapR f = runIdentity . sdmapRM (embedIden f)
|
||||
|
||||
sdmap :: (Node -> Node) -> Node -> Node
|
||||
sdmap f = runIdentity . sdmapM (embedIden f)
|
||||
|
||||
sumap :: (Node -> Node) -> Node -> Node
|
||||
sumap f = runIdentity . sumapM (embedIden f)
|
@ -29,6 +29,7 @@ import Juvix.Compiler.Core.Transformation.MoveApps
|
||||
import Juvix.Compiler.Core.Transformation.NaiveMatchToCase qualified as Naive
|
||||
import Juvix.Compiler.Core.Transformation.NatToPrimInt
|
||||
import Juvix.Compiler.Core.Transformation.Normalize
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.CaseCallLifting
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
|
||||
@ -36,6 +37,8 @@ import Juvix.Compiler.Core.Transformation.Optimize.Phase.Eval qualified as Phase
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Exec qualified as Phase.Exec
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Geb qualified as Phase.Geb
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Main qualified as Phase.Main
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.Phase.VampIR qualified as Phase.VampIR
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs
|
||||
import Juvix.Compiler.Core.Transformation.RemoveTypeArgs
|
||||
import Juvix.Compiler.Core.Transformation.TopEtaExpand
|
||||
import Juvix.Compiler.Core.Transformation.UnrollRecursion
|
||||
@ -69,7 +72,10 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
|
||||
LetHoisting -> return . letHoisting
|
||||
Inlining -> inlining
|
||||
FoldTypeSynonyms -> return . foldTypeSynonyms
|
||||
CaseCallLifting -> return . caseCallLifting
|
||||
SimplifyIfs -> return . simplifyIfs
|
||||
OptPhaseEval -> Phase.Eval.optimize
|
||||
OptPhaseExec -> Phase.Exec.optimize
|
||||
OptPhaseGeb -> Phase.Geb.optimize
|
||||
OptPhaseVampIR -> Phase.VampIR.optimize
|
||||
OptPhaseMain -> Phase.Main.optimize
|
||||
|
@ -0,0 +1,136 @@
|
||||
module Juvix.Compiler.Core.Transformation.Optimize.CaseCallLifting (caseCallLifting) where
|
||||
|
||||
import Data.HashSet qualified as HashSet
|
||||
import Data.List qualified as List
|
||||
import Juvix.Compiler.Core.Extra
|
||||
import Juvix.Compiler.Core.Transformation.Base
|
||||
|
||||
convertNode :: InfoTable -> Node -> Node
|
||||
convertNode tab = umap go
|
||||
where
|
||||
go :: Node -> Node
|
||||
go = \case
|
||||
NCase Case {..}
|
||||
| not (null idents) ->
|
||||
if
|
||||
| isCaseBoolean _caseBranches && not (isImmediate _caseValue) ->
|
||||
mkLet'
|
||||
mkTypeBool'
|
||||
_caseValue
|
||||
(liftApps 0 _caseInductive (mkVar' 0) (brs' 1) (def' 1) idents)
|
||||
| otherwise ->
|
||||
liftApps 0 _caseInductive _caseValue (brs' 0) (def' 0) idents
|
||||
where
|
||||
bodies = map (^. caseBranchBody) _caseBranches ++ maybeToList _caseDefault
|
||||
idts = foldr (flip gatherIdents) mempty bodies
|
||||
idents = filter (\sym -> all (\x -> countApps sym x == 1) bodies) (toList idts)
|
||||
n = length idents
|
||||
brs' k = map (over caseBranchBody (shift (n + k))) _caseBranches
|
||||
def' k = fmap (shift (n + k)) _caseDefault
|
||||
node -> node
|
||||
|
||||
liftApps :: Level -> Symbol -> Node -> [CaseBranch] -> Maybe Node -> [Symbol] -> Node
|
||||
liftApps lvl ind val brs def = \case
|
||||
[] ->
|
||||
NCase
|
||||
Case
|
||||
{ _caseInfo = mempty,
|
||||
_caseInductive = ind,
|
||||
_caseValue = shift lvl val,
|
||||
_caseBranches = brs,
|
||||
_caseDefault = def
|
||||
}
|
||||
sym : syms -> mkLet' ty app (liftApps (lvl + 1) ind val brs' def' syms)
|
||||
where
|
||||
idx = length syms
|
||||
args0 = map (fromJust . gatherAppArgs sym . (^. caseBranchBody)) brs
|
||||
dargs0 = fmap (fromJust . gatherAppArgs sym) def
|
||||
appArgs = computeArgs args0 dargs0
|
||||
app = mkApps' (mkIdent' sym) appArgs
|
||||
(tyargs, tgt) = unfoldPi' (lookupIdentifierInfo tab sym ^. identifierType)
|
||||
tyargs' = drop (length appArgs) tyargs
|
||||
ty = substs appArgs (mkPis' tyargs' tgt)
|
||||
brs' = map (\br -> over caseBranchBody (substApps sym (mkVar' (br ^. caseBranchBindersNum + idx))) br) brs
|
||||
def' = fmap (substApps sym (mkVar' idx)) def
|
||||
|
||||
computeArgs :: [[Node]] -> Maybe [Node] -> [Node]
|
||||
computeArgs args dargs
|
||||
| null (List.head args) = []
|
||||
| otherwise =
|
||||
shift
|
||||
(-idx - 1)
|
||||
(mkCase' ind (shift (lvl + 1) val) (zipWithExact (set caseBranchBody) hbs brs) hdef)
|
||||
: computeArgs args' dargs'
|
||||
where
|
||||
hbs = map List.head args
|
||||
hdef = fmap List.head dargs
|
||||
args' = map List.tail args
|
||||
dargs' = fmap List.tail dargs
|
||||
|
||||
gatherIdents :: HashSet Symbol -> Node -> HashSet Symbol
|
||||
gatherIdents = sgather go'
|
||||
where
|
||||
go' :: HashSet Symbol -> Node -> HashSet Symbol
|
||||
go' acc node = case node of
|
||||
NApp {} ->
|
||||
let (h, args) = unfoldApps' node
|
||||
in case h of
|
||||
NIdt Ident {..}
|
||||
| length args == lookupIdentifierInfo tab _identSymbol ^. identifierArgsNum ->
|
||||
HashSet.insert _identSymbol acc
|
||||
_ -> acc
|
||||
_ -> acc
|
||||
|
||||
countApps :: Symbol -> Node -> Int
|
||||
countApps sym = sgather go' 0
|
||||
where
|
||||
argsNum = lookupIdentifierInfo tab sym ^. identifierArgsNum
|
||||
|
||||
go' :: Int -> Node -> Int
|
||||
go' acc node = case node of
|
||||
NApp {} ->
|
||||
let (h, args) = unfoldApps' node
|
||||
in case h of
|
||||
NIdt Ident {..}
|
||||
| _identSymbol == sym
|
||||
&& length args == argsNum ->
|
||||
acc + 1
|
||||
_ -> acc
|
||||
_ -> acc
|
||||
|
||||
gatherAppArgs :: Symbol -> Node -> Maybe [Node]
|
||||
gatherAppArgs sym = sgather go' Nothing
|
||||
where
|
||||
argsNum = lookupIdentifierInfo tab sym ^. identifierArgsNum
|
||||
|
||||
go' :: Maybe [Node] -> Node -> Maybe [Node]
|
||||
go' acc node = case node of
|
||||
NApp {} ->
|
||||
let (h, args) = unfoldApps' node
|
||||
in case h of
|
||||
NIdt Ident {..}
|
||||
| _identSymbol == sym
|
||||
&& length args == argsNum ->
|
||||
Just args
|
||||
_ -> acc
|
||||
_ -> acc
|
||||
|
||||
substApps :: Symbol -> Node -> Node -> Node
|
||||
substApps sym snode = sumap go'
|
||||
where
|
||||
argsNum = lookupIdentifierInfo tab sym ^. identifierArgsNum
|
||||
|
||||
go' :: Node -> Node
|
||||
go' node = case node of
|
||||
NApp {} ->
|
||||
let (h, args) = unfoldApps' node
|
||||
in case h of
|
||||
NIdt Ident {..}
|
||||
| _identSymbol == sym
|
||||
&& length args == argsNum ->
|
||||
snode
|
||||
_ -> node
|
||||
_ -> node
|
||||
|
||||
caseCallLifting :: InfoTable -> InfoTable
|
||||
caseCallLifting tab = mapAllNodes (convertNode tab) tab
|
@ -22,6 +22,7 @@ convertNode isFoldable = rmap go
|
||||
go recur = \case
|
||||
NLet Let {..}
|
||||
| isImmediate (_letItem ^. letItemValue)
|
||||
|| _letBody == mkVar' 0
|
||||
|| isFoldable (_letItem ^. letItemValue) ->
|
||||
go (recur . (mkBCRemove (_letItem ^. letItemBinder) val' :)) _letBody
|
||||
where
|
||||
|
@ -0,0 +1,13 @@
|
||||
module Juvix.Compiler.Core.Transformation.Optimize.Phase.VampIR where
|
||||
|
||||
import Juvix.Compiler.Core.Options
|
||||
import Juvix.Compiler.Core.Transformation.Base
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.CaseCallLifting
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
|
||||
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs
|
||||
|
||||
optimize :: Member (Reader CoreOptions) r => InfoTable -> Sem r InfoTable
|
||||
optimize =
|
||||
withOptimizationLevel 1 $
|
||||
return . letFolding . simplifyIfs . caseCallLifting . letFolding . lambdaFolding
|
@ -0,0 +1,21 @@
|
||||
module Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs (simplifyIfs) where
|
||||
|
||||
import Data.List qualified as List
|
||||
import Juvix.Compiler.Core.Extra
|
||||
import Juvix.Compiler.Core.Transformation.Base
|
||||
|
||||
convertNode :: Node -> Node
|
||||
convertNode = umap go
|
||||
where
|
||||
go :: Node -> Node
|
||||
go node = case node of
|
||||
NCase Case {..}
|
||||
| isCaseBoolean _caseBranches
|
||||
&& all (== List.head bodies) (List.tail bodies) ->
|
||||
List.head bodies
|
||||
where
|
||||
bodies = map (^. caseBranchBody) _caseBranches ++ maybeToList _caseDefault
|
||||
_ -> node
|
||||
|
||||
simplifyIfs :: InfoTable -> InfoTable
|
||||
simplifyIfs = mapAllNodes convertNode
|
Loading…
Reference in New Issue
Block a user