1
1
mirror of https://github.com/anoma/juvix.git synced 2025-01-08 16:51:53 +03:00

Bugfix: compiler looping with the specialize pragma (#2899)

* Closes #2884
This commit is contained in:
Łukasz Czajka 2024-07-15 15:08:31 +02:00 committed by GitHub
parent 7d2a59cc9f
commit 5a76e5d9dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 34 additions and 36 deletions

@ -1 +1 @@
Subproject commit 216cb609cbe5aec9badea858f151a5ea400f2e66
Subproject commit 17f22fcec5d78be511ea59984aee3499da5f3342

View File

@ -101,3 +101,6 @@ nonRecursiveIdents' tab =
HashSet.difference
(HashSet.fromList (HashMap.keys (tab ^. infoIdentifiers)))
(recursiveIdentsClosure tab)
nonRecursiveIdents :: Module -> HashSet Symbol
nonRecursiveIdents = nonRecursiveIdents' . computeCombinedInfoTable

View File

@ -39,6 +39,7 @@ data TransformationId
| SpecializeArgs
| CaseFolding
| CasePermutation
| ConstantFolding
| FilterUnreachable
| OptPhaseEval
| OptPhaseExec
@ -113,6 +114,7 @@ instance TransformationId' TransformationId where
SpecializeArgs -> strSpecializeArgs
CaseFolding -> strCaseFolding
CasePermutation -> strCasePermutation
ConstantFolding -> strConstantFolding
FilterUnreachable -> strFilterUnreachable
OptPhaseEval -> strOptPhaseEval
OptPhaseExec -> strOptPhaseExec

View File

@ -119,6 +119,9 @@ strCaseFolding = "case-folding"
strCasePermutation :: Text
strCasePermutation = "case-permutation"
strConstantFolding :: Text
strConstantFolding = "constant-folding"
strFilterUnreachable :: Text
strFilterUnreachable = "filter-unreachable"

View File

@ -36,6 +36,7 @@ import Juvix.Compiler.Core.Transformation.Normalize
import Juvix.Compiler.Core.Transformation.Optimize.CaseCallLifting
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
import Juvix.Compiler.Core.Transformation.Optimize.CasePermutation (casePermutation)
import Juvix.Compiler.Core.Transformation.Optimize.ConstantFolding
import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable (filterUnreachable)
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
@ -96,6 +97,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
SpecializeArgs -> return . specializeArgs
CaseFolding -> return . caseFolding
CasePermutation -> return . casePermutation
ConstantFolding -> constantFolding
FilterUnreachable -> return . filterUnreachable
OptPhaseEval -> Phase.Eval.optimize
OptPhaseExec -> Phase.Exec.optimize

View File

@ -17,7 +17,7 @@ isInlineableLambda inlineDepth md bl node = case node of
False
convertNode :: Int -> HashSet Symbol -> Module -> Node -> Node
convertNode inlineDepth recSyms md = dmapL go
convertNode inlineDepth nonRecSyms md = dmapL go
where
go :: BinderList Binder -> Node -> Node
go bl node = case node of
@ -37,7 +37,7 @@ convertNode inlineDepth recSyms md = dmapL go
Just InlineNever ->
node
_
| not (HashSet.member _identSymbol recSyms)
| HashSet.member _identSymbol nonRecSyms
&& isInlineableLambda inlineDepth md bl def
&& length args >= argsNum ->
mkApps def args
@ -57,7 +57,7 @@ convertNode inlineDepth recSyms md = dmapL go
Just InlineAlways -> def
Just InlineNever -> node
_
| not (HashSet.member _identSymbol recSyms)
| HashSet.member _identSymbol nonRecSyms
&& isImmediate md def ->
def
| otherwise ->
@ -76,7 +76,7 @@ convertNode inlineDepth recSyms md = dmapL go
Just InlineCase ->
NCase cs {_caseValue = mkApps def args}
Nothing
| not (HashSet.member _identSymbol recSyms)
| HashSet.member _identSymbol nonRecSyms
&& isConstructorApp def
&& checkDepth md bl inlineDepth def ->
NCase cs {_caseValue = mkApps def args}
@ -92,9 +92,9 @@ convertNode inlineDepth recSyms md = dmapL go
node
inlining' :: Int -> HashSet Symbol -> Module -> Module
inlining' inliningDepth recSyms md = mapT (const (convertNode inliningDepth recSyms md)) md
inlining' inliningDepth nonRecSyms md = mapT (const (convertNode inliningDepth nonRecSyms md)) md
inlining :: (Member (Reader CoreOptions) r) => Module -> Sem r Module
inlining md = do
d <- asks (^. optInliningDepth)
return $ inlining' d (recursiveIdents md) md
return $ inlining' d (nonRecursiveIdents md) md

View File

@ -33,9 +33,6 @@ optimize' opts@CoreOptions {..} md =
tab :: InfoTable
tab = computeCombinedInfoTable md
recs :: HashSet Symbol
recs = recursiveIdents' tab
nonRecs :: HashSet Symbol
nonRecs = nonRecursiveIdents' tab
@ -48,12 +45,12 @@ optimize' opts@CoreOptions {..} md =
| otherwise = nonRecs
doInlining :: Module -> Module
doInlining md' = inlining' _optInliningDepth recs' md'
doInlining md' = inlining' _optInliningDepth nonRecs' md'
where
recs' =
nonRecs' =
if
| _optOptimizationLevel > 1 -> recursiveIdents md'
| otherwise -> recs
| _optOptimizationLevel > 1 -> nonRecursiveIdents md'
| otherwise -> nonRecs
doSimplification :: Int -> Module -> Module
doSimplification n =

View File

@ -9,20 +9,12 @@ mymap {A B} (f : A -> B) : List A -> List B
| (x :: xs) := f x :: mymap f xs;
{-# specialize: [2, 5], inline: false #-}
myf
: {A B : Type}
-> A
-> (A -> A -> B)
-> A
-> B
-> Bool
-> B
myf : {A B : Type} -> A -> (A -> A -> B) -> A -> B -> Bool -> B
| a0 f a b true := f a0 a
| a0 f a b false := b;
{-# inline: false #-}
myf'
: {A B : Type} -> A -> (A -> A -> A -> B) -> A -> B -> B
myf' : {A B : Type} -> A -> (A -> A -> A -> B) -> A -> B -> B
| a0 f a b := myf a0 (f a0) a b true;
sum : List Nat -> Nat
@ -40,8 +32,7 @@ funa : {A : Type} -> (A -> A) -> A -> A
{-# specialize: true #-}
type Additive A := mkAdditive {add : A -> A -> A};
type Multiplicative A :=
mkMultiplicative {mul : A -> A -> A};
type Multiplicative A := mkMultiplicative {mul : A -> A -> A};
addNat : Additive Nat := mkAdditive (+);
@ -49,20 +40,20 @@ addNat : Additive Nat := mkAdditive (+);
mulNat : Multiplicative Nat := mkMultiplicative (*);
{-# inline: false #-}
fadd {A} (a : Additive A) (x y : A) : A :=
Additive.add a x y;
fadd {A} (a : Additive A) (x y : A) : A := Additive.add a x y;
{-# inline: false #-}
fmul {A} (m : Multiplicative A) (x y : A) : A :=
Multiplicative.mul m x y;
fmul {A} (m : Multiplicative A) (x y : A) : A := Multiplicative.mul m x y;
{-# specialize: [1] #-}
myfilter {A} (f : A → Bool) : List A → List A
| nil := nil
| (h :: hs) := ite (f h) (h :: myfilter f hs) (myfilter f hs);
main : Nat :=
sum (mymap λ {x := x + 3} (1 :: 2 :: 3 :: 4 :: nil))
+ sum
(flatten
(mymap
(mymap λ {x := x + 2})
((1 :: nil) :: (2 :: 3 :: nil) :: nil)))
sum (myfilter (const false) [])
+ sum (mymap λ {x := x + 3} (1 :: 2 :: 3 :: 4 :: nil))
+ sum (flatten (mymap (mymap λ {x := x + 2}) ((1 :: nil) :: (2 :: 3 :: nil) :: nil)))
+ myf 3 (*) 2 5 true
+ myf 1 (+) 2 0 false
+ myf' 7 (const (+)) 2 0