diff --git a/app/Commands/Repl.hs b/app/Commands/Repl.hs index ff22040c5..c3f8e2c63 100644 --- a/app/Commands/Repl.hs +++ b/app/Commands/Repl.hs @@ -366,9 +366,11 @@ printDefinition = replParseIdentifiers >=> printIdentifiers printFunction :: Scoped.NameId -> Repl () printFunction fun = do tbl :: Scoped.InfoTable <- getInfoTable - let def :: Scoped.FunctionInfo = tbl ^?! Scoped.infoFunctions . at fun . _Just - printLocation def - printConcreteLn def + case tbl ^. Scoped.infoFunctions . at fun of + Just def -> do + printLocation def + printConcreteLn def + Nothing -> return () printInductive :: Scoped.NameId -> Repl () printInductive ind = do diff --git a/juvix-stdlib b/juvix-stdlib index 9a091c545..6a76d4f2a 160000 --- a/juvix-stdlib +++ b/juvix-stdlib @@ -1 +1 @@ -Subproject commit 9a091c5453594ac66b3b25cde0c11a54a255a9c9 +Subproject commit 6a76d4f2aed0ba36aac5f7cae94ac2e070ede154 diff --git a/src/Juvix/Compiler/Core/Data/TransformationId.hs b/src/Juvix/Compiler/Core/Data/TransformationId.hs index d1a426e4f..e09f057c1 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId.hs @@ -30,6 +30,7 @@ data TransformationId | FoldTypeSynonyms | CaseCallLifting | SimplifyIfs + | SimplifyComparisons | SpecializeArgs | CaseFolding | CasePermutation diff --git a/src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs b/src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs index 9da953aba..754312375 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs @@ -90,6 +90,7 @@ transformationText = \case FoldTypeSynonyms -> strFoldTypeSynonyms CaseCallLifting -> strCaseCallLifting SimplifyIfs -> strSimplifyIfs + SimplifyComparisons -> strSimplifyComparisons SpecializeArgs -> strSpecializeArgs CaseFolding -> strCaseFolding CasePermutation -> strCasePermutation @@ -205,6 +206,9 @@ strCaseCallLifting = "case-call-lifting" strSimplifyIfs :: Text strSimplifyIfs = "simplify-ifs" +strSimplifyComparisons :: Text +strSimplifyComparisons = "simplify-comparisons" + strSpecializeArgs :: Text strSpecializeArgs = "specialize-args" diff --git a/src/Juvix/Compiler/Core/Transformation.hs b/src/Juvix/Compiler/Core/Transformation.hs index c615d46bb..9a3c3aabe 100644 --- a/src/Juvix/Compiler/Core/Transformation.hs +++ b/src/Juvix/Compiler/Core/Transformation.hs @@ -42,6 +42,7 @@ import Juvix.Compiler.Core.Transformation.Optimize.Phase.Exec qualified as Phase import Juvix.Compiler.Core.Transformation.Optimize.Phase.Geb qualified as Phase.Geb import Juvix.Compiler.Core.Transformation.Optimize.Phase.Main qualified as Phase.Main import Juvix.Compiler.Core.Transformation.Optimize.Phase.VampIR qualified as Phase.VampIR +import Juvix.Compiler.Core.Transformation.Optimize.SimplifyComparisons (simplifyComparisons) import Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs import Juvix.Compiler.Core.Transformation.Optimize.SpecializeArgs import Juvix.Compiler.Core.Transformation.RemoveTypeArgs @@ -80,6 +81,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts FoldTypeSynonyms -> return . foldTypeSynonyms CaseCallLifting -> return . caseCallLifting SimplifyIfs -> return . simplifyIfs + SimplifyComparisons -> return . simplifyComparisons SpecializeArgs -> return . specializeArgs CaseFolding -> return . caseFolding CasePermutation -> return . casePermutation diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/CaseValueInlining.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/CaseValueInlining.hs new file mode 100644 index 000000000..00aa04c14 --- /dev/null +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/CaseValueInlining.hs @@ -0,0 +1,21 @@ +module Juvix.Compiler.Core.Transformation.Optimize.CaseValueInlining where + +import Juvix.Compiler.Core.Extra +import Juvix.Compiler.Core.Transformation.Base + +convertNode :: InfoTable -> Node -> Node +convertNode tab = dmap go + where + go :: Node -> Node + go node = case node of + NCase cs@Case {..} -> case _caseValue of + NIdt Ident {..} + | Just InlineCase <- lookupIdentifierInfo tab _identSymbol ^. identifierPragmas . pragmasInline -> + NCase cs {_caseValue = lookupIdentifierNode tab _identSymbol} + _ -> + node + _ -> + node + +caseValueInlining :: InfoTable -> InfoTable +caseValueInlining tab = mapAllNodes (convertNode tab) tab diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs index 9874fc73e..b6f532967 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs @@ -61,21 +61,27 @@ convertNode inlineDepth recSyms tab = dmapL go pi = ii ^. identifierPragmas . pragmasInline argsNum = ii ^. identifierArgsNum def = lookupIdentifierNode tab _identSymbol - -- inline zero-argument definitions automatically if inlining would result + -- inline zero-argument definitions (automatically) if inlining would result -- in case reduction - NCase cs@Case {..} -> case _caseValue of - NIdt Ident {..} - | isNothing pi - && not (HashSet.member _identSymbol recSyms) - && isConstructorApp def - && checkDepth tab bl inlineDepth def -> - NCase cs {_caseValue = def} - where - ii = lookupIdentifierInfo tab _identSymbol - pi = ii ^. identifierPragmas . pragmasInline - def = lookupIdentifierNode tab _identSymbol - _ -> - node + NCase cs@Case {..} -> + let (h, args) = unfoldApps _caseValue + in case h of + NIdt Ident {..} -> case pi of + Just InlineCase -> + NCase cs {_caseValue = mkApps def args} + Nothing + | not (HashSet.member _identSymbol recSyms) + && isConstructorApp def + && checkDepth tab bl inlineDepth def -> + NCase cs {_caseValue = mkApps def args} + _ -> + node + where + ii = lookupIdentifierInfo tab _identSymbol + pi = ii ^. identifierPragmas . pragmasInline + def = lookupIdentifierNode tab _identSymbol + _ -> + node _ -> node diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/MandatoryInlining.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/MandatoryInlining.hs index 96b6b3284..cd15fb73a 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/MandatoryInlining.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/MandatoryInlining.hs @@ -11,6 +11,12 @@ convertNode tab = dmap go NIdt Ident {..} | Just InlineAlways <- lookupIdentifierInfo tab _identSymbol ^. identifierPragmas . pragmasInline -> lookupIdentifierNode tab _identSymbol + NCase cs@Case {..} -> case _caseValue of + NIdt Ident {..} + | Just InlineCase <- lookupIdentifierInfo tab _identSymbol ^. identifierPragmas . pragmasInline -> + NCase cs {_caseValue = lookupIdentifierNode tab _identSymbol} + _ -> + node _ -> node diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Eval.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Eval.hs index d0153e4ba..e793ff91e 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Eval.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Eval.hs @@ -2,6 +2,7 @@ module Juvix.Compiler.Core.Transformation.Optimize.Phase.Eval where import Juvix.Compiler.Core.Transformation.Base import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding +import Juvix.Compiler.Core.Transformation.Optimize.CaseValueInlining import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding import Juvix.Compiler.Core.Transformation.Optimize.LetFolding import Juvix.Compiler.Core.Transformation.Optimize.MandatoryInlining @@ -9,7 +10,11 @@ import Juvix.Compiler.Core.Transformation.Optimize.MandatoryInlining optimize :: InfoTable -> Sem r InfoTable optimize = return + . letFolding + . lambdaFolding + . letFolding . caseFolding + . caseValueInlining . letFolding . lambdaFolding . mandatoryInlining diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs index c88b3e7ff..1c8f55d41 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs @@ -18,15 +18,9 @@ optimize' CoreOptions {..} tab = filterUnreachable . compose (4 * _optOptimizationLevel) - ( compose 2 (letFolding' (isInlineableLambda _optInliningDepth)) - . lambdaFolding + ( doSimplification 2 . doInlining - . simplifyIfs' (_optOptimizationLevel <= 1) - . simplifyComparisons - . caseFolding - . casePermutation - . letFolding' (isInlineableLambda _optInliningDepth) - . lambdaFolding + . doSimplification 1 . specializeArgs ) . letFolding @@ -43,6 +37,15 @@ optimize' CoreOptions {..} tab = | _optOptimizationLevel > 1 -> recursiveIdents tab' | otherwise -> recs + doSimplification :: Int -> InfoTable -> InfoTable + doSimplification n = + simplifyIfs' (_optOptimizationLevel <= 1) + . simplifyComparisons + . caseFolding + . casePermutation + . compose n (letFolding' (isInlineableLambda _optInliningDepth)) + . lambdaFolding + optimize :: (Member (Reader CoreOptions) r) => InfoTable -> Sem r InfoTable optimize tab = do opts <- ask diff --git a/src/Juvix/Data/Pragmas.hs b/src/Juvix/Data/Pragmas.hs index 8a21f73f1..0773dbbee 100644 --- a/src/Juvix/Data/Pragmas.hs +++ b/src/Juvix/Data/Pragmas.hs @@ -7,6 +7,7 @@ import Juvix.Prelude.Base data PragmaInline = InlineAlways | InlineNever + | InlineCase | InlineFullyApplied | InlinePartiallyApplied {_pragmaInlineArgsNum :: Int} deriving stock (Show, Eq, Ord, Data, Generic) @@ -135,6 +136,7 @@ instance FromJSON Pragmas where case txt of "always" -> return InlineAlways "never" -> return InlineNever + "case" -> return InlineCase _ -> throwCustomError ("unrecognized inline specification: " <> txt) parseUnroll :: Parse YamlError PragmaUnroll @@ -220,6 +222,7 @@ adjustPragmaInline n = \case InlinePartiallyApplied k -> InlinePartiallyApplied (k + n) InlineAlways -> InlineAlways InlineNever -> InlineNever + InlineCase -> InlineCase InlineFullyApplied -> InlineFullyApplied adjustPragmaSpecialiseArg :: Int -> PragmaSpecialiseArg -> PragmaSpecialiseArg diff --git a/tests/Compilation/positive/test058.juvix b/tests/Compilation/positive/test058.juvix index 699b35ccc..662cdcd3f 100644 --- a/tests/Compilation/positive/test058.juvix +++ b/tests/Compilation/positive/test058.juvix @@ -2,12 +2,12 @@ module test058; import Stdlib.Prelude open hiding {for}; -import Stdlib.Data.Nat.Range open; +import Stdlib.Data.Range open; -sum : Nat → Nat - | x := for (acc := 0) (n in 1 to x) {acc + n}; +sum (x : Nat) : Nat := + for (acc := 0) (n in 1 to x) {acc + n}; -sum' : Nat → Nat - | x := for (acc := 0) (n in 1 to x step 2) {acc + n}; +sum' (x : Nat) : Nat := + for (acc := 0) (n in 1 to x step 2) {acc + n}; main : Nat := sum 100 + sum' 100; diff --git a/tests/smoke/Commands/repl.smoke.yaml b/tests/smoke/Commands/repl.smoke.yaml index f2e32c621..084b94a6c 100644 --- a/tests/smoke/Commands/repl.smoke.yaml +++ b/tests/smoke/Commands/repl.smoke.yaml @@ -71,10 +71,11 @@ tests: - juvix - repl - ../examples/milestone/HelloWorld/HelloWorld.juvix - stdin: ":def + (+) (((+)))" + stdin: ":def >> (>>) (((>>)))" stdout: contains: | - + {A} {{Natural A}} : A -> A -> A := Natural.+ + builtin IO-sequence + axiom >> : IO → IO → IO exit-status: 0 - name: repl-def-infix @@ -82,10 +83,11 @@ tests: - juvix - repl - ../examples/milestone/HelloWorld/HelloWorld.juvix - stdin: ":def +" + stdin: ":def >>" stdout: contains: | - + {A} {{Natural A}} : A -> A -> A := Natural.+ + builtin IO-sequence + axiom >> : IO → IO → IO exit-status: 0 - name: open