1
1
mirror of https://github.com/anoma/juvix.git synced 2025-01-05 22:46:08 +03:00

Lifting calls out of cases for the VampIR backend (#2218)

* Closes #2200 

For example,

```
def power' : Int → Int → Int → Int :=
  λ(acc : Int)
    λ(a : Int)
      λ(b : Int)
        if = b 0 then acc else if = (% b 2) 0 then power' acc (* a a) (/ b 2) else power' (* acc a) (* a a) (/ b 2);
```

is transformed into

```
def power' : Int → Int → Int → Int :=
  λ(acc : Int)
    λ(a : Int)
      λ(b : Int)
        if = b 0 then acc else let _X : Bool := = (% b 2) 0 in
        power' (if _X then acc else * acc a) (* a a) (/ b 2);
```
This commit is contained in:
Łukasz Czajka 2023-06-23 11:55:19 +02:00 committed by GitHub
parent 8201cb828c
commit f77e05513f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 362 additions and 64 deletions

View File

@ -27,9 +27,12 @@ data TransformationId
| LetHoisting
| Inlining
| FoldTypeSynonyms
| CaseCallLifting
| SimplifyIfs
| OptPhaseEval
| OptPhaseExec
| OptPhaseGeb
| OptPhaseVampIR
| OptPhaseMain
deriving stock (Data, Bounded, Enum, Show)
@ -69,7 +72,7 @@ toNormalizeTransformations :: [TransformationId]
toNormalizeTransformations = toEvalTransformations ++ [LetRecLifting, LetFolding, UnrollRecursion]
toVampIRTransformations :: [TransformationId]
toVampIRTransformations = toEvalTransformations ++ [CheckVampIR, LetRecLifting, LetFolding, UnrollRecursion, Normalize, LetHoisting]
toVampIRTransformations = toEvalTransformations ++ [CheckVampIR, LetRecLifting, OptPhaseVampIR, UnrollRecursion, Normalize, LetHoisting]
toStrippedTransformations :: [TransformationId]
toStrippedTransformations =

View File

@ -87,9 +87,12 @@ transformationText = \case
LetHoisting -> strLetHoisting
Inlining -> strInlining
FoldTypeSynonyms -> strFoldTypeSynonyms
CaseCallLifting -> strCaseCallLifting
SimplifyIfs -> strSimplifyIfs
OptPhaseEval -> strOptPhaseEval
OptPhaseExec -> strOptPhaseExec
OptPhaseGeb -> strOptPhaseGeb
OptPhaseVampIR -> strOptPhaseVampIR
OptPhaseMain -> strOptPhaseMain
parsePipeline :: MonadParsec e Text m => m PipelineId
@ -188,6 +191,12 @@ strInlining = "inlining"
strFoldTypeSynonyms :: Text
strFoldTypeSynonyms = "fold-type-synonyms"
strCaseCallLifting :: Text
strCaseCallLifting = "case-call-lifting"
strSimplifyIfs :: Text
strSimplifyIfs = "simplify-ifs"
strOptPhaseEval :: Text
strOptPhaseEval = "opt-phase-eval"
@ -197,5 +206,8 @@ strOptPhaseExec = "opt-phase-exec"
strOptPhaseGeb :: Text
strOptPhaseGeb = "opt-phase-geb"
strOptPhaseVampIR :: Text
strOptPhaseVampIR = "opt-phase-vampir"
strOptPhaseMain :: Text
strOptPhaseMain = "opt-phase-main"

View File

@ -1,12 +1,15 @@
module Juvix.Compiler.Core.Extra.Recursors
( module Juvix.Compiler.Core.Extra.Recursors.Fold,
module Juvix.Compiler.Core.Extra.Recursors.Collector,
module Juvix.Compiler.Core.Extra.Recursors.Fold.Named,
module Juvix.Compiler.Core.Extra.Recursors.Map,
module Juvix.Compiler.Core.Extra.Recursors.Map.Named,
module Juvix.Compiler.Core.Extra.Recursors.SFold,
module Juvix.Compiler.Core.Extra.Recursors.SFold.Named,
module Juvix.Compiler.Core.Extra.Recursors.SMap.Named,
module Juvix.Compiler.Core.Extra.Recursors.RMap,
module Juvix.Compiler.Core.Extra.Recursors.RMap.Named,
module Juvix.Compiler.Core.Extra.Recursors.Fold.Named,
module Juvix.Compiler.Core.Extra.Recursors.Recur,
module Juvix.Compiler.Core.Extra.Recursors.Collector,
)
where
@ -19,3 +22,6 @@ import Juvix.Compiler.Core.Extra.Recursors.Map.Named
import Juvix.Compiler.Core.Extra.Recursors.RMap
import Juvix.Compiler.Core.Extra.Recursors.RMap.Named
import Juvix.Compiler.Core.Extra.Recursors.Recur
import Juvix.Compiler.Core.Extra.Recursors.SFold
import Juvix.Compiler.Core.Extra.Recursors.SFold.Named
import Juvix.Compiler.Core.Extra.Recursors.SMap.Named

View File

@ -4,15 +4,8 @@ module Juvix.Compiler.Core.Extra.Recursors.Map where
import Data.Functor.Identity
import Data.Kind qualified as GHC
import Data.Singletons.TH
import Juvix.Compiler.Core.Extra.Base
import Juvix.Compiler.Core.Extra.Recursors.Base
import Juvix.Compiler.Core.Extra.Recursors.Parameters
type DirTy :: Direction -> GHC.Type -> GHC.Type
type family DirTy d c = res | res -> d where
DirTy 'TopDown c = Recur' c
DirTy 'BottomUp _ = Node -- For bottom up maps we never recur on the children
-- | `umapG` maps the nodes bottom-up, i.e., when invoking the mapper function the
-- recursive subnodes have already been mapped
@ -119,14 +112,3 @@ fromRecur' d =
End' x -> End' x
Recur' (c, x) -> Recur' ((c, d), x)
)
nodeMapG' ::
(Monad m) =>
Sing dir ->
Collector (Int, [Binder]) c ->
(c -> Node -> m (DirTy dir c)) ->
Node ->
m Node
nodeMapG' sdir = case sdir of
STopDown -> dmapG
SBottomUp -> umapG

View File

@ -3,7 +3,6 @@ module Juvix.Compiler.Core.Extra.Recursors.Map.Named where
import Data.Functor.Identity
import Juvix.Compiler.Core.Extra.Recursors.Base
import Juvix.Compiler.Core.Extra.Recursors.Map
import Juvix.Compiler.Core.Extra.Recursors.Parameters
{-
@ -50,7 +49,7 @@ dmapLM :: (Monad m) => (BinderList Binder -> Node -> m Node) -> Node -> m Node
dmapLM f = dmapLM' (mempty, f)
umapLM :: (Monad m) => (BinderList Binder -> Node -> m Node) -> Node -> m Node
umapLM f = nodeMapG' SBottomUp binderInfoCollector f
umapLM f = umapG binderInfoCollector f
dmapNRM :: (Monad m) => (Level -> Node -> m Recur) -> Node -> m Node
dmapNRM f = dmapNRM' (0, f)
@ -59,34 +58,34 @@ dmapNM :: (Monad m) => (Level -> Node -> m Node) -> Node -> m Node
dmapNM f = dmapNM' (0, f)
umapNM :: (Monad m) => (Level -> Node -> m Node) -> Node -> m Node
umapNM f = nodeMapG' SBottomUp binderNumCollector f
umapNM f = umapG binderNumCollector f
dmapRM :: (Monad m) => (Node -> m Recur) -> Node -> m Node
dmapRM f = nodeMapG' STopDown unitCollector (const (fromRecur mempty . f))
dmapRM f = dmapG unitCollector (const (fromRecur mempty . f))
dmapM :: (Monad m) => (Node -> m Node) -> Node -> m Node
dmapM f = nodeMapG' STopDown unitCollector (const (fromSimple mempty . f))
dmapM f = dmapG unitCollector (const (fromSimple mempty . f))
umapM :: (Monad m) => (Node -> m Node) -> Node -> m Node
umapM f = nodeMapG' SBottomUp unitCollector (const f)
umapM f = umapG unitCollector (const f)
dmapLRM' :: (Monad m) => (BinderList Binder, BinderList Binder -> Node -> m Recur) -> Node -> m Node
dmapLRM' f = nodeMapG' STopDown (binderInfoCollector' (fst f)) (\bi -> fromRecur bi . snd f bi)
dmapLRM' f = dmapG (binderInfoCollector' (fst f)) (\bi -> fromRecur bi . snd f bi)
dmapLM' :: (Monad m) => (BinderList Binder, BinderList Binder -> Node -> m Node) -> Node -> m Node
dmapLM' f = nodeMapG' STopDown (binderInfoCollector' (fst f)) (\bi -> fromSimple bi . snd f bi)
dmapLM' f = dmapG (binderInfoCollector' (fst f)) (\bi -> fromSimple bi . snd f bi)
umapLM' :: (Monad m) => (BinderList Binder, BinderList Binder -> Node -> m Node) -> Node -> m Node
umapLM' f = nodeMapG' SBottomUp (binderInfoCollector' (fst f)) (snd f)
umapLM' f = umapG (binderInfoCollector' (fst f)) (snd f)
dmapNRM' :: (Monad m) => (Level, Level -> Node -> m Recur) -> Node -> m Node
dmapNRM' f = nodeMapG' STopDown (binderNumCollector' (fst f)) (\bi -> fromRecur bi . snd f bi)
dmapNRM' f = dmapG (binderNumCollector' (fst f)) (\bi -> fromRecur bi . snd f bi)
dmapNM' :: (Monad m) => (Level, Level -> Node -> m Node) -> Node -> m Node
dmapNM' f = nodeMapG' STopDown (binderNumCollector' (fst f)) (\bi -> fromSimple bi . snd f bi)
dmapNM' f = dmapG (binderNumCollector' (fst f)) (\bi -> fromSimple bi . snd f bi)
umapNM' :: (Monad m) => (Level, Level -> Node -> m Node) -> Node -> m Node
umapNM' f = nodeMapG' SBottomUp (binderNumCollector' (fst f)) (snd f)
umapNM' f = umapG (binderNumCollector' (fst f)) (snd f)
dmapLR :: (BinderList Binder -> Node -> Recur) -> Node -> Node
dmapLR f = runIdentity . dmapLRM (embedIden f)
@ -134,25 +133,25 @@ umapN' :: (Level, Level -> Node -> Node) -> Node -> Node
umapN' f = runIdentity . umapNM' (embedIden f)
dmapCLM' :: (Monad m) => (BinderList Binder, c -> BinderList Binder -> Node -> m (c, Node)) -> c -> Node -> m Node
dmapCLM' f ini = nodeMapG' STopDown (pairCollector (identityCollector ini) (binderInfoCollector' (fst f))) (\(c, bi) -> fromPair bi . snd f c bi)
dmapCLM' f ini = dmapG (pairCollector (identityCollector ini) (binderInfoCollector' (fst f))) (\(c, bi) -> fromPair bi . snd f c bi)
dmapCLRM' :: (Monad m) => (BinderList Binder, c -> BinderList Binder -> Node -> m (Recur' c)) -> c -> Node -> m Node
dmapCLRM' f ini = nodeMapG' STopDown (pairCollector (identityCollector ini) (binderInfoCollector' (fst f))) (\(c, bi) -> fromRecur' bi . snd f c bi)
dmapCLRM' f ini = dmapG (pairCollector (identityCollector ini) (binderInfoCollector' (fst f))) (\(c, bi) -> fromRecur' bi . snd f c bi)
dmapCNRM' :: (Monad m) => (Level, c -> Level -> Node -> m (Recur' c)) -> c -> Node -> m Node
dmapCNRM' f ini = nodeMapG' STopDown (pairCollector (identityCollector ini) (binderNumCollector' (fst f))) (\(c, bi) -> fromRecur' bi . snd f c bi)
dmapCNRM' f ini = dmapG (pairCollector (identityCollector ini) (binderNumCollector' (fst f))) (\(c, bi) -> fromRecur' bi . snd f c bi)
dmapCLM :: (Monad m) => (c -> BinderList Binder -> Node -> m (c, Node)) -> c -> Node -> m Node
dmapCLM f = dmapCLM' (mempty, f)
dmapCNM' :: (Monad m) => (Level, c -> Level -> Node -> m (c, Node)) -> c -> Node -> m Node
dmapCNM' f ini = nodeMapG' STopDown (pairCollector (identityCollector ini) (binderNumCollector' (fst f))) (\(c, bi) -> fromPair bi . snd f c bi)
dmapCNM' f ini = dmapG (pairCollector (identityCollector ini) (binderNumCollector' (fst f))) (\(c, bi) -> fromPair bi . snd f c bi)
dmapCNM :: (Monad m) => (c -> Level -> Node -> m (c, Node)) -> c -> Node -> m Node
dmapCNM f = dmapCNM' (0, f)
dmapCM :: (Monad m) => (c -> Node -> m (c, Node)) -> c -> Node -> m Node
dmapCM f ini = nodeMapG' STopDown (identityCollector ini) (\c -> fmap Recur' . f c)
dmapCM f ini = dmapG (identityCollector ini) (\c -> fmap Recur' . f c)
dmapCL' :: (BinderList Binder, c -> BinderList Binder -> Node -> (c, Node)) -> c -> Node -> Node
dmapCL' f ini = runIdentity . dmapCLM' (embedIden f) ini

View File

@ -1,26 +0,0 @@
module Juvix.Compiler.Core.Extra.Recursors.Parameters where
import Juvix.Prelude
data CollectorIni
= NoIni
| Ini
data Ctx
= CtxBinderList
| CtxBinderNum
| CtxNone
data Monadic
= Monadic
| NonMonadic
data Ret
= RetRecur
| RetSimple
data Direction
= TopDown
| BottomUp
$(genSingletons [''CollectorIni, ''Ctx, ''Ret, ''Monadic, ''Direction])

View File

@ -0,0 +1,23 @@
-- | Shallow fold recursors over 'Node'.
module Juvix.Compiler.Core.Extra.Recursors.SFold where
import Juvix.Compiler.Core.Extra.Base
import Juvix.Compiler.Core.Extra.Recursors.Base
sfoldG ::
forall a f.
(Applicative f) =>
(a -> [a] -> a) ->
(Node -> f a) ->
Node ->
f a
sfoldG uplus f = go
where
go :: Node -> f a
go n = do
mas' <- sequenceA mas
n' <- f n
pure (uplus n' mas')
where
mas :: [f a]
mas = map go (schildren n)

View File

@ -0,0 +1,24 @@
module Juvix.Compiler.Core.Extra.Recursors.SFold.Named where
import Data.Functor.Identity
import Juvix.Compiler.Core.Extra.Recursors.Base
import Juvix.Compiler.Core.Extra.Recursors.SFold
{-
The shallow folding recursors are analogous to general folding recursors (see
Core/Extra/Recursors/Fold/Named.hs) except that they don't go under binders.
-}
sfoldA :: (Applicative f) => (a -> [a] -> a) -> (Node -> f a) -> Node -> f a
sfoldA uplus f = sfoldG uplus f
swalk :: (Applicative f) => (Node -> f ()) -> Node -> f ()
swalk = sfoldA (foldr mappend)
sfold :: (a -> [a] -> a) -> (Node -> a) -> Node -> a
sfold uplus f = runIdentity . sfoldA uplus (return . f)
sgather :: (a -> Node -> a) -> a -> Node -> a
sgather f acc = run . execState acc . swalk (\n' -> modify' (`f` n'))

View File

@ -0,0 +1,68 @@
{-# LANGUAGE UndecidableInstances #-}
module Juvix.Compiler.Core.Extra.Recursors.SMap where
import Data.Functor.Identity
import Data.Kind qualified as GHC
import Juvix.Compiler.Core.Extra.Base
import Juvix.Compiler.Core.Extra.Recursors.Base
-- | `sumapG` is the shallow version of `umapG`, i.e., it doesn't recurse under binders
sumapG ::
forall m.
(Monad m) =>
(Node -> m Node) ->
Node ->
m Node
sumapG f = go
where
go :: Node -> m Node
go n =
let ni = destruct n
in do
ns <- mapM goChild (ni ^. nodeChildren)
f (reassembleDetails ni ns)
where
goChild :: NodeChild -> m Node
goChild nc
| nc ^. childBindersNum == 0 = go (nc ^. childNode)
| otherwise = return $ nc ^. childNode
sdmapG ::
forall m.
(Monad m) =>
(Node -> m Recur) ->
Node ->
m Node
sdmapG f = go
where
go :: Node -> m Node
go n = do
r <- f n
case r of
End n' -> return n'
Recur n' ->
let ni = destruct n'
in reassembleDetails ni <$> mapM goChild (ni ^. nodeChildren)
where
goChild :: NodeChild -> m Node
goChild ch
| ch ^. childBindersNum == 0 = go (ch ^. childNode)
| otherwise = return $ ch ^. childNode
type OverIdentity :: GHC.Type -> GHC.Type
type family OverIdentity t = res where
OverIdentity (a -> b) = a -> OverIdentity b
OverIdentity leaf = Identity leaf
class EmbedIdentity a where
embedIden :: a -> OverIdentity a
instance (EmbedIdentity b) => EmbedIdentity (a -> b) where
embedIden f = embedIden . f
instance EmbedIdentity Node where
embedIden = Identity
instance EmbedIdentity Recur where
embedIden = Identity

View File

@ -0,0 +1,30 @@
module Juvix.Compiler.Core.Extra.Recursors.SMap.Named where
import Data.Functor.Identity
import Juvix.Compiler.Core.Extra.Recursors.Base
import Juvix.Compiler.Core.Extra.Recursors.SMap
{-
The shallow mapping recursors are analogous to ordinary mapping recursors (see
Core/Extra/Recursors/Map/Named.hs) except that they don't go under binders.
-}
sdmapM :: (Monad m) => (Node -> m Node) -> Node -> m Node
sdmapM f = sdmapG (fmap Recur . f)
sumapM :: (Monad m) => (Node -> m Node) -> Node -> m Node
sumapM f = sumapG f
sdmapRM :: (Monad m) => (Node -> m Recur) -> Node -> m Node
sdmapRM f = sdmapG f
sdmapR :: (Node -> Recur) -> Node -> Node
sdmapR f = runIdentity . sdmapRM (embedIden f)
sdmap :: (Node -> Node) -> Node -> Node
sdmap f = runIdentity . sdmapM (embedIden f)
sumap :: (Node -> Node) -> Node -> Node
sumap f = runIdentity . sumapM (embedIden f)

View File

@ -29,6 +29,7 @@ import Juvix.Compiler.Core.Transformation.MoveApps
import Juvix.Compiler.Core.Transformation.NaiveMatchToCase qualified as Naive
import Juvix.Compiler.Core.Transformation.NatToPrimInt
import Juvix.Compiler.Core.Transformation.Normalize
import Juvix.Compiler.Core.Transformation.Optimize.CaseCallLifting
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
@ -36,6 +37,8 @@ import Juvix.Compiler.Core.Transformation.Optimize.Phase.Eval qualified as Phase
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Exec qualified as Phase.Exec
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Geb qualified as Phase.Geb
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Main qualified as Phase.Main
import Juvix.Compiler.Core.Transformation.Optimize.Phase.VampIR qualified as Phase.VampIR
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs
import Juvix.Compiler.Core.Transformation.RemoveTypeArgs
import Juvix.Compiler.Core.Transformation.TopEtaExpand
import Juvix.Compiler.Core.Transformation.UnrollRecursion
@ -69,7 +72,10 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
LetHoisting -> return . letHoisting
Inlining -> inlining
FoldTypeSynonyms -> return . foldTypeSynonyms
CaseCallLifting -> return . caseCallLifting
SimplifyIfs -> return . simplifyIfs
OptPhaseEval -> Phase.Eval.optimize
OptPhaseExec -> Phase.Exec.optimize
OptPhaseGeb -> Phase.Geb.optimize
OptPhaseVampIR -> Phase.VampIR.optimize
OptPhaseMain -> Phase.Main.optimize

View File

@ -0,0 +1,136 @@
module Juvix.Compiler.Core.Transformation.Optimize.CaseCallLifting (caseCallLifting) where
import Data.HashSet qualified as HashSet
import Data.List qualified as List
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Transformation.Base
convertNode :: InfoTable -> Node -> Node
convertNode tab = umap go
where
go :: Node -> Node
go = \case
NCase Case {..}
| not (null idents) ->
if
| isCaseBoolean _caseBranches && not (isImmediate _caseValue) ->
mkLet'
mkTypeBool'
_caseValue
(liftApps 0 _caseInductive (mkVar' 0) (brs' 1) (def' 1) idents)
| otherwise ->
liftApps 0 _caseInductive _caseValue (brs' 0) (def' 0) idents
where
bodies = map (^. caseBranchBody) _caseBranches ++ maybeToList _caseDefault
idts = foldr (flip gatherIdents) mempty bodies
idents = filter (\sym -> all (\x -> countApps sym x == 1) bodies) (toList idts)
n = length idents
brs' k = map (over caseBranchBody (shift (n + k))) _caseBranches
def' k = fmap (shift (n + k)) _caseDefault
node -> node
liftApps :: Level -> Symbol -> Node -> [CaseBranch] -> Maybe Node -> [Symbol] -> Node
liftApps lvl ind val brs def = \case
[] ->
NCase
Case
{ _caseInfo = mempty,
_caseInductive = ind,
_caseValue = shift lvl val,
_caseBranches = brs,
_caseDefault = def
}
sym : syms -> mkLet' ty app (liftApps (lvl + 1) ind val brs' def' syms)
where
idx = length syms
args0 = map (fromJust . gatherAppArgs sym . (^. caseBranchBody)) brs
dargs0 = fmap (fromJust . gatherAppArgs sym) def
appArgs = computeArgs args0 dargs0
app = mkApps' (mkIdent' sym) appArgs
(tyargs, tgt) = unfoldPi' (lookupIdentifierInfo tab sym ^. identifierType)
tyargs' = drop (length appArgs) tyargs
ty = substs appArgs (mkPis' tyargs' tgt)
brs' = map (\br -> over caseBranchBody (substApps sym (mkVar' (br ^. caseBranchBindersNum + idx))) br) brs
def' = fmap (substApps sym (mkVar' idx)) def
computeArgs :: [[Node]] -> Maybe [Node] -> [Node]
computeArgs args dargs
| null (List.head args) = []
| otherwise =
shift
(-idx - 1)
(mkCase' ind (shift (lvl + 1) val) (zipWithExact (set caseBranchBody) hbs brs) hdef)
: computeArgs args' dargs'
where
hbs = map List.head args
hdef = fmap List.head dargs
args' = map List.tail args
dargs' = fmap List.tail dargs
gatherIdents :: HashSet Symbol -> Node -> HashSet Symbol
gatherIdents = sgather go'
where
go' :: HashSet Symbol -> Node -> HashSet Symbol
go' acc node = case node of
NApp {} ->
let (h, args) = unfoldApps' node
in case h of
NIdt Ident {..}
| length args == lookupIdentifierInfo tab _identSymbol ^. identifierArgsNum ->
HashSet.insert _identSymbol acc
_ -> acc
_ -> acc
countApps :: Symbol -> Node -> Int
countApps sym = sgather go' 0
where
argsNum = lookupIdentifierInfo tab sym ^. identifierArgsNum
go' :: Int -> Node -> Int
go' acc node = case node of
NApp {} ->
let (h, args) = unfoldApps' node
in case h of
NIdt Ident {..}
| _identSymbol == sym
&& length args == argsNum ->
acc + 1
_ -> acc
_ -> acc
gatherAppArgs :: Symbol -> Node -> Maybe [Node]
gatherAppArgs sym = sgather go' Nothing
where
argsNum = lookupIdentifierInfo tab sym ^. identifierArgsNum
go' :: Maybe [Node] -> Node -> Maybe [Node]
go' acc node = case node of
NApp {} ->
let (h, args) = unfoldApps' node
in case h of
NIdt Ident {..}
| _identSymbol == sym
&& length args == argsNum ->
Just args
_ -> acc
_ -> acc
substApps :: Symbol -> Node -> Node -> Node
substApps sym snode = sumap go'
where
argsNum = lookupIdentifierInfo tab sym ^. identifierArgsNum
go' :: Node -> Node
go' node = case node of
NApp {} ->
let (h, args) = unfoldApps' node
in case h of
NIdt Ident {..}
| _identSymbol == sym
&& length args == argsNum ->
snode
_ -> node
_ -> node
caseCallLifting :: InfoTable -> InfoTable
caseCallLifting tab = mapAllNodes (convertNode tab) tab

View File

@ -22,6 +22,7 @@ convertNode isFoldable = rmap go
go recur = \case
NLet Let {..}
| isImmediate (_letItem ^. letItemValue)
|| _letBody == mkVar' 0
|| isFoldable (_letItem ^. letItemValue) ->
go (recur . (mkBCRemove (_letItem ^. letItemBinder) val' :)) _letBody
where

View File

@ -0,0 +1,13 @@
module Juvix.Compiler.Core.Transformation.Optimize.Phase.VampIR where
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Optimize.CaseCallLifting
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs
optimize :: Member (Reader CoreOptions) r => InfoTable -> Sem r InfoTable
optimize =
withOptimizationLevel 1 $
return . letFolding . simplifyIfs . caseCallLifting . letFolding . lambdaFolding

View File

@ -0,0 +1,21 @@
module Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs (simplifyIfs) where
import Data.List qualified as List
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Transformation.Base
convertNode :: Node -> Node
convertNode = umap go
where
go :: Node -> Node
go node = case node of
NCase Case {..}
| isCaseBoolean _caseBranches
&& all (== List.head bodies) (List.tail bodies) ->
List.head bodies
where
bodies = map (^. caseBranchBody) _caseBranches ++ maybeToList _caseDefault
_ -> node
simplifyIfs :: InfoTable -> InfoTable
simplifyIfs = mapAllNodes convertNode