1
1
mirror of https://github.com/anoma/juvix.git synced 2024-11-30 14:13:27 +03:00

Case value inlining (#2441)

* Introduces the `inline: case` pragma which causes an identifier to be
inlined if it is matched on. This is necessary to support GEB without
compromising optimization for other targets.
* Adapts to the new commits in
https://github.com/anoma/juvix-stdlib/pull/86
This commit is contained in:
Łukasz Czajka 2023-10-12 18:59:47 +02:00 committed by GitHub
parent 9e3e07d97c
commit c3bcf40db1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 90 additions and 35 deletions

View File

@ -366,9 +366,11 @@ printDefinition = replParseIdentifiers >=> printIdentifiers
printFunction :: Scoped.NameId -> Repl ()
printFunction fun = do
tbl :: Scoped.InfoTable <- getInfoTable
let def :: Scoped.FunctionInfo = tbl ^?! Scoped.infoFunctions . at fun . _Just
printLocation def
printConcreteLn def
case tbl ^. Scoped.infoFunctions . at fun of
Just def -> do
printLocation def
printConcreteLn def
Nothing -> return ()
printInductive :: Scoped.NameId -> Repl ()
printInductive ind = do

@ -1 +1 @@
Subproject commit 9a091c5453594ac66b3b25cde0c11a54a255a9c9
Subproject commit 6a76d4f2aed0ba36aac5f7cae94ac2e070ede154

View File

@ -30,6 +30,7 @@ data TransformationId
| FoldTypeSynonyms
| CaseCallLifting
| SimplifyIfs
| SimplifyComparisons
| SpecializeArgs
| CaseFolding
| CasePermutation

View File

@ -90,6 +90,7 @@ transformationText = \case
FoldTypeSynonyms -> strFoldTypeSynonyms
CaseCallLifting -> strCaseCallLifting
SimplifyIfs -> strSimplifyIfs
SimplifyComparisons -> strSimplifyComparisons
SpecializeArgs -> strSpecializeArgs
CaseFolding -> strCaseFolding
CasePermutation -> strCasePermutation
@ -205,6 +206,9 @@ strCaseCallLifting = "case-call-lifting"
strSimplifyIfs :: Text
strSimplifyIfs = "simplify-ifs"
strSimplifyComparisons :: Text
strSimplifyComparisons = "simplify-comparisons"
strSpecializeArgs :: Text
strSpecializeArgs = "specialize-args"

View File

@ -42,6 +42,7 @@ import Juvix.Compiler.Core.Transformation.Optimize.Phase.Exec qualified as Phase
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Geb qualified as Phase.Geb
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Main qualified as Phase.Main
import Juvix.Compiler.Core.Transformation.Optimize.Phase.VampIR qualified as Phase.VampIR
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyComparisons (simplifyComparisons)
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs
import Juvix.Compiler.Core.Transformation.Optimize.SpecializeArgs
import Juvix.Compiler.Core.Transformation.RemoveTypeArgs
@ -80,6 +81,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
FoldTypeSynonyms -> return . foldTypeSynonyms
CaseCallLifting -> return . caseCallLifting
SimplifyIfs -> return . simplifyIfs
SimplifyComparisons -> return . simplifyComparisons
SpecializeArgs -> return . specializeArgs
CaseFolding -> return . caseFolding
CasePermutation -> return . casePermutation

View File

@ -0,0 +1,21 @@
module Juvix.Compiler.Core.Transformation.Optimize.CaseValueInlining where
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Transformation.Base
convertNode :: InfoTable -> Node -> Node
convertNode tab = dmap go
where
go :: Node -> Node
go node = case node of
NCase cs@Case {..} -> case _caseValue of
NIdt Ident {..}
| Just InlineCase <- lookupIdentifierInfo tab _identSymbol ^. identifierPragmas . pragmasInline ->
NCase cs {_caseValue = lookupIdentifierNode tab _identSymbol}
_ ->
node
_ ->
node
caseValueInlining :: InfoTable -> InfoTable
caseValueInlining tab = mapAllNodes (convertNode tab) tab

View File

@ -61,21 +61,27 @@ convertNode inlineDepth recSyms tab = dmapL go
pi = ii ^. identifierPragmas . pragmasInline
argsNum = ii ^. identifierArgsNum
def = lookupIdentifierNode tab _identSymbol
-- inline zero-argument definitions automatically if inlining would result
-- inline zero-argument definitions (automatically) if inlining would result
-- in case reduction
NCase cs@Case {..} -> case _caseValue of
NIdt Ident {..}
| isNothing pi
&& not (HashSet.member _identSymbol recSyms)
&& isConstructorApp def
&& checkDepth tab bl inlineDepth def ->
NCase cs {_caseValue = def}
where
ii = lookupIdentifierInfo tab _identSymbol
pi = ii ^. identifierPragmas . pragmasInline
def = lookupIdentifierNode tab _identSymbol
_ ->
node
NCase cs@Case {..} ->
let (h, args) = unfoldApps _caseValue
in case h of
NIdt Ident {..} -> case pi of
Just InlineCase ->
NCase cs {_caseValue = mkApps def args}
Nothing
| not (HashSet.member _identSymbol recSyms)
&& isConstructorApp def
&& checkDepth tab bl inlineDepth def ->
NCase cs {_caseValue = mkApps def args}
_ ->
node
where
ii = lookupIdentifierInfo tab _identSymbol
pi = ii ^. identifierPragmas . pragmasInline
def = lookupIdentifierNode tab _identSymbol
_ ->
node
_ ->
node

View File

@ -11,6 +11,12 @@ convertNode tab = dmap go
NIdt Ident {..}
| Just InlineAlways <- lookupIdentifierInfo tab _identSymbol ^. identifierPragmas . pragmasInline ->
lookupIdentifierNode tab _identSymbol
NCase cs@Case {..} -> case _caseValue of
NIdt Ident {..}
| Just InlineCase <- lookupIdentifierInfo tab _identSymbol ^. identifierPragmas . pragmasInline ->
NCase cs {_caseValue = lookupIdentifierNode tab _identSymbol}
_ ->
node
_ ->
node

View File

@ -2,6 +2,7 @@ module Juvix.Compiler.Core.Transformation.Optimize.Phase.Eval where
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
import Juvix.Compiler.Core.Transformation.Optimize.CaseValueInlining
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
import Juvix.Compiler.Core.Transformation.Optimize.MandatoryInlining
@ -9,7 +10,11 @@ import Juvix.Compiler.Core.Transformation.Optimize.MandatoryInlining
optimize :: InfoTable -> Sem r InfoTable
optimize =
return
. letFolding
. lambdaFolding
. letFolding
. caseFolding
. caseValueInlining
. letFolding
. lambdaFolding
. mandatoryInlining

View File

@ -18,15 +18,9 @@ optimize' CoreOptions {..} tab =
filterUnreachable
. compose
(4 * _optOptimizationLevel)
( compose 2 (letFolding' (isInlineableLambda _optInliningDepth))
. lambdaFolding
( doSimplification 2
. doInlining
. simplifyIfs' (_optOptimizationLevel <= 1)
. simplifyComparisons
. caseFolding
. casePermutation
. letFolding' (isInlineableLambda _optInliningDepth)
. lambdaFolding
. doSimplification 1
. specializeArgs
)
. letFolding
@ -43,6 +37,15 @@ optimize' CoreOptions {..} tab =
| _optOptimizationLevel > 1 -> recursiveIdents tab'
| otherwise -> recs
doSimplification :: Int -> InfoTable -> InfoTable
doSimplification n =
simplifyIfs' (_optOptimizationLevel <= 1)
. simplifyComparisons
. caseFolding
. casePermutation
. compose n (letFolding' (isInlineableLambda _optInliningDepth))
. lambdaFolding
optimize :: (Member (Reader CoreOptions) r) => InfoTable -> Sem r InfoTable
optimize tab = do
opts <- ask

View File

@ -7,6 +7,7 @@ import Juvix.Prelude.Base
data PragmaInline
= InlineAlways
| InlineNever
| InlineCase
| InlineFullyApplied
| InlinePartiallyApplied {_pragmaInlineArgsNum :: Int}
deriving stock (Show, Eq, Ord, Data, Generic)
@ -135,6 +136,7 @@ instance FromJSON Pragmas where
case txt of
"always" -> return InlineAlways
"never" -> return InlineNever
"case" -> return InlineCase
_ -> throwCustomError ("unrecognized inline specification: " <> txt)
parseUnroll :: Parse YamlError PragmaUnroll
@ -220,6 +222,7 @@ adjustPragmaInline n = \case
InlinePartiallyApplied k -> InlinePartiallyApplied (k + n)
InlineAlways -> InlineAlways
InlineNever -> InlineNever
InlineCase -> InlineCase
InlineFullyApplied -> InlineFullyApplied
adjustPragmaSpecialiseArg :: Int -> PragmaSpecialiseArg -> PragmaSpecialiseArg

View File

@ -2,12 +2,12 @@
module test058;
import Stdlib.Prelude open hiding {for};
import Stdlib.Data.Nat.Range open;
import Stdlib.Data.Range open;
sum : Nat → Nat
| x := for (acc := 0) (n in 1 to x) {acc + n};
sum (x : Nat) : Nat :=
for (acc := 0) (n in 1 to x) {acc + n};
sum' : Nat → Nat
| x := for (acc := 0) (n in 1 to x step 2) {acc + n};
sum' (x : Nat) : Nat :=
for (acc := 0) (n in 1 to x step 2) {acc + n};
main : Nat := sum 100 + sum' 100;

View File

@ -71,10 +71,11 @@ tests:
- juvix
- repl
- ../examples/milestone/HelloWorld/HelloWorld.juvix
stdin: ":def + (+) (((+)))"
stdin: ":def >> (>>) (((>>)))"
stdout:
contains: |
+ {A} {{Natural A}} : A -> A -> A := Natural.+
builtin IO-sequence
axiom >> : IO → IO → IO
exit-status: 0
- name: repl-def-infix
@ -82,10 +83,11 @@ tests:
- juvix
- repl
- ../examples/milestone/HelloWorld/HelloWorld.juvix
stdin: ":def +"
stdin: ":def >>"
stdout:
contains: |
+ {A} {{Natural A}} : A -> A -> A := Natural.+
builtin IO-sequence
axiom >> : IO → IO → IO
exit-status: 0
- name: open